From 01b4ba510c43cdc4c49309288f713742de6de9ef Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 28 Aug 2024 11:39:39 -0400 Subject: [PATCH 01/51] First putting in MACE --- hydragnn/models/MACEStack.py | 677 +++++++++++ hydragnn/utils/mace_utils/data/__init__.py | 34 + hydragnn/utils/mace_utils/data/atomic_data.py | 227 ++++ .../utils/mace_utils/data/hdf5_dataset.py | 86 ++ .../utils/mace_utils/data/neighborhood.py | 66 + hydragnn/utils/mace_utils/data/utils.py | 393 ++++++ hydragnn/utils/mace_utils/modules/__init__.py | 109 ++ hydragnn/utils/mace_utils/modules/blocks.py | 758 ++++++++++++ .../utils/mace_utils/modules/irreps_tools.py | 86 ++ hydragnn/utils/mace_utils/modules/loss.py | 367 ++++++ hydragnn/utils/mace_utils/modules/models.py | 1065 +++++++++++++++++ hydragnn/utils/mace_utils/modules/radial.py | 323 +++++ .../modules/symmetric_contraction.py | 233 ++++ hydragnn/utils/mace_utils/modules/utils.py | 414 +++++++ hydragnn/utils/mace_utils/tools/__init__.py | 72 ++ hydragnn/utils/mace_utils/tools/arg_parser.py | 792 ++++++++++++ .../mace_utils/tools/arg_parser_tools.py | 113 ++ hydragnn/utils/mace_utils/tools/cg.py | 131 ++ hydragnn/utils/mace_utils/tools/checkpoint.py | 227 ++++ hydragnn/utils/mace_utils/tools/compile.py | 95 ++ .../mace_utils/tools/finetuning_utils.py | 149 +++ hydragnn/utils/mace_utils/tools/scatter.py | 112 ++ .../utils/mace_utils/tools/scripts_utils.py | 653 ++++++++++ .../mace_utils/tools/slurm_distributed.py | 34 + .../tools/torch_geometric/README.md | 12 + .../tools/torch_geometric/__init__.py | 7 + .../mace_utils/tools/torch_geometric/batch.py | 257 ++++ .../mace_utils/tools/torch_geometric/data.py | 441 +++++++ .../tools/torch_geometric/dataloader.py | 87 ++ .../tools/torch_geometric/dataset.py | 280 +++++ .../mace_utils/tools/torch_geometric/seed.py | 17 + .../mace_utils/tools/torch_geometric/utils.py | 54 + .../utils/mace_utils/tools/torch_tools.py | 138 +++ hydragnn/utils/mace_utils/tools/train.py | 524 ++++++++ hydragnn/utils/mace_utils/tools/utils.py | 168 +++ 35 files changed, 9201 insertions(+) create mode 100644 hydragnn/models/MACEStack.py create mode 100644 hydragnn/utils/mace_utils/data/__init__.py create mode 100644 hydragnn/utils/mace_utils/data/atomic_data.py create mode 100644 hydragnn/utils/mace_utils/data/hdf5_dataset.py create mode 100644 hydragnn/utils/mace_utils/data/neighborhood.py create mode 100644 hydragnn/utils/mace_utils/data/utils.py create mode 100644 hydragnn/utils/mace_utils/modules/__init__.py create mode 100644 hydragnn/utils/mace_utils/modules/blocks.py create mode 100644 hydragnn/utils/mace_utils/modules/irreps_tools.py create mode 100644 hydragnn/utils/mace_utils/modules/loss.py create mode 100644 hydragnn/utils/mace_utils/modules/models.py create mode 100644 hydragnn/utils/mace_utils/modules/radial.py create mode 100644 hydragnn/utils/mace_utils/modules/symmetric_contraction.py create mode 100644 hydragnn/utils/mace_utils/modules/utils.py create mode 100644 hydragnn/utils/mace_utils/tools/__init__.py create mode 100644 hydragnn/utils/mace_utils/tools/arg_parser.py create mode 100644 hydragnn/utils/mace_utils/tools/arg_parser_tools.py create mode 100644 hydragnn/utils/mace_utils/tools/cg.py create mode 100644 hydragnn/utils/mace_utils/tools/checkpoint.py create mode 100644 hydragnn/utils/mace_utils/tools/compile.py create mode 100644 hydragnn/utils/mace_utils/tools/finetuning_utils.py create mode 100644 hydragnn/utils/mace_utils/tools/scatter.py create mode 100644 hydragnn/utils/mace_utils/tools/scripts_utils.py create mode 100644 hydragnn/utils/mace_utils/tools/slurm_distributed.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/README.md create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/batch.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/data.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/seed.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/utils.py create mode 100644 hydragnn/utils/mace_utils/tools/torch_tools.py create mode 100644 hydragnn/utils/mace_utils/tools/train.py create mode 100644 hydragnn/utils/mace_utils/tools/utils.py diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py new file mode 100644 index 000000000..2537f945f --- /dev/null +++ b/hydragnn/models/MACEStack.py @@ -0,0 +1,677 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +# Adapted From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import numpy as np +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode +from torch_geometric.nn import ( + Sequential as PyGSequential, +) # This naming is because there is torch.nn.Sequential and torch_geometric.nn.Sequential + + +# Mace +from hydragnn.utils.mace_utils.data import AtomicData +from hydragnn.utils.mace_utils.modules import ZBLBasis +from hydragnn.utils.mace_utils.tools.scatter import scatter_sum + +from hydragnn.utils.mace_utils.modules.blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + ScaleShiftBlock, +) +from hydragnn.utils.mace_utils.modules.utils import ( + compute_fixed_charge_dipole, + compute_forces, + get_edge_vectors_and_lengths, + get_outputs, + get_symmetric_displacement, +) + +# HydraGNN +from .Base import Base + +# pylint: disable=C0302 + + +@compile_mode("script") +class MACEStack(Base): + def __init__( + self, + radius: float, + num_radial: int, + irreps_cutoff: int, # What is this for? + max_ell: int, # Max l-type for CG-tensor product + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: Union[int, List[int]], + gate: Optional[Callable], + pair_repulsion: bool = False, + distance_transform: str = "None", + radial_MLP: Optional[List[int]] = None, + radial_type: Optional[str] = "bessel", + *args, + **kwargs, + ): + self.num_radial = num_radial + self.radius = radius + self.irreps_cutoff = irreps_cutoff + self.max_ell = max_ell + self.interaction_cls = interaction_cls + self.interaction_cls_first = interaction_cls_first + self.num_interactions = num_interactions + self.num_elements = num_elements + self.hidden_irreps = hidden_irreps + self.MLP_irreps = MLP_irreps + self.atomic_energies = atomic_energies + self.avg_num_neighbors = avg_num_neighbors + self.atomic_numbers = atomic_numbers + self.correlation = correlation + self.gate = gate + self.pair_repulsion = pair_repulsion + self.distance_transform = distance_transform + self.radial_MLP = radial_MLP + self.radial_type = radial_type + + super().__init__(*args, **kwargs) + + def get_conv(self, input_dim, output_dim): + conv = MACEConv( + self, + r_max=self.radius, + num_bessel=self.num_radial, + max_ell=self.max_ell, + interaction_cls=self.interaction_cls, + interaction_cls_first=self.interaction_cls_first, + num_interactions=self.num_interactions, + num_elements=self.num_elements, + hidden_irreps=self.hidden_irreps, + MLP_irreps=self.MLP_irreps, + atomic_energies=self.atomic_energies, + avg_num_neighbors=self.avg_num_neighbors, + atomic_numbers=self.atomic_numbers, + correlation=self.correlation, + gate=self.gate, + pair_repulsion=self.pair_repulsion, + distance_transform=self.distance_transform, + radial_MLP=self.radial_MLP, + radial_type=self.radial_type, + ) + + input_args = "x, pos, edge_index, rbf" + conv_args = "x, edge_index, rbf" + + if self.use_edge_attr: + input_args += ", edge_attr" + conv_args += ", edge_attr" + + return PyGSequential( + input_args, + [ + (conv, conv_args + " -> x"), + (lambda x, pos: [x, pos], "x, pos -> x, pos"), + ], + ) + + def _conv_args(self, data): + assert ( + data.pos is not None + ), "PNA+ requires node positions (data.pos) to be set." + + j, i = data.edge_index # j->i + dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt() + rbf = self.rbf(dist) + # rbf = dist.unsqueeze(-1) + conv_args = {"edge_index": data.edge_index.to(torch.long), "rbf": rbf} + + if self.use_edge_attr: + assert ( + data.edge_attr is not None + ), "Data must have edge attributes if use_edge_attributes is set." + conv_args.update({"edge_attr": data.edge_attr}) + + return conv_args + + def __str__(self): + return "PNAStack" + + + +@compile_mode("script") +class MACEConv(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: Union[int, List[int]], + gate: Optional[Callable], + pair_repulsion: bool = False, + distance_transform: str = "None", + radial_MLP: Optional[List[int]] = None, + radial_type: Optional[str] = "bessel", + ): + super().__init__() + # Register buffers are made when parameters need to be saved and transferred with the model, but not trained. + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + if isinstance(correlation, int): + correlation = [correlation] * num_interactions + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + distance_transform=distance_transform, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + if pair_repulsion: + self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) + self.pair_repulsion = True + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) # This makes the irreps string + num_features = hidden_irreps.count(o3.Irrep(0, 1)) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() #.sort() is a tuple, so we need the [0] element for the sorted result + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" # This makes the spherical harmonic class to be called with forward + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readout + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) # For atom ground-state energy. It takes a one-hot encoding of atom types and returns the energy of each atom type + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation[0], + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearReadoutBlock(hidden_irreps)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = str( + hidden_irreps[0] + ) # Select only scalars for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation[i + 1], + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(LinearReadoutBlock(hidden_irreps)) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + pair_energy = scatter_sum( + src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + else: + pair_node_energy = torch.zeros_like(node_e0) + pair_energy = torch.zeros_like(e0) + + # Interactions + energies = [e0, pair_energy] + node_energies_list = [node_e0, pair_node_energy] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_feats_list.append(node_feats) + node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + energy = scatter_sum( + src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + energies.append(energy) + node_energies_list.append(node_energies) + + # Concatenate node features + node_feats_out = torch.cat(node_feats_list, dim=-1) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + + # Outputs + forces, virials, stress, hessian = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + ) + + return { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "hessian": hessian, + "node_feats": node_feats_out, + } + + + + + + +@compile_mode("script") +class MACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: Union[int, List[int]], + gate: Optional[Callable], + pair_repulsion: bool = False, + distance_transform: str = "None", + radial_MLP: Optional[List[int]] = None, + radial_type: Optional[str] = "bessel", + ): + super().__init__() + # Register buffers are made when parameters need to be saved and transferred with the model, but not trained. + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + if isinstance(correlation, int): + correlation = [correlation] * num_interactions + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + distance_transform=distance_transform, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + if pair_repulsion: + self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) + self.pair_repulsion = True + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) # This makes the irreps string + num_features = hidden_irreps.count(o3.Irrep(0, 1)) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() #.sort() is a tuple, so we need the [0] element for the sorted result + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" # This makes the spherical harmonic class to be called with forward + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readout + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) # For atom ground-state energy. It takes a one-hot encoding of atom types and returns the energy of each atom type + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation[0], + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearReadoutBlock(hidden_irreps)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = str( + hidden_irreps[0] + ) # Select only scalars for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation[i + 1], + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(LinearReadoutBlock(hidden_irreps)) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + pair_energy = scatter_sum( + src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + else: + pair_node_energy = torch.zeros_like(node_e0) + pair_energy = torch.zeros_like(e0) + + # Interactions + energies = [e0, pair_energy] + node_energies_list = [node_e0, pair_node_energy] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_feats_list.append(node_feats) + node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + energy = scatter_sum( + src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + energies.append(energy) + node_energies_list.append(node_energies) + + # Concatenate node features + node_feats_out = torch.cat(node_feats_list, dim=-1) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + + # Outputs + forces, virials, stress, hessian = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + ) + + return { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "hessian": hessian, + "node_feats": node_feats_out, + } + diff --git a/hydragnn/utils/mace_utils/data/__init__.py b/hydragnn/utils/mace_utils/data/__init__.py new file mode 100644 index 000000000..c10a36982 --- /dev/null +++ b/hydragnn/utils/mace_utils/data/__init__.py @@ -0,0 +1,34 @@ +from .atomic_data import AtomicData +from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 +from .neighborhood import get_neighborhood +from .utils import ( + Configuration, + Configurations, + compute_average_E0s, + config_from_atoms, + config_from_atoms_list, + load_from_xyz, + random_train_valid_split, + save_AtomicData_to_HDF5, + save_configurations_as_HDF5, + save_dataset_as_HDF5, + test_config_types, +) + +__all__ = [ + "get_neighborhood", + "Configuration", + "Configurations", + "random_train_valid_split", + "load_from_xyz", + "test_config_types", + "config_from_atoms", + "config_from_atoms_list", + "AtomicData", + "compute_average_E0s", + "save_dataset_as_HDF5", + "HDF5Dataset", + "dataset_from_sharded_hdf5", + "save_AtomicData_to_HDF5", + "save_configurations_as_HDF5", +] diff --git a/hydragnn/utils/mace_utils/data/atomic_data.py b/hydragnn/utils/mace_utils/data/atomic_data.py new file mode 100644 index 000000000..edb91b14c --- /dev/null +++ b/hydragnn/utils/mace_utils/data/atomic_data.py @@ -0,0 +1,227 @@ +########################################################################################### +# Atomic Data Class for handling molecules as graphs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Optional, Sequence + +import torch.utils.data + +from mace.tools import ( + AtomicNumberTable, + atomic_numbers_to_indices, + to_one_hot, + torch_geometric, + voigt_to_matrix, +) + +from .neighborhood import get_neighborhood +from .utils import Configuration + + +class AtomicData(torch_geometric.data.Data): + num_graphs: torch.Tensor + batch: torch.Tensor + edge_index: torch.Tensor + node_attrs: torch.Tensor + edge_vectors: torch.Tensor + edge_lengths: torch.Tensor + positions: torch.Tensor + shifts: torch.Tensor + unit_shifts: torch.Tensor + cell: torch.Tensor + forces: torch.Tensor + energy: torch.Tensor + stress: torch.Tensor + virials: torch.Tensor + dipole: torch.Tensor + charges: torch.Tensor + weight: torch.Tensor + energy_weight: torch.Tensor + forces_weight: torch.Tensor + stress_weight: torch.Tensor + virials_weight: torch.Tensor + + def __init__( + self, + edge_index: torch.Tensor, # [2, n_edges] + node_attrs: torch.Tensor, # [n_nodes, n_node_feats] + positions: torch.Tensor, # [n_nodes, 3] + shifts: torch.Tensor, # [n_edges, 3], + unit_shifts: torch.Tensor, # [n_edges, 3] + cell: Optional[torch.Tensor], # [3,3] + weight: Optional[torch.Tensor], # [,] + energy_weight: Optional[torch.Tensor], # [,] + forces_weight: Optional[torch.Tensor], # [,] + stress_weight: Optional[torch.Tensor], # [,] + virials_weight: Optional[torch.Tensor], # [,] + forces: Optional[torch.Tensor], # [n_nodes, 3] + energy: Optional[torch.Tensor], # [, ] + stress: Optional[torch.Tensor], # [1,3,3] + virials: Optional[torch.Tensor], # [1,3,3] + dipole: Optional[torch.Tensor], # [, 3] + charges: Optional[torch.Tensor], # [n_nodes, ] + ): + # Check shapes + num_nodes = node_attrs.shape[0] + + assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 + assert positions.shape == (num_nodes, 3) + assert shifts.shape[1] == 3 + assert unit_shifts.shape[1] == 3 + assert len(node_attrs.shape) == 2 + assert weight is None or len(weight.shape) == 0 + assert energy_weight is None or len(energy_weight.shape) == 0 + assert forces_weight is None or len(forces_weight.shape) == 0 + assert stress_weight is None or len(stress_weight.shape) == 0 + assert virials_weight is None or len(virials_weight.shape) == 0 + assert cell is None or cell.shape == (3, 3) + assert forces is None or forces.shape == (num_nodes, 3) + assert energy is None or len(energy.shape) == 0 + assert stress is None or stress.shape == (1, 3, 3) + assert virials is None or virials.shape == (1, 3, 3) + assert dipole is None or dipole.shape[-1] == 3 + assert charges is None or charges.shape == (num_nodes,) + # Aggregate data + data = { + "num_nodes": num_nodes, + "edge_index": edge_index, + "positions": positions, + "shifts": shifts, + "unit_shifts": unit_shifts, + "cell": cell, + "node_attrs": node_attrs, + "weight": weight, + "energy_weight": energy_weight, + "forces_weight": forces_weight, + "stress_weight": stress_weight, + "virials_weight": virials_weight, + "forces": forces, + "energy": energy, + "stress": stress, + "virials": virials, + "dipole": dipole, + "charges": charges, + } + super().__init__(**data) + + @classmethod + def from_config( + cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float + ) -> "AtomicData": + edge_index, shifts, unit_shifts = get_neighborhood( + positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell + ) + indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) + one_hot = to_one_hot( + torch.tensor(indices, dtype=torch.long).unsqueeze(-1), + num_classes=len(z_table), + ) + + cell = ( + torch.tensor(config.cell, dtype=torch.get_default_dtype()) + if config.cell is not None + else torch.tensor( + 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() + ).view(3, 3) + ) + + weight = ( + torch.tensor(config.weight, dtype=torch.get_default_dtype()) + if config.weight is not None + else 1 + ) + + energy_weight = ( + torch.tensor(config.energy_weight, dtype=torch.get_default_dtype()) + if config.energy_weight is not None + else 1 + ) + + forces_weight = ( + torch.tensor(config.forces_weight, dtype=torch.get_default_dtype()) + if config.forces_weight is not None + else 1 + ) + + stress_weight = ( + torch.tensor(config.stress_weight, dtype=torch.get_default_dtype()) + if config.stress_weight is not None + else 1 + ) + + virials_weight = ( + torch.tensor(config.virials_weight, dtype=torch.get_default_dtype()) + if config.virials_weight is not None + else 1 + ) + + forces = ( + torch.tensor(config.forces, dtype=torch.get_default_dtype()) + if config.forces is not None + else None + ) + energy = ( + torch.tensor(config.energy, dtype=torch.get_default_dtype()) + if config.energy is not None + else None + ) + stress = ( + voigt_to_matrix( + torch.tensor(config.stress, dtype=torch.get_default_dtype()) + ).unsqueeze(0) + if config.stress is not None + else None + ) + virials = ( + voigt_to_matrix( + torch.tensor(config.virials, dtype=torch.get_default_dtype()) + ).unsqueeze(0) + if config.virials is not None + else None + ) + dipole = ( + torch.tensor(config.dipole, dtype=torch.get_default_dtype()).unsqueeze(0) + if config.dipole is not None + else None + ) + charges = ( + torch.tensor(config.charges, dtype=torch.get_default_dtype()) + if config.charges is not None + else None + ) + + return cls( + edge_index=torch.tensor(edge_index, dtype=torch.long), + positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), + shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), + unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), + cell=cell, + node_attrs=one_hot, + weight=weight, + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + virials_weight=virials_weight, + forces=forces, + energy=energy, + stress=stress, + virials=virials, + dipole=dipole, + charges=charges, + ) + + +def get_data_loader( + dataset: Sequence[AtomicData], + batch_size: int, + shuffle=True, + drop_last=False, +) -> torch.utils.data.DataLoader: + return torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) diff --git a/hydragnn/utils/mace_utils/data/hdf5_dataset.py b/hydragnn/utils/mace_utils/data/hdf5_dataset.py new file mode 100644 index 000000000..5057fd7f1 --- /dev/null +++ b/hydragnn/utils/mace_utils/data/hdf5_dataset.py @@ -0,0 +1,86 @@ +from glob import glob +from typing import List + +import h5py +from torch.utils.data import ConcatDataset, Dataset + +from mace.data.atomic_data import AtomicData +from mace.data.utils import Configuration +from mace.tools.utils import AtomicNumberTable + + +class HDF5Dataset(Dataset): + def __init__(self, file_path, r_max, z_table, **kwargs): + super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments + self.file_path = file_path + self._file = None + batch_key = list(self.file.keys())[0] + self.batch_size = len(self.file[batch_key].keys()) + self.length = len(self.file.keys()) * self.batch_size + self.r_max = r_max + self.z_table = z_table + try: + self.drop_last = bool(self.file.attrs["drop_last"]) + except KeyError: + self.drop_last = False + self.kwargs = kwargs + + @property + def file(self): + if self._file is None: + # If a file has not already been opened, open one here + self._file = h5py.File(self.file_path, "r") + return self._file + + def __getstate__(self): + _d = dict(self.__dict__) + + # An opened h5py.File cannot be pickled, so we must exclude it from the state + _d["_file"] = None + return _d + + def __len__(self): + return self.length + + def __getitem__(self, index): + # compute the index of the batch + batch_index = index // self.batch_size + config_index = index % self.batch_size + grp = self.file["config_batch_" + str(batch_index)] + subgrp = grp["config_" + str(config_index)] + config = Configuration( + atomic_numbers=subgrp["atomic_numbers"][()], + positions=subgrp["positions"][()], + energy=unpack_value(subgrp["energy"][()]), + forces=unpack_value(subgrp["forces"][()]), + stress=unpack_value(subgrp["stress"][()]), + virials=unpack_value(subgrp["virials"][()]), + dipole=unpack_value(subgrp["dipole"][()]), + charges=unpack_value(subgrp["charges"][()]), + weight=unpack_value(subgrp["weight"][()]), + energy_weight=unpack_value(subgrp["energy_weight"][()]), + forces_weight=unpack_value(subgrp["forces_weight"][()]), + stress_weight=unpack_value(subgrp["stress_weight"][()]), + virials_weight=unpack_value(subgrp["virials_weight"][()]), + config_type=unpack_value(subgrp["config_type"][()]), + pbc=unpack_value(subgrp["pbc"][()]), + cell=unpack_value(subgrp["cell"][()]), + ) + atomic_data = AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max + ) + return atomic_data + + +def dataset_from_sharded_hdf5(files: List, z_table: AtomicNumberTable, r_max: float): + files = glob(files + "/*") + datasets = [] + for file in files: + datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max)) + full_dataset = ConcatDataset(datasets) + return full_dataset + + +def unpack_value(value): + value = value.decode("utf-8") if isinstance(value, bytes) else value + return None if str(value) == "None" else value diff --git a/hydragnn/utils/mace_utils/data/neighborhood.py b/hydragnn/utils/mace_utils/data/neighborhood.py new file mode 100644 index 000000000..293576af4 --- /dev/null +++ b/hydragnn/utils/mace_utils/data/neighborhood.py @@ -0,0 +1,66 @@ +from typing import Optional, Tuple + +import numpy as np +from matscipy.neighbours import neighbour_list + + +def get_neighborhood( + positions: np.ndarray, # [num_positions, 3] + cutoff: float, + pbc: Optional[Tuple[bool, bool, bool]] = None, + cell: Optional[np.ndarray] = None, # [3, 3] + true_self_interaction=False, +) -> Tuple[np.ndarray, np.ndarray]: + if pbc is None: + pbc = (False, False, False) + + if cell is None or cell.any() == np.zeros((3, 3)).any(): + cell = np.identity(3, dtype=float) + + assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) + assert cell.shape == (3, 3) + + pbc_x = pbc[0] + pbc_y = pbc[1] + pbc_z = pbc[2] + identity = np.identity(3, dtype=float) + max_positions = np.max(np.absolute(positions)) + 1 + # Extend cell in non-periodic directions + # For models with more than 5 layers, the multiplicative constant needs to be increased. + temp_cell = np.copy(cell) + if not pbc_x: + temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + if not pbc_y: + temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + if not pbc_z: + temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + + sender, receiver, unit_shifts = neighbour_list( + quantities="ijS", + pbc=pbc, + cell=temp_cell, + positions=positions, + cutoff=cutoff, + # self_interaction=True, # we want edges from atom to itself in different periodic images + # use_scaled_positions=False, # positions are not scaled positions + ) + + if not true_self_interaction: + # Eliminate self-edges that don't cross periodic boundaries + true_self_edge = sender == receiver + true_self_edge &= np.all(unit_shifts == 0, axis=1) + keep_edge = ~true_self_edge + + # Note: after eliminating self-edges, it can be that no edges remain in this system + sender = sender[keep_edge] + receiver = receiver[keep_edge] + unit_shifts = unit_shifts[keep_edge] + + # Build output + edge_index = np.stack((sender, receiver)) # [2, n_edges] + + # From the docs: With the shift vector S, the distances D between atoms can be computed from + # D = positions[j]-positions[i]+S.dot(cell) + shifts = np.dot(unit_shifts, cell) # [n_edges, 3] + + return edge_index, shifts, unit_shifts diff --git a/hydragnn/utils/mace_utils/data/utils.py b/hydragnn/utils/mace_utils/data/utils.py new file mode 100644 index 000000000..78e3e76fd --- /dev/null +++ b/hydragnn/utils/mace_utils/data/utils.py @@ -0,0 +1,393 @@ +########################################################################################### +# Data parsing utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import ase.data +import ase.io +import h5py +import numpy as np + +from mace.tools import AtomicNumberTable + +Vector = np.ndarray # [3,] +Positions = np.ndarray # [..., 3] +Forces = np.ndarray # [..., 3] +Stress = np.ndarray # [6, ], [3,3], [9, ] +Virials = np.ndarray # [6, ], [3,3], [9, ] +Charges = np.ndarray # [..., 1] +Cell = np.ndarray # [3,3] +Pbc = tuple # (3,) + +DEFAULT_CONFIG_TYPE = "Default" +DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} + + +@dataclass +class Configuration: + atomic_numbers: np.ndarray + positions: Positions # Angstrom + energy: Optional[float] = None # eV + forces: Optional[Forces] = None # eV/Angstrom + stress: Optional[Stress] = None # eV/Angstrom^3 + virials: Optional[Virials] = None # eV + dipole: Optional[Vector] = None # Debye + charges: Optional[Charges] = None # atomic unit + cell: Optional[Cell] = None + pbc: Optional[Pbc] = None + + weight: float = 1.0 # weight of config in loss + energy_weight: float = 1.0 # weight of config energy in loss + forces_weight: float = 1.0 # weight of config forces in loss + stress_weight: float = 1.0 # weight of config stress in loss + virials_weight: float = 1.0 # weight of config virial in loss + config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config + + +Configurations = List[Configuration] + + +def random_train_valid_split( + items: Sequence, valid_fraction: float, seed: int, work_dir: str +) -> Tuple[List, List]: + assert 0.0 < valid_fraction < 1.0 + + size = len(items) + train_size = size - int(valid_fraction * size) + + indices = list(range(size)) + rng = np.random.default_rng(seed) + rng.shuffle(indices) + if len(indices[train_size:]) < 10: + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" + ) + else: + # Save indices to file + with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: + for index in indices[train_size:]: + f.write(f"{index}\n") + + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" + ) + + return ( + [items[i] for i in indices[:train_size]], + [items[i] for i in indices[train_size:]], + ) + + +def config_from_atoms_list( + atoms_list: List[ase.Atoms], + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", + config_type_weights: Dict[str, float] = None, +) -> Configurations: + """Convert list of ase.Atoms into Configurations""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + all_configs = [] + for atoms in atoms_list: + all_configs.append( + config_from_atoms( + atoms, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + config_type_weights=config_type_weights, + ) + ) + return all_configs + + +def config_from_atoms( + atoms: ase.Atoms, + energy_key="REF_energy", + forces_key="REF_forces", + stress_key="REF_stress", + virials_key="REF_virials", + dipole_key="REF_dipole", + charges_key="REF_charges", + config_type_weights: Dict[str, float] = None, +) -> Configuration: + """Convert ase.Atoms to Configuration""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + energy = atoms.info.get(energy_key, None) # eV + forces = atoms.arrays.get(forces_key, None) # eV / Ang + stress = atoms.info.get(stress_key, None) # eV / Ang ^ 3 + virials = atoms.info.get(virials_key, None) + dipole = atoms.info.get(dipole_key, None) # Debye + # Charges default to 0 instead of None if not found + charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit + atomic_numbers = np.array( + [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] + ) + pbc = tuple(atoms.get_pbc()) + cell = np.array(atoms.get_cell()) + config_type = atoms.info.get("config_type", "Default") + weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( + config_type, 1.0 + ) + energy_weight = atoms.info.get("config_energy_weight", 1.0) + forces_weight = atoms.info.get("config_forces_weight", 1.0) + stress_weight = atoms.info.get("config_stress_weight", 1.0) + virials_weight = atoms.info.get("config_virials_weight", 1.0) + + # fill in missing quantities but set their weight to 0.0 + if energy is None: + energy = 0.0 + energy_weight = 0.0 + if forces is None: + forces = np.zeros(np.shape(atoms.positions)) + forces_weight = 0.0 + if stress is None: + stress = np.zeros(6) + stress_weight = 0.0 + if virials is None: + virials = np.zeros((3, 3)) + virials_weight = 0.0 + if dipole is None: + dipole = np.zeros(3) + # dipoles_weight = 0.0 + + return Configuration( + atomic_numbers=atomic_numbers, + positions=atoms.get_positions(), + energy=energy, + forces=forces, + stress=stress, + virials=virials, + dipole=dipole, + charges=charges, + weight=weight, + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + virials_weight=virials_weight, + config_type=config_type, + pbc=pbc, + cell=cell, + ) + + +def test_config_types( + test_configs: Configurations, +) -> List[Tuple[Optional[str], List[Configuration]]]: + """Split test set based on config_type-s""" + test_by_ct = [] + all_cts = [] + for conf in test_configs: + if conf.config_type not in all_cts: + all_cts.append(conf.config_type) + test_by_ct.append((conf.config_type, [conf])) + else: + ind = all_cts.index(conf.config_type) + test_by_ct[ind][1].append(conf) + return test_by_ct + + +def load_from_xyz( + file_path: str, + config_type_weights: Dict, + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", + virials_key: str = "REF_virials", + dipole_key: str = "REF_dipole", + charges_key: str = "REF_charges", + extract_atomic_energies: bool = False, + keep_isolated_atoms: bool = False, +) -> Tuple[Dict[int, float], Configurations]: + atoms_list = ase.io.read(file_path, index=":") + if energy_key == "energy": + logging.warning( + "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." + ) + energy_key = "REF_energy" + for atoms in atoms_list: + try: + atoms.info["REF_energy"] = atoms.get_potential_energy() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract energy: {e}") + atoms.info["REF_energy"] = None + if forces_key == "forces": + logging.warning( + "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." + ) + forces_key = "REF_forces" + for atoms in atoms_list: + try: + atoms.arrays["REF_forces"] = atoms.get_forces() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract forces: {e}") + atoms.arrays["REF_forces"] = None + if stress_key == "stress": + logging.warning( + "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." + ) + stress_key = "REF_stress" + for atoms in atoms_list: + try: + atoms.info["REF_stress"] = atoms.get_stress() + except Exception as e: # pylint: disable=W0703 + atoms.info["REF_stress"] = None + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + atomic_energies_dict = {} + if extract_atomic_energies: + atoms_without_iso_atoms = [] + + for idx, atoms in enumerate(atoms_list): + isolated_atom_config = ( + len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" + ) + if isolated_atom_config: + if energy_key in atoms.info.keys(): + atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ + energy_key + ] + else: + logging.warning( + f"Configuration '{idx}' is marked as 'IsolatedAtom' " + "but does not contain an energy. Zero energy will be used." + ) + atomic_energies_dict[atoms.get_atomic_numbers()[0]] = np.zeros(1) + else: + atoms_without_iso_atoms.append(atoms) + + if len(atomic_energies_dict) > 0: + logging.info("Using isolated atom energies from training file") + if not keep_isolated_atoms: + atoms_list = atoms_without_iso_atoms + + configs = config_from_atoms_list( + atoms_list, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + ) + return atomic_energies_dict, configs + + +def compute_average_E0s( + collections_train: Configurations, z_table: AtomicNumberTable +) -> Dict[int, float]: + """ + Function to compute the average interaction energy of each chemical element + returns dictionary of E0s + """ + len_train = len(collections_train) + len_zs = len(z_table) + A = np.zeros((len_train, len_zs)) + B = np.zeros(len_train) + for i in range(len_train): + B[i] = collections_train[i].energy + for j, z in enumerate(z_table.zs): + A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) + try: + E0s = np.linalg.lstsq(A, B, rcond=None)[0] + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = E0s[i] + except np.linalg.LinAlgError: + logging.error( + "Failed to compute E0s using least squares regression, using the same for all atoms" + ) + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = 0.0 + return atomic_energies_dict + + +def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: + with h5py.File(out_name, "w") as f: + for i, data in enumerate(dataset): + grp = f.create_group(f"config_{i}") + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges + + +def save_AtomicData_to_HDF5(data, i, h5_file) -> None: + grp = h5_file.create_group(f"config_{i}") + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges + + +def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: + grp = h5_file.create_group("config_batch_0") + for j, config in enumerate(configurations): + subgroup_name = f"config_{j}" + subgroup = grp.create_group(subgroup_name) + subgroup["atomic_numbers"] = write_value(config.atomic_numbers) + subgroup["positions"] = write_value(config.positions) + subgroup["energy"] = write_value(config.energy) + subgroup["forces"] = write_value(config.forces) + subgroup["stress"] = write_value(config.stress) + subgroup["virials"] = write_value(config.virials) + subgroup["dipole"] = write_value(config.dipole) + subgroup["charges"] = write_value(config.charges) + subgroup["cell"] = write_value(config.cell) + subgroup["pbc"] = write_value(config.pbc) + subgroup["weight"] = write_value(config.weight) + subgroup["energy_weight"] = write_value(config.energy_weight) + subgroup["forces_weight"] = write_value(config.forces_weight) + subgroup["stress_weight"] = write_value(config.stress_weight) + subgroup["virials_weight"] = write_value(config.virials_weight) + subgroup["config_type"] = write_value(config.config_type) + + +def write_value(value): + return value if value is not None else "None" diff --git a/hydragnn/utils/mace_utils/modules/__init__.py b/hydragnn/utils/mace_utils/modules/__init__.py new file mode 100644 index 000000000..9278130fd --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/__init__.py @@ -0,0 +1,109 @@ +from typing import Callable, Dict, Optional, Type + +import torch + +from .blocks import ( + AgnosticNonlinearInteractionBlock, + AgnosticResidualNonlinearInteractionBlock, + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + RealAgnosticAttResidualInteractionBlock, + RealAgnosticInteractionBlock, + RealAgnosticResidualInteractionBlock, + ResidualElementDependentInteractionBlock, + ScaleShiftBlock, +) +from .loss import ( + DipoleSingleLoss, + UniversalLoss, + WeightedEnergyForcesDipoleLoss, + WeightedEnergyForcesLoss, + WeightedEnergyForcesStressLoss, + WeightedEnergyForcesVirialsLoss, + WeightedForcesLoss, + WeightedHuberEnergyForcesStressLoss, +) +from .models import ( + MACE, + AtomicDipolesMACE, + BOTNet, + EnergyDipolesMACE, + ScaleShiftBOTNet, + ScaleShiftMACE, +) +from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis +from .symmetric_contraction import SymmetricContraction +from .utils import ( + compute_avg_num_neighbors, + compute_fixed_charge_dipole, + compute_mean_rms_energy_forces, + compute_mean_std_atomic_inter_energy, + compute_rms_dipoles, + compute_statistics, +) + +interaction_classes: Dict[str, Type[InteractionBlock]] = { + "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, + "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, + "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock, + "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, + "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, + "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, +} + +scaling_classes: Dict[str, Callable] = { + "std_scaling": compute_mean_std_atomic_inter_energy, + "rms_forces_scaling": compute_mean_rms_energy_forces, + "rms_dipoles_scaling": compute_rms_dipoles, +} + +gate_dict: Dict[str, Optional[Callable]] = { + "abs": torch.abs, + "tanh": torch.tanh, + "silu": torch.nn.functional.silu, + "None": None, +} + +__all__ = [ + "AtomicEnergiesBlock", + "RadialEmbeddingBlock", + "ZBLBasis", + "LinearNodeEmbeddingBlock", + "LinearReadoutBlock", + "EquivariantProductBasisBlock", + "ScaleShiftBlock", + "LinearDipoleReadoutBlock", + "NonLinearDipoleReadoutBlock", + "InteractionBlock", + "NonLinearReadoutBlock", + "PolynomialCutoff", + "BesselBasis", + "GaussianBasis", + "MACE", + "ScaleShiftMACE", + "BOTNet", + "ScaleShiftBOTNet", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + "WeightedEnergyForcesLoss", + "WeightedForcesLoss", + "WeightedEnergyForcesVirialsLoss", + "WeightedEnergyForcesStressLoss", + "DipoleSingleLoss", + "WeightedEnergyForcesDipoleLoss", + "WeightedHuberEnergyForcesStressLoss", + "UniversalLoss", + "SymmetricContraction", + "interaction_classes", + "compute_mean_std_atomic_inter_energy", + "compute_avg_num_neighbors", + "compute_statistics", + "compute_fixed_charge_dipole", +] diff --git a/hydragnn/utils/mace_utils/modules/blocks.py b/hydragnn/utils/mace_utils/modules/blocks.py new file mode 100644 index 000000000..e8645a8e7 --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/blocks.py @@ -0,0 +1,758 @@ +########################################################################################### +# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from abc import abstractmethod +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.nn.functional +from e3nn import nn, o3 +from e3nn.util.jit import compile_mode + +from mace.tools.compile import simplify_if_compile +from mace.tools.scatter import scatter_sum + +from .irreps_tools import ( + linear_out_irreps, + reshape_irreps, + tp_out_irreps_with_instructions, +) +from .radial import ( + AgnesiTransform, + BesselBasis, + ChebychevBasis, + GaussianBasis, + PolynomialCutoff, + SoftTransform, +) +from .symmetric_contraction import SymmetricContraction + + +@compile_mode("script") +class LinearNodeEmbeddingBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + ) -> torch.Tensor: # [n_nodes, irreps] + return self.linear(node_attrs) + + +@compile_mode("script") +class LinearReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps): + super().__init__() + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=o3.Irreps("0e")) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@simplify_if_compile +@compile_mode("script") +class NonLinearReadoutBlock(torch.nn.Module): + def __init__( + self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, gate: Optional[Callable] + ): + super().__init__() + self.hidden_irreps = MLP_irreps + self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) + self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) + self.linear_2 = o3.Linear( + irreps_in=self.hidden_irreps, irreps_out=o3.Irreps("0e") + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.non_linearity(self.linear_1(x)) + return self.linear_2(x) # [n_nodes, 1] + + +@compile_mode("script") +class LinearDipoleReadoutBlock(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): + super().__init__() + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@compile_mode("script") +class NonLinearDipoleReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Callable, + dipole_only: bool = False, + ): + super().__init__() + self.hidden_irreps = MLP_irreps + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + irreps_scalars = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] + ) + irreps_gated = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] + ) + irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) + self.equivariant_nonlin = nn.Gate( + irreps_scalars=irreps_scalars, + act_scalars=[gate for _, ir in irreps_scalars], + irreps_gates=irreps_gates, + act_gates=[gate] * len(irreps_gates), + irreps_gated=irreps_gated, + ) + self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() + self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) + self.linear_2 = o3.Linear( + irreps_in=self.hidden_irreps, irreps_out=self.irreps_out + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.equivariant_nonlin(self.linear_1(x)) + return self.linear_2(x) # [n_nodes, 1] + + +@compile_mode("script") +class AtomicEnergiesBlock(torch.nn.Module): + atomic_energies: torch.Tensor + + def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): + super().__init__() + assert len(atomic_energies.shape) == 1 + + self.register_buffer( + "atomic_energies", + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), + ) # [n_elements, ] + + def forward( + self, x: torch.Tensor # one-hot of elements [..., n_elements] + ) -> torch.Tensor: # [..., ] + return torch.matmul(x, self.atomic_energies) + + def __repr__(self): + formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) + return f"{self.__class__.__name__}(energies=[{formatted_energies}])" + + +@compile_mode("script") +class RadialEmbeddingBlock(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + radial_type: str = "bessel", + distance_transform: str = "None", + ): + super().__init__() + if radial_type == "bessel": + self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "gaussian": + self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "chebyshev": + self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) + if distance_transform == "Agnesi": + self.distance_transform = AgnesiTransform() + elif distance_transform == "Soft": + self.distance_transform = SoftTransform() + self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) + self.out_dim = num_bessel + + def forward( + self, + edge_lengths: torch.Tensor, # [n_edges, 1] + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ): + cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] + if hasattr(self, "distance_transform"): + edge_lengths = self.distance_transform( + edge_lengths, node_attrs, edge_index, atomic_numbers + ) + radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] + return radial * cutoff # [n_edges, n_basis] + + +@compile_mode("script") +class EquivariantProductBasisBlock(torch.nn.Module): + def __init__( + self, + node_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + correlation: int, + use_sc: bool = True, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_sc = use_sc + self.symmetric_contractions = SymmetricContraction( + irreps_in=node_feats_irreps, + irreps_out=target_irreps, + correlation=correlation, + num_elements=num_elements, + ) + # Update linear + self.linear = o3.Linear( + target_irreps, + target_irreps, + internal_weights=True, + shared_weights=True, + ) + + def forward( + self, + node_feats: torch.Tensor, + sc: Optional[torch.Tensor], + node_attrs: torch.Tensor, + ) -> torch.Tensor: + node_feats = self.symmetric_contractions(node_feats, node_attrs) + if self.use_sc and sc is not None: + return self.linear(node_feats) + sc + return self.linear(node_feats) + + +@compile_mode("script") +class InteractionBlock(torch.nn.Module): + def __init__( + self, + node_attrs_irreps: o3.Irreps, + node_feats_irreps: o3.Irreps, + edge_attrs_irreps: o3.Irreps, + edge_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + hidden_irreps: o3.Irreps, + avg_num_neighbors: float, + radial_MLP: Optional[List[int]] = None, + ) -> None: + super().__init__() + self.node_attrs_irreps = node_attrs_irreps + self.node_feats_irreps = node_feats_irreps + self.edge_attrs_irreps = edge_attrs_irreps + self.edge_feats_irreps = edge_feats_irreps + self.target_irreps = target_irreps + self.hidden_irreps = hidden_irreps + self.avg_num_neighbors = avg_num_neighbors + if radial_MLP is None: + radial_MLP = [64, 64, 64] + self.radial_MLP = radial_MLP + + self._setup() + + @abstractmethod + def _setup(self) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} + + +@compile_mode("script") +class TensorProductWeightsBlock(torch.nn.Module): + def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int): + super().__init__() + + weights = torch.empty( + (num_elements, num_edge_feats, num_feats_out), + dtype=torch.get_default_dtype(), + ) + torch.nn.init.xavier_uniform_(weights) + self.weights = torch.nn.Parameter(weights) + + def forward( + self, + sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded + edge_feats: torch.Tensor, + ): + return torch.einsum( + "be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights + ) + + def __repr__(self): + return ( + f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), ' + f"weights={np.prod(self.weights.shape)})" + ) + + +@compile_mode("script") +class ResidualElementDependentInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + self.conv_tp_weights = TensorProductWeightsBlock( + num_elements=self.node_attrs_irreps.num_irreps, + num_edge_feats=self.edge_feats_irreps.num_irreps, + num_feats_out=self.conv_tp.weight_numel, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return message + sc # [n_nodes, irreps] + + +@compile_mode("script") +class AgnosticNonlinearInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + tp_weights = self.conv_tp_weights(edge_feats) + node_feats = self.linear_up(node_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return message # [n_nodes, irreps] + + +@compile_mode("script") +class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) + self.irreps_out = self.irreps_out.simplify() + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + ) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = message + sc + return message # [n_nodes, irreps] + + +@compile_mode("script") +class RealAgnosticInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.irreps_out, self.node_attrs_irreps, self.irreps_out + ) + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + ) + + # Selector TensorProduct + self.skip_tp = o3.FullyConnectedTensorProduct( + self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + ) + self.reshape = reshape_irreps(self.irreps_out) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticAttResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + self.node_feats_down_irreps = o3.Irreps("64x0e") + # First linear + self.linear_up = o3.Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = o3.TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + # Convolution weights + self.linear_down = o3.Linear( + self.node_feats_irreps, + self.node_feats_down_irreps, + internal_weights=True, + shared_weights=True, + ) + input_dim = ( + self.edge_feats_irreps.num_irreps + + 2 * self.node_feats_down_irreps.num_irreps + ) + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + irreps_mid = irreps_mid.simplify() + self.irreps_out = self.target_irreps + self.linear = o3.Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + ) + + self.reshape = reshape_irreps(self.irreps_out) + + # Skip connection. + self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_linear(node_feats) + node_feats_up = self.linear_up(node_feats) + node_feats_down = self.linear_down(node_feats) + augmented_edge_feats = torch.cat( + [ + edge_feats, + node_feats_down[sender], + node_feats_down[receiver], + ], + dim=-1, + ) + tp_weights = self.conv_tp_weights(augmented_edge_feats) + mji = self.conv_tp( + node_feats_up[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class ScaleShiftBlock(torch.nn.Module): + def __init__(self, scale: float, shift: float): + super().__init__() + self.register_buffer( + "scale", torch.tensor(scale, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "shift", torch.tensor(shift, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.scale * x + self.shift + + def __repr__(self): + return ( + f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})" + ) diff --git a/hydragnn/utils/mace_utils/modules/irreps_tools.py b/hydragnn/utils/mace_utils/modules/irreps_tools.py new file mode 100644 index 000000000..642f3fa87 --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/irreps_tools.py @@ -0,0 +1,86 @@ +########################################################################################### +# Elementary tools for handling irreducible representations +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import List, Tuple + +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + + +# Based on mir-group/nequip +def tp_out_irreps_with_instructions( + irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps +) -> Tuple[o3.Irreps, List]: + trainable = True + + # Collect possible irreps and their instructions + irreps_out_list: List[Tuple[int, o3.Irreps]] = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 + if ir_out in target_irreps: + k = len(irreps_out_list) # instruction index + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + # We sort the output irreps of the tensor product so that we can simplify them + # when they are provided to the second o3.Linear + irreps_out = o3.Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + # Permute the output indexes of the instructions to match the sorted irreps: + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + instructions = sorted(instructions, key=lambda x: x[2]) + + return irreps_out, instructions + + +def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: + # Assuming simplified irreps + irreps_mid = [] + for _, ir_in in irreps: + found = False + + for mul, ir_out in target_irreps: + if ir_in == ir_out: + irreps_mid.append((mul, ir_out)) + found = True + break + + if not found: + raise RuntimeError(f"{ir_in} not in {target_irreps}") + + return o3.Irreps(irreps_mid) + + +@compile_mode("script") +class reshape_irreps(torch.nn.Module): + def __init__(self, irreps: o3.Irreps) -> None: + super().__init__() + self.irreps = o3.Irreps(irreps) + self.dims = [] + self.muls = [] + for mul, ir in self.irreps: + d = ir.dim + self.dims.append(d) + self.muls.append(mul) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + ix = 0 + out = [] + batch, _ = tensor.shape + for mul, d in zip(self.muls, self.dims): + field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + field = field.reshape(batch, mul, d) + out.append(field) + return torch.cat(out, dim=-1) diff --git a/hydragnn/utils/mace_utils/modules/loss.py b/hydragnn/utils/mace_utils/modules/loss.py new file mode 100644 index 000000000..b3421ef59 --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/loss.py @@ -0,0 +1,367 @@ +########################################################################################### +# Implementation of different loss functions +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import torch + +from mace.tools import TensorDict +from mace.tools.torch_geometric import Batch + + +def mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + return torch.mean(torch.square(ref["energy"] - pred["energy"])) # [] + + +def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + configs_weight = ref.weight # [n_graphs, ] + configs_energy_weight = ref.energy_weight # [n_graphs, ] + num_atoms = ref.ptr[1:] - ref.ptr[:-1] # [n_graphs,] + return torch.mean( + configs_weight + * configs_energy_weight + * torch.square((ref["energy"] - pred["energy"]) / num_atoms) + ) # [] + + +def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] + return torch.mean( + configs_weight + * configs_stress_weight + * torch.square(ref["stress"] - pred["stress"]) + ) # [] + + +def weighted_mean_squared_virials(ref: Batch, pred: TensorDict) -> torch.Tensor: + # energy: [n_graphs, ] + configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] + configs_virials_weight = ref.virials_weight.view(-1, 1, 1) # [n_graphs, ] + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,] + return torch.mean( + configs_weight + * configs_virials_weight + * torch.square((ref["virials"] - pred["virials"]) / num_atoms) + ) # [] + + +def mean_squared_error_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: + # forces: [n_atoms, 3] + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + return torch.mean( + configs_weight + * configs_forces_weight + * torch.square(ref["forces"] - pred["forces"]) + ) # [] + + +def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Tensor: + # dipole: [n_graphs, ] + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) # [n_graphs,1] + return torch.mean(torch.square((ref["dipole"] - pred["dipole"]) / num_atoms)) # [] + # return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # [] + + +def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: + # forces: [n_atoms, 3] + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze( + -1 + ) # [n_atoms, 1] + + # Define the multiplication factors for each condition + factors = torch.tensor([1.0, 0.7, 0.4, 0.1]) + + # Apply multiplication factors based on conditions + c1 = torch.norm(ref["forces"], dim=-1) < 100 + c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( + torch.norm(ref["forces"], dim=-1) < 200 + ) + c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( + torch.norm(ref["forces"], dim=-1) < 300 + ) + + err = ref["forces"] - pred["forces"] + + se = torch.zeros_like(err) + + se[c1] = torch.square(err[c1]) * factors[0] + se[c2] = torch.square(err[c2]) * factors[1] + se[c3] = torch.square(err[c3]) * factors[2] + se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] + + return torch.mean(configs_weight * configs_forces_weight * se) + + +def conditional_huber_forces( + ref: Batch, pred: TensorDict, huber_delta: float +) -> torch.Tensor: + # Define the multiplication factors for each condition + factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) + + # Apply multiplication factors based on conditions + c1 = torch.norm(ref["forces"], dim=-1) < 100 + c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( + torch.norm(ref["forces"], dim=-1) < 200 + ) + c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( + torch.norm(ref["forces"], dim=-1) < 300 + ) + c4 = ~(c1 | c2 | c3) + + se = torch.zeros_like(pred["forces"]) + + se[c1] = torch.nn.functional.huber_loss( + ref["forces"][c1], pred["forces"][c1], reduction="none", delta=factors[0] + ) + se[c2] = torch.nn.functional.huber_loss( + ref["forces"][c2], pred["forces"][c2], reduction="none", delta=factors[1] + ) + se[c3] = torch.nn.functional.huber_loss( + ref["forces"][c3], pred["forces"][c3], reduction="none", delta=factors[2] + ) + se[c4] = torch.nn.functional.huber_loss( + ref["forces"][c4], pred["forces"][c4], reduction="none", delta=factors[3] + ) + + return torch.mean(se) + + +class WeightedEnergyForcesLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return self.energy_weight * weighted_mean_squared_error_energy( + ref, pred + ) + self.forces_weight * mean_squared_error_forces(ref, pred) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f})" + ) + + +class WeightedForcesLoss(torch.nn.Module): + def __init__(self, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return self.forces_weight * mean_squared_error_forces(ref, pred) + + def __repr__(self): + return f"{self.__class__.__name__}(" f"forces_weight={self.forces_weight:.3f})" + + +class WeightedEnergyForcesStressLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.energy_weight * weighted_mean_squared_error_energy(ref, pred) + + self.forces_weight * mean_squared_error_forces(ref, pred) + + self.stress_weight * weighted_mean_squared_stress(ref, pred) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + return ( + self.energy_weight + * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + + self.forces_weight * self.huber_loss(ref["forces"], pred["forces"]) + + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class UniversalLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + self.huber_delta = huber_delta + self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + return ( + self.energy_weight + * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + + self.forces_weight + * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) + + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedEnergyForcesVirialsLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 + ) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "virials_weight", + torch.tensor(virials_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.energy_weight * weighted_mean_squared_error_energy(ref, pred) + + self.forces_weight * mean_squared_error_forces(ref, pred) + + self.virials_weight * weighted_mean_squared_virials(ref, pred) + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" + ) + + +class DipoleSingleLoss(torch.nn.Module): + def __init__(self, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100.0 + ) # multiply by 100 to have the right scale for the loss + + def __repr__(self): + return f"{self.__class__.__name__}(" f"dipole_weight={self.dipole_weight:.3f})" + + +class WeightedEnergyForcesDipoleLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return ( + self.energy_weight * weighted_mean_squared_error_energy(ref, pred) + + self.forces_weight * mean_squared_error_forces(ref, pred) + + self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100 + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" + ) diff --git a/hydragnn/utils/mace_utils/modules/models.py b/hydragnn/utils/mace_utils/modules/models.py new file mode 100644 index 000000000..3e5cb6626 --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/models.py @@ -0,0 +1,1065 @@ +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import numpy as np +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.data import AtomicData +from mace.modules.radial import ZBLBasis +from mace.tools.scatter import scatter_sum + +from .blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + ScaleShiftBlock, +) +from .utils import ( + compute_fixed_charge_dipole, + compute_forces, + get_edge_vectors_and_lengths, + get_outputs, + get_symmetric_displacement, +) + +# pylint: disable=C0302 + + +@compile_mode("script") +class MACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: Union[int, List[int]], + gate: Optional[Callable], + pair_repulsion: bool = False, + distance_transform: str = "None", + radial_MLP: Optional[List[int]] = None, + radial_type: Optional[str] = "bessel", + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + if isinstance(correlation, int): + correlation = [correlation] * num_interactions + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + distance_transform=distance_transform, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + if pair_repulsion: + self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) + self.pair_repulsion = True + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readout + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation[0], + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearReadoutBlock(hidden_irreps)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = str( + hidden_irreps[0] + ) # Select only scalars for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation[i + 1], + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(LinearReadoutBlock(hidden_irreps)) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + pair_energy = scatter_sum( + src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + else: + pair_node_energy = torch.zeros_like(node_e0) + pair_energy = torch.zeros_like(e0) + + # Interactions + energies = [e0, pair_energy] + node_energies_list = [node_e0, pair_node_energy] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_feats_list.append(node_feats) + node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + energy = scatter_sum( + src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + energies.append(energy) + node_energies_list.append(node_energies) + + # Concatenate node features + node_feats_out = torch.cat(node_feats_list, dim=-1) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + + # Outputs + forces, virials, stress, hessian = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + ) + + return { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "hessian": hessian, + "node_feats": node_feats_out, + } + + +@compile_mode("script") +class ScaleShiftMACE(MACE): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["positions"].requires_grad_(True) + data["node_attrs"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + else: + pair_node_energy = torch.zeros_like(node_e0) + # Interactions + node_es_list = [pair_node_energy] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] + ) + node_feats_list.append(node_feats) + node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + # Concatenate node features + node_feats_out = torch.cat(node_feats_list, dim=-1) + # print("node_es_list", node_es_list) + # Sum over interactions + node_inter_es = torch.sum( + torch.stack(node_es_list, dim=0), dim=0 + ) # [n_nodes, ] + node_inter_es = self.scale_shift(node_inter_es) + + # Sum over nodes in graph + inter_e = scatter_sum( + src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Add E_0 and (scaled) interaction energy + total_energy = e0 + inter_e + node_energy = node_e0 + node_inter_es + forces, virials, stress, hessian = get_outputs( + energy=inter_e, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + ) + output = { + "energy": total_energy, + "node_energy": node_energy, + "interaction_energy": inter_e, + "forces": forces, + "virials": virials, + "stress": stress, + "hessian": hessian, + "displacement": displacement, + "node_feats": node_feats_out, + } + + return output + + +class BOTNet(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + gate: Optional[Callable], + avg_num_neighbors: float, + atomic_numbers: List[int], + ): + super().__init__() + self.r_max = r_max + self.atomic_numbers = atomic_numbers + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + + # Interactions and readouts + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + self.interactions = torch.nn.ModuleList() + self.readouts = torch.nn.ModuleList() + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions.append(inter) + self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + + for i in range(num_interactions - 1): + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=inter.irreps_out, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + ) + self.interactions.append(inter) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate) + ) + else: + self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + + def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: + # Setup + data.positions.requires_grad = True + + # Atomic energies + node_e0 = self.atomic_energies_fn(data.node_attrs) + e0 = scatter_sum( + src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data.node_attrs) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + energies = [e0] + for interaction, readout in zip(self.interactions, self.readouts): + node_feats = interaction( + node_attrs=data.node_attrs, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data.edge_index, + ) + node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + energy = scatter_sum( + src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs,] + energies.append(energy) + + # Sum over energy contributions + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + + output = { + "energy": total_energy, + "contributions": contributions, + "forces": compute_forces( + energy=total_energy, positions=data.positions, training=training + ), + } + + return output + + +class ScaleShiftBOTNet(BOTNet): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: + # Setup + data.positions.requires_grad = True + + # Atomic energies + node_e0 = self.atomic_energies_fn(data.node_attrs) + e0 = scatter_sum( + src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data.node_attrs) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data.positions, edge_index=data.edge_index, shifts=data.shifts + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + node_es_list = [] + for interaction, readout in zip(self.interactions, self.readouts): + node_feats = interaction( + node_attrs=data.node_attrs, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data.edge_index, + ) + + node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + + # Sum over interactions + node_inter_es = torch.sum( + torch.stack(node_es_list, dim=0), dim=0 + ) # [n_nodes, ] + node_inter_es = self.scale_shift(node_inter_es) + + # Sum over nodes in graph + inter_e = scatter_sum( + src=node_inter_es, index=data.batch, dim=-1, dim_size=data.num_graphs + ) # [n_graphs,] + + # Add E_0 and (scaled) interaction energy + total_e = e0 + inter_e + + output = { + "energy": total_e, + "forces": compute_forces( + energy=inter_e, positions=data.positions, training=training + ), + } + + return output + + +@compile_mode("script") +class AtomicDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[ + None + ], # Just here to make it compatible with energy models, MUST be None + radial_type: Optional[str] = "bessel", + radial_MLP: Optional[List[int]] = None, + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + assert atomic_energies is None + + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + + # Interactions and readouts + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[1] + ) # Select only l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=True + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, # pylint: disable=W0613 + compute_force: bool = False, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + assert compute_force is False + assert compute_virials is False + assert compute_stress is False + assert compute_displacement is False + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] + dipoles.append(node_dipoles) + + # Compute the dipoles + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"], + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + output = { + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output + + +@compile_mode("script") +class EnergyDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[np.ndarray], + radial_MLP: Optional[List[int]] = None, + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readouts + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[:2] + ) # Select scalars and l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=False + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + energies = [e0] + node_energies_list = [node_e0] + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] + # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + node_energies = node_out[:, 0] + energy = scatter_sum( + src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + energies.append(energy) + node_dipoles = node_out[:, 1:] + dipoles.append(node_dipoles) + + # Compute the energies and dipoles + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"].unsqueeze(-1), + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + forces, virials, stress, _ = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + ) + + output = { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output diff --git a/hydragnn/utils/mace_utils/modules/radial.py b/hydragnn/utils/mace_utils/modules/radial.py new file mode 100644 index 000000000..a928c1847 --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/radial.py @@ -0,0 +1,323 @@ +########################################################################################### +# Radial basis and cutoff +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import ase +import numpy as np +import torch +from e3nn.util.jit import compile_mode + +from mace.tools.compile import simplify_if_compile +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class BesselBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8, trainable=False): + super().__init__() + + bessel_weights = ( + np.pi + / r_max + * torch.linspace( + start=1.0, + end=num_basis, + steps=num_basis, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "prefactor", + torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] + return self.prefactor * (numerator / x) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " + f"trainable={self.bessel_weights.requires_grad})" + ) + + +@compile_mode("script") +class ChebychevBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8): + super().__init__() + self.register_buffer( + "n", + torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze( + 0 + ), + ) + self.num_basis = num_basis + self.r_max = r_max + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x.repeat(1, self.num_basis) + n = self.n.repeat(len(x), 1) + return torch.special.chebyshev_polynomial_t(x, n) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis}," + ) + + +@compile_mode("script") +class GaussianBasis(torch.nn.Module): + """ + Gaussian basis functions + """ + + def __init__(self, r_max: float, num_basis=128, trainable=False): + super().__init__() + gaussian_weights = torch.linspace( + start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype() + ) + if trainable: + self.gaussian_weights = torch.nn.Parameter( + gaussian_weights, requires_grad=True + ) + else: + self.register_buffer("gaussian_weights", gaussian_weights) + self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x - self.gaussian_weights + return torch.exp(self.coeff * torch.pow(x, 2)) + + +@compile_mode("script") +class PolynomialCutoff(torch.nn.Module): + """ + Equation (8) + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6): + super().__init__() + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # yapf: disable + envelope = ( + 1.0 + - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) + + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) + - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) + ) + # yapf: enable + + # noinspection PyUnresolvedReferences + return envelope * (x < self.r_max) + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" + + +@compile_mode("script") +class ZBLBasis(torch.nn.Module): + """ + Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6, trainable=False): + super().__init__() + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + # Pre-calculate the p coefficients for the ZBL potential + self.register_buffer( + "c", + torch.tensor( + [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() + ), + ) + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + self.cutoff = PolynomialCutoff(r_max, p) + if trainable: + self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) + self.a_prefactor = torch.nn.Parameter( + torch.tensor(0.4543, requires_grad=True) + ) + else: + self.register_buffer("a_exp", torch.tensor(0.300)) + self.register_buffer("a_prefactor", torch.tensor(0.4543)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + a = ( + self.a_prefactor + * 0.529 + / (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp)) + ) + r_over_a = x / a + phi = ( + self.c[0] * torch.exp(-3.2 * r_over_a) + + self.c[1] * torch.exp(-0.9423 * r_over_a) + + self.c[2] * torch.exp(-0.4028 * r_over_a) + + self.c[3] * torch.exp(-0.2016 * r_over_a) + ) + v_edges = (14.3996 * Z_u * Z_v) / x * phi + r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] + envelope = ( + 1.0 + - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p) + + self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1) + - (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2) + ) * (x < r_max) + v_edges = 0.5 * v_edges * envelope + V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) + return V_ZBL.squeeze(-1) + + def __repr__(self): + return f"{self.__class__.__name__}(r_max={self.r_max}, c={self.c})" + + +@compile_mode("script") +class AgnesiTransform(torch.nn.Module): + """ + Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160 + """ + + def __init__( + self, + q: float = 0.9183, + p: float = 4.5791, + a: float = 1.0805, + trainable=False, + ): + super().__init__() + self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype())) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True)) + self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True)) + self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + return ( + 1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p))) + ) ** (-1) + + def __repr__(self): + return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})" + + +@simplify_if_compile +@compile_mode("script") +class SoftTransform(torch.nn.Module): + """ + Soft Transform + """ + + def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False): + super().__init__() + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(a, requires_grad=True)) + self.b = torch.nn.Parameter(torch.tensor(b, requires_grad=True)) + else: + self.register_buffer("a", torch.tensor(a)) + self.register_buffer("b", torch.tensor(b)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4 + y = ( + x + + (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b)) + + 1 / 2 + ) + return y + + def __repr__(self): + return f"{self.__class__.__name__}(a={self.a.item()}, b={self.b.item()})" diff --git a/hydragnn/utils/mace_utils/modules/symmetric_contraction.py b/hydragnn/utils/mace_utils/modules/symmetric_contraction.py new file mode 100644 index 000000000..9db75da02 --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/symmetric_contraction.py @@ -0,0 +1,233 @@ +########################################################################################### +# Implementation of the symmetric contraction algorithm presented in the MACE paper +# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Dict, Optional, Union + +import opt_einsum_fx +import torch +import torch.fx +from e3nn import o3 +from e3nn.util.codegen import CodeGenMixin +from e3nn.util.jit import compile_mode + +from mace.tools.cg import U_matrix_real + +BATCH_EXAMPLE = 10 +ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] + + +@compile_mode("script") +class SymmetricContraction(CodeGenMixin, torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: Union[int, Dict[str, int]], + irrep_normalization: str = "component", + path_normalization: str = "element", + internal_weights: Optional[bool] = None, + shared_weights: Optional[bool] = None, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + if irrep_normalization is None: + irrep_normalization = "component" + + if path_normalization is None: + path_normalization = "element" + + assert irrep_normalization in ["component", "norm", "none"] + assert path_normalization in ["element", "path", "none"] + + self.irreps_in = o3.Irreps(irreps_in) + self.irreps_out = o3.Irreps(irreps_out) + + del irreps_in, irreps_out + + if not isinstance(correlation, tuple): + corr = correlation + correlation = {} + for irrep_out in self.irreps_out: + correlation[irrep_out] = corr + + assert shared_weights or not internal_weights + + if internal_weights is None: + internal_weights = True + + self.internal_weights = internal_weights + self.shared_weights = shared_weights + + del internal_weights, shared_weights + + self.contractions = torch.nn.ModuleList() + for irrep_out in self.irreps_out: + self.contractions.append( + Contraction( + irreps_in=self.irreps_in, + irrep_out=o3.Irreps(str(irrep_out.ir)), + correlation=correlation[irrep_out], + internal_weights=self.internal_weights, + num_elements=num_elements, + weights=self.shared_weights, + ) + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + outs = [contraction(x, y) for contraction in self.contractions] + return torch.cat(outs, dim=-1) + + +@compile_mode("script") +class Contraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps, + correlation: int, + internal_weights: bool = True, + num_elements: Optional[int] = None, + weights: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + + self.num_features = irreps_in.count((0, 1)) + self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) + self.correlation = correlation + dtype = torch.get_default_dtype() + for nu in range(1, correlation + 1): + U_matrix = U_matrix_real( + irreps_in=self.coupling_irreps, + irreps_out=irrep_out, + correlation=nu, + dtype=dtype, + )[-1] + self.register_buffer(f"U_matrix_{nu}", U_matrix) + + # Tensor contraction equations + self.contractions_weighting = torch.nn.ModuleList() + self.contractions_features = torch.nn.ModuleList() + + # Create weight for product basis + self.weights = torch.nn.ParameterList([]) + + for i in range(correlation, 0, -1): + # Shapes definying + num_params = self.U_tensors(i).size()[-1] + num_equivariance = 2 * irrep_out.lmax + 1 + num_ell = self.U_tensors(i).size()[-2] + + if i == correlation: + parse_subscript_main = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + + ["ik,ekc,bci,be -> bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + ) + graph_module_main = torch.fx.symbolic_trace( + lambda x, y, w, z: torch.einsum( + "".join(parse_subscript_main), x, y, w, z + ) + ) + + # Optimizing the contractions + self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( + model=graph_module_main, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights_max = w + else: + # Generate optimized contractions equations + parse_subscript_weighting = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + + ["k,ekc,be->bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + ) + parse_subscript_features = ( + ["bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + + ["i,bci->bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + ) + + # Symbolic tracing of contractions + graph_module_weighting = torch.fx.symbolic_trace( + lambda x, y, z: torch.einsum( + "".join(parse_subscript_weighting), x, y, z + ) + ) + graph_module_features = torch.fx.symbolic_trace( + lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) + ) + + # Optimizing the contractions + graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( + model=graph_module_weighting, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + graph_opt_features = opt_einsum_fx.optimize_einsums_full( + model=graph_module_features, + example_inputs=( + torch.randn( + [BATCH_EXAMPLE, self.num_features, num_equivariance] + + [num_ell] * i + ).squeeze(2), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + ), + ) + self.contractions_weighting.append(graph_opt_weighting) + self.contractions_features.append(graph_opt_features) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights.append(w) + if not internal_weights: + self.weights = weights[:-1] + self.weights_max = weights[-1] + + def forward(self, x: torch.Tensor, y: torch.Tensor): + out = self.graph_opt_main( + self.U_tensors(self.correlation), + self.weights_max, + x, + y, + ) + for i, (weight, contract_weights, contract_features) in enumerate( + zip(self.weights, self.contractions_weighting, self.contractions_features) + ): + c_tensor = contract_weights( + self.U_tensors(self.correlation - i - 1), + weight, + y, + ) + c_tensor = c_tensor + out + out = contract_features(c_tensor, x) + + return out.view(out.shape[0], -1) + + def U_tensors(self, nu: int): + return dict(self.named_buffers())[f"U_matrix_{nu}"] diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py new file mode 100644 index 000000000..37fef1bbd --- /dev/null +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -0,0 +1,414 @@ +########################################################################################### +# Utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn +import torch.utils.data +from scipy.constants import c, e + +from mace.tools import to_numpy +from mace.tools.scatter import scatter_sum +from mace.tools.torch_geometric.batch import Batch + +from .blocks import AtomicEnergiesBlock + + +def compute_forces( + energy: torch.Tensor, positions: torch.Tensor, training: bool = True +) -> torch.Tensor: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + gradient = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, # For complete dissociation turn to true + )[ + 0 + ] # [n_nodes, 3] + if gradient is None: + return torch.zeros_like(positions) + return -1 * gradient + + +def compute_forces_virials( + energy: torch.Tensor, + positions: torch.Tensor, + displacement: torch.Tensor, + cell: torch.Tensor, + training: bool = True, + compute_stress: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + forces, virials = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions, displacement], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, + ) + stress = torch.zeros_like(displacement) + if compute_stress and virials is not None: + cell = cell.view(-1, 3, 3) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) + stress = virials / volume.view(-1, 1, 1) + stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) + if forces is None: + forces = torch.zeros_like(positions) + if virials is None: + virials = torch.zeros((1, 3, 3)) + + return -1 * forces, -1 * virials, stress + + +def get_symmetric_displacement( + positions: torch.Tensor, + unit_shifts: torch.Tensor, + cell: Optional[torch.Tensor], + edge_index: torch.Tensor, + num_graphs: int, + batch: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if cell is None: + cell = torch.zeros( + num_graphs * 3, + 3, + dtype=positions.dtype, + device=positions.device, + ) + sender = edge_index[0] + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=positions.dtype, + device=positions.device, + ) + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * ( + displacement + displacement.transpose(-1, -2) + ) # From https://github.com/mir-group/nequip + positions = positions + torch.einsum( + "be,bec->bc", positions, symmetric_displacement[batch] + ) + cell = cell.view(-1, 3, 3) + cell = cell + torch.matmul(cell, symmetric_displacement) + shifts = torch.einsum( + "be,bec->bc", + unit_shifts, + cell[batch[sender]], + ) + return positions, shifts, displacement + + +@torch.jit.unused +def compute_hessians_vmap( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + forces_flatten = forces.view(-1) + num_elements = forces_flatten.shape[0] + + def get_vjp(v): + return torch.autograd.grad( + -1 * forces_flatten, + positions, + v, + retain_graph=True, + create_graph=False, + allow_unused=False, + ) + + I_N = torch.eye(num_elements).to(forces.device) + try: + chunk_size = 1 if num_elements < 64 else 16 + gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( + I_N + )[0] + except RuntimeError: + gradient = compute_hessians_loop(forces, positions) + if gradient is None: + return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) + return gradient + + +@torch.jit.unused +def compute_hessians_loop( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + + hessian = [] + for grad_elem in forces.view(-1): + hess_row = torch.autograd.grad( + outputs=[-1 * grad_elem], + inputs=[positions], + grad_outputs=torch.ones_like(grad_elem), + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + hess_row = hess_row.detach() # this makes it very slow? but needs less memory + if hess_row is None: + hessian.append(torch.zeros_like(positions)) + else: + hessian.append(hess_row) + hessian = torch.stack(hessian) + return hessian + + +def get_outputs( + energy: torch.Tensor, + positions: torch.Tensor, + displacement: Optional[torch.Tensor], + cell: torch.Tensor, + training: bool = False, + compute_force: bool = True, + compute_virials: bool = True, + compute_stress: bool = True, + compute_hessian: bool = False, +) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + if (compute_virials or compute_stress) and displacement is not None: + # forces come for free + forces, virials, stress = compute_forces_virials( + energy=energy, + positions=positions, + displacement=displacement, + cell=cell, + compute_stress=compute_stress, + training=(training or compute_hessian), + ) + elif compute_force: + forces, virials, stress = ( + compute_forces( + energy=energy, + positions=positions, + training=(training or compute_hessian), + ), + None, + None, + ) + else: + forces, virials, stress = (None, None, None) + if compute_hessian: + assert forces is not None, "Forces must be computed to get the hessian" + hessian = compute_hessians_vmap(forces, positions) + else: + hessian = None + return forces, virials, stress, hessian + + +def get_edge_vectors_and_lengths( + positions: torch.Tensor, # [n_nodes, 3] + edge_index: torch.Tensor, # [2, n_edges] + shifts: torch.Tensor, # [n_edges, 3] + normalize: bool = False, + eps: float = 1e-9, +) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + if normalize: + vectors_normed = vectors / (lengths + eps) + return vectors_normed, lengths + + return vectors, lengths + + +def _check_non_zero(std): + if std == 0.0: + logging.warning( + "Standard deviation of the scaling is zero, Changing to no scaling" + ) + std = 1.0 + return std + + +def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): + out = [] + for i in range(num_layers - 1): + out.append( + x[ + :, + i + * (l_max + 1) ** 2 + * num_features : (i * (l_max + 1) ** 2 + 1) + * num_features, + ] + ) + out.append(x[:, -num_features:]) + return torch.cat(out, dim=-1) + + +def compute_mean_std_atomic_inter_energy( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + avg_atom_inter_es_list = [] + + for batch in data_loader: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs + ) + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + avg_atom_inter_es_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + + avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] + mean = to_numpy(torch.mean(avg_atom_inter_es)).item() + std = to_numpy(torch.std(avg_atom_inter_es)).item() + std = _check_non_zero(std) + + return mean, std + + +def _compute_mean_std_atomic_inter_energy( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs + ) + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes + return atom_energies + + +def compute_mean_rms_energy_forces( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + + for batch in data_loader: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs + ) + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + + mean = to_numpy(torch.mean(atom_energies)).item() + rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + rms = _check_non_zero(rms) + + return mean, rms + + +def _compute_mean_rms_energy_forces( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs + ) + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } + forces = batch.forces # {[n_graphs*n_atoms,3], } + + return atom_energies, forces + + +def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: + num_neighbors = [] + + for batch in data_loader: + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + return to_numpy(avg_num_neighbors).item() + + +def compute_statistics( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float, float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + num_neighbors = [] + + for batch in data_loader: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs + ) + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + + mean = to_numpy(torch.mean(atom_energies)).item() + rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + + return to_numpy(avg_num_neighbors).item(), mean, rms + + +def compute_rms_dipoles( + data_loader: torch.utils.data.DataLoader, +) -> Tuple[float, float]: + dipoles_list = [] + for batch in data_loader: + dipoles_list.append(batch.dipole) # {[n_graphs,3], } + + dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } + rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() + rms = _check_non_zero(rms) + return rms + + +def compute_fixed_charge_dipole( + charges: torch.Tensor, + positions: torch.Tensor, + batch: torch.Tensor, + num_graphs: int, +) -> torch.Tensor: + mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] + return scatter_sum( + src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs + ) # [N_graphs,3] diff --git a/hydragnn/utils/mace_utils/tools/__init__.py b/hydragnn/utils/mace_utils/tools/__init__.py new file mode 100644 index 000000000..54c594550 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/__init__.py @@ -0,0 +1,72 @@ +from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser_tools import check_args +from .cg import U_matrix_real +from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState +from .finetuning_utils import load_foundations +from .torch_tools import ( + TensorDict, + cartesian_to_spherical, + count_parameters, + init_device, + init_wandb, + set_default_dtype, + set_seeds, + spherical_to_cartesian, + to_numpy, + to_one_hot, + voigt_to_matrix, +) +from .train import SWAContainer, evaluate, train +from .utils import ( + AtomicNumberTable, + MetricsLogger, + atomic_numbers_to_indices, + compute_c, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, + get_atomic_number_table_from_zs, + get_optimizer, + get_tag, + setup_logger, +) + +__all__ = [ + "TensorDict", + "AtomicNumberTable", + "atomic_numbers_to_indices", + "to_numpy", + "to_one_hot", + "build_default_arg_parser", + "check_args", + "set_seeds", + "init_device", + "setup_logger", + "get_tag", + "count_parameters", + "get_optimizer", + "MetricsLogger", + "get_atomic_number_table_from_zs", + "train", + "evaluate", + "SWAContainer", + "CheckpointHandler", + "CheckpointIO", + "CheckpointState", + "set_default_dtype", + "compute_mae", + "compute_rel_mae", + "compute_rmse", + "compute_rel_rmse", + "compute_q95", + "compute_c", + "U_matrix_real", + "spherical_to_cartesian", + "cartesian_to_spherical", + "voigt_to_matrix", + "init_wandb", + "load_foundations", + "build_preprocess_arg_parser", +] diff --git a/hydragnn/utils/mace_utils/tools/arg_parser.py b/hydragnn/utils/mace_utils/tools/arg_parser.py new file mode 100644 index 000000000..2b0e2b56e --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/arg_parser.py @@ -0,0 +1,792 @@ +########################################################################################### +# Parsing functionalities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import os +from typing import Optional + + +def build_default_arg_parser() -> argparse.ArgumentParser: + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to agregate options", + ) + except ImportError: + parser = argparse.ArgumentParser() + + # Name and seed + parser.add_argument("--name", help="experiment name", required=True) + parser.add_argument("--seed", help="random seed", type=int, default=123) + + # Directories + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--log_dir", help="directory for log files", type=str, default=None + ) + parser.add_argument( + "--model_dir", help="directory for final model", type=str, default=None + ) + parser.add_argument( + "--checkpoints_dir", + help="directory for checkpoint files", + type=str, + default=None, + ) + parser.add_argument( + "--results_dir", help="directory for results", type=str, default=None + ) + parser.add_argument( + "--downloads_dir", help="directory for downloads", type=str, default=None + ) + + # Device and logging + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda", "mps"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--distributed", + help="train in multi-GPU data parallel mode", + action="store_true", + default=False, + ) + parser.add_argument("--log_level", help="log level", type=str, default="INFO") + + parser.add_argument( + "--error_table", + help="Type of error table produced at the end of the training", + type=str, + choices=[ + "PerAtomRMSE", + "TotalRMSE", + "PerAtomRMSEstressvirials", + "PerAtomMAEstressvirials", + "PerAtomMAE", + "TotalMAE", + "DipoleRMSE", + "DipoleMAE", + "EnergyDipoleRMSE", + ], + default="PerAtomRMSE", + ) + + # Model + parser.add_argument( + "--model", + help="model type", + default="MACE", + choices=[ + "BOTNet", + "MACE", + "ScaleShiftMACE", + "ScaleShiftBOTNet", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + ], + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--radial_type", + help="type of radial basis functions", + type=str, + default="bessel", + choices=["bessel", "gaussian", "chebyshev"], + ) + parser.add_argument( + "--num_radial_basis", + help="number of radial basis functions", + type=int, + default=8, + ) + parser.add_argument( + "--num_cutoff_basis", + help="number of basis functions for smooth cutoff", + type=int, + default=5, + ) + parser.add_argument( + "--pair_repulsion", + help="use pair repulsion term with ZBL potential", + action="store_true", + default=False, + ) + parser.add_argument( + "--distance_transform", + help="use distance transform for radial basis functions", + default="None", + choices=["None", "Agnesi", "Soft"], + ) + parser.add_argument( + "--interaction", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticAttResidualInteractionBlock", + "RealAgnosticInteractionBlock", + ], + ) + parser.add_argument( + "--interaction_first", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticInteractionBlock", + ], + ) + parser.add_argument( + "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 + ) + parser.add_argument( + "--correlation", help="correlation order at each layer", type=int, default=3 + ) + parser.add_argument( + "--num_interactions", help="number of interactions", type=int, default=2 + ) + parser.add_argument( + "--MLP_irreps", + help="hidden irreps of the MLP in last readout", + type=str, + default="16x0e", + ) + parser.add_argument( + "--radial_MLP", + help="width of the radial MLP", + type=str, + default="[64, 64, 64]", + ) + parser.add_argument( + "--hidden_irreps", + help="irreps for hidden node states", + type=str, + default=None, + ) + # add option to specify irreps by channel number and max L + parser.add_argument( + "--num_channels", + help="number of embedding channels", + type=int, + default=None, + ) + parser.add_argument( + "--max_L", + help="max L equivariance of the message", + type=int, + default=None, + ) + parser.add_argument( + "--gate", + help="non linearity for last readout", + type=str, + default="silu", + choices=["silu", "tanh", "abs", "None"], + ) + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--avg_num_neighbors", + help="normalization factor for the message", + type=float, + default=1, + ) + parser.add_argument( + "--compute_avg_num_neighbors", + help="normalization factor for the message", + type=bool, + default=True, + ) + parser.add_argument( + "--compute_stress", + help="Select True to compute stress", + type=bool, + default=False, + ) + parser.add_argument( + "--compute_forces", + help="Select True to compute forces", + type=bool, + default=True, + ) + + # Dataset + parser.add_argument( + "--train_file", + help="Training set file, format is .xyz or .h5", + type=str, + required=True, + ) + parser.add_argument( + "--valid_file", + help="Validation set .xyz or .h5 file", + default=None, + type=str, + required=False, + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set .xyz pt .h5 file", + type=str, + ) + parser.add_argument( + "--test_dir", + help="Path to directory with test files named as test_*.h5", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multi_processed_test", + help="Boolean value for whether the test data was multiprocessed", + type=bool, + default=False, + required=False, + ) + parser.add_argument( + "--num_workers", + help="Number of workers for data loading", + type=int, + default=0, + ) + parser.add_argument( + "--pin_memory", + help="Pin memory for data loading", + default=True, + type=bool, + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--mean", + help="Mean energy per atom of training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--std", + help="Standard deviation of force components in the training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--statistics_file", + help="json file containing statistics of training set", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--keep_isolated_atoms", + help="Keep isolated atoms in the dataset, useful for transfer learning", + type=bool, + default=False, + ) + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default="REF_energy", + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default="REF_forces", + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default="REF_virials", + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default="REF_stress", + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default="REF_dipole", + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default="REF_charges", + ) + + # Loss and optimization + parser.add_argument( + "--loss", + help="type of loss", + default="weighted", + choices=[ + "ef", + "weighted", + "forces_only", + "virials", + "stress", + "dipole", + "huber", + "universal", + "energy_forces_dipole", + ], + ) + parser.add_argument( + "--forces_weight", help="weight of forces loss", type=float, default=100.0 + ) + parser.add_argument( + "--swa_forces_weight", + "--stage_two_forces_weight", + help="weight of forces loss after starting Stage Two (previously called swa)", + type=float, + default=100.0, + dest="swa_forces_weight", + ) + parser.add_argument( + "--energy_weight", help="weight of energy loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_energy_weight", + "--stage_two_energy_weight", + help="weight of energy loss after starting Stage Two (previously called swa)", + type=float, + default=1000.0, + dest="swa_energy_weight", + ) + parser.add_argument( + "--virials_weight", help="weight of virials loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_virials_weight", + "--stage_two_virials_weight", + help="weight of virials loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_virials_weight", + ) + parser.add_argument( + "--stress_weight", help="weight of virials loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_stress_weight", + "--stage_two_stress_weight", + help="weight of stress loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_stress_weight", + ) + parser.add_argument( + "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_dipole_weight", + "--stage_two_dipole_weight", + help="weight of dipoles after starting Stage Two (previously called swa)", + type=float, + default=1.0, + dest="swa_dipole_weight", + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--huber_delta", + help="delta parameter for huber loss", + type=float, + default=0.01, + ) + parser.add_argument( + "--optimizer", + help="Optimizer for parameter optimization", + type=str, + default="adam", + choices=["adam", "adamw", "schedulefree"], + ) + parser.add_argument( + "--beta", + help="Beta parameter for the optimizer", + type=float, + default=0.9, + ) + parser.add_argument("--batch_size", help="batch size", type=int, default=10) + parser.add_argument( + "--valid_batch_size", help="Validation batch size", type=int, default=10 + ) + parser.add_argument( + "--lr", help="Learning rate of optimizer", type=float, default=0.01 + ) + parser.add_argument( + "--swa_lr", + "--stage_two_lr", + help="Learning rate of optimizer in Stage Two (previously called swa)", + type=float, + default=1e-3, + dest="swa_lr", + ) + parser.add_argument( + "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 + ) + parser.add_argument( + "--amsgrad", + help="use amsgrad variant of optimizer", + action="store_true", + default=True, + ) + parser.add_argument( + "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" + ) + parser.add_argument( + "--lr_factor", help="Learning rate factor", type=float, default=0.8 + ) + parser.add_argument( + "--scheduler_patience", help="Learning rate factor", type=int, default=50 + ) + parser.add_argument( + "--lr_scheduler_gamma", + help="Gamma of learning rate scheduler", + type=float, + default=0.9993, + ) + parser.add_argument( + "--swa", + "--stage_two", + help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", + action="store_true", + default=False, + dest="swa", + ) + parser.add_argument( + "--start_swa", + "--start_stage_two", + help="Number of epochs before changing to Stage Two loss weights", + type=int, + default=None, + dest="start_swa", + ) + parser.add_argument( + "--ema", + help="use Exponential Moving Average", + action="store_true", + default=False, + ) + parser.add_argument( + "--ema_decay", + help="Exponential Moving Average decay", + type=float, + default=0.99, + ) + parser.add_argument( + "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 + ) + parser.add_argument( + "--patience", + help="Maximum number of consecutive epochs of increasing loss", + type=int, + default=2048, + ) + parser.add_argument( + "--foundation_model", + help="Path to the foundation model for transfer learning", + type=str, + default=None, + ) + parser.add_argument( + "--foundation_model_readout", + help="Use readout of foundation model for transfer learning", + action="store_false", + default=True, + ) + parser.add_argument( + "--eval_interval", help="evaluate model every epochs", type=int, default=1 + ) + parser.add_argument( + "--keep_checkpoints", + help="keep all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_all_checkpoints", + help="save all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--restart_latest", + help="restart optimizer from latest checkpoint", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_cpu", + help="Save a model to be loaded on cpu", + action="store_true", + default=False, + ) + parser.add_argument( + "--clip_grad", + help="Gradient Clipping Value", + type=check_float_or_none, + default=10.0, + ) + # options for using Weights and Biases for experiment tracking + # to install see https://wandb.ai + parser.add_argument( + "--wandb", + help="Use Weights and Biases for experiment tracking", + action="store_true", + default=False, + ) + parser.add_argument( + "--wandb_dir", + help="An absolute path to a directory where Weights and Biases metadata will be stored", + type=str, + default=None, + ) + parser.add_argument( + "--wandb_project", + help="Weights and Biases project name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_entity", + help="Weights and Biases entity name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_name", + help="Weights and Biases experiment name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_log_hypers", + help="The hyperparameters to log in Weights and Biases", + type=list, + default=[ + "num_channels", + "max_L", + "correlation", + "lr", + "swa_lr", + "weight_decay", + "batch_size", + "max_num_epochs", + "start_swa", + "energy_weight", + "forces_weight", + ], + ) + return parser + + +def build_preprocess_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--train_file", + help="Training set h5 file", + type=str, + default=None, + required=True, + ) + parser.add_argument( + "--valid_file", + help="Training set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--num_process", + help="The user defined number of processes to use, as well as the number of files created.", + type=int, + default=int(os.cpu_count() / 4), + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--h5_prefix", + help="Prefix for h5 files when saving", + type=str, + default="", + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default="REF_energy", + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default="REF_forces", + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default="REF_virials", + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default="REF_stress", + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default="REF_dipole", + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default="REF_charges", + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--compute_statistics", + help="Compute statistics for the dataset", + action="store_true", + default=False, + ) + parser.add_argument( + "--batch_size", + help="batch size to compute average number of neighbours", + type=int, + default=16, + ) + + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--shuffle", + help="Shuffle the training dataset", + type=bool, + default=True, + ) + parser.add_argument( + "--seed", + help="Random seed for splitting training and validation sets", + type=int, + default=123, + ) + return parser + + +def check_float_or_none(value: str) -> Optional[float]: + try: + return float(value) + except ValueError: + if value != "None": + raise argparse.ArgumentTypeError( + f"{value} is an invalid value (float or None)" + ) from None + return None diff --git a/hydragnn/utils/mace_utils/tools/arg_parser_tools.py b/hydragnn/utils/mace_utils/tools/arg_parser_tools.py new file mode 100644 index 000000000..da64806a3 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/arg_parser_tools.py @@ -0,0 +1,113 @@ +import logging +import os + +from e3nn import o3 + + +def check_args(args): + """ + Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing + the (potentially) modified args and a list of log messages. + """ + log_messages = [] + + # Directories + # Use work_dir for all other directories as well, unless they were specified by the user + if args.log_dir is None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir is None: + args.model_dir = args.work_dir + if args.checkpoints_dir is None: + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir is None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir is None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") + + # Model + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif ( + args.hidden_irreps is not None + and args.num_channels is not None + and args.max_L is not None + ): + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + log_messages.append( + ( + "All of hidden_irreps, num_channels and max_L are specified", + logging.WARNING, + ) + ) + log_messages.append( + ( + f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", + logging.WARNING, + ) + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.num_channels is not None and args.max_L is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + assert args.max_L >= 0, "max_L must be non-negative integer" + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.hidden_irreps is not None: + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} + )[0] + args.max_L = o3.Irreps(args.hidden_irreps).lmax + elif args.max_L is not None and args.num_channels is None: + assert args.max_L >= 0, "max_L must be non-negative integer" + args.num_channels = 128 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + elif args.max_L is None and args.num_channels is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + args.max_L = 1 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + + # Loss and optimization + # Check Stage Two loss start + if args.swa: + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + if args.start_swa > args.max_num_epochs: + log_messages.append( + ( + f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + logging.WARNING, + ) + ) + log_messages.append( + ( + "Stage Two will not start, as start_stage_two > max_num_epochs", + logging.WARNING, + ) + ) + args.swa = False + + return args, log_messages diff --git a/hydragnn/utils/mace_utils/tools/cg.py b/hydragnn/utils/mace_utils/tools/cg.py new file mode 100644 index 000000000..2cca09c94 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/cg.py @@ -0,0 +1,131 @@ +########################################################################################### +# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import collections +from typing import List, Union + +import torch +from e3nn import o3 + +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim**0.5 + if normalization == "norm": + C *= ir_left.dim**0.5 * ir.dim**0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + if correlation == 4: + filter_ir_mid = [ + (0, 1), + (1, -1), + (2, 1), + (3, -1), + (4, 1), + (5, -1), + (6, 1), + (7, -1), + (8, 1), + (9, -1), + (10, 1), + (11, -1), + ] + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out diff --git a/hydragnn/utils/mace_utils/tools/checkpoint.py b/hydragnn/utils/mace_utils/tools/checkpoint.py new file mode 100644 index 000000000..8a62f1f27 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/checkpoint.py @@ -0,0 +1,227 @@ +########################################################################################### +# Checkpointing +# Authors: Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import os +import re +from typing import Dict, List, Optional, Tuple + +import torch + +from .torch_tools import TensorDict + +Checkpoint = Dict[str, TensorDict] + + +@dataclasses.dataclass +class CheckpointState: + model: torch.nn.Module + optimizer: torch.optim.Optimizer + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR + + +class CheckpointBuilder: + @staticmethod + def create_checkpoint(state: CheckpointState) -> Checkpoint: + return { + "model": state.model.state_dict(), + "optimizer": state.optimizer.state_dict(), + "lr_scheduler": state.lr_scheduler.state_dict(), + } + + @staticmethod + def load_checkpoint( + state: CheckpointState, checkpoint: Checkpoint, strict: bool + ) -> None: + state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore + state.optimizer.load_state_dict(checkpoint["optimizer"]) + state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + +@dataclasses.dataclass +class CheckpointPathInfo: + path: str + tag: str + epochs: int + swa: bool + + +class CheckpointIO: + def __init__( + self, directory: str, tag: str, keep: bool = False, swa_start: int = None + ) -> None: + self.directory = directory + self.tag = tag + self.keep = keep + self.old_path: Optional[str] = None + self.swa_start = swa_start + + self._epochs_string = "_epoch-" + self._filename_extension = "pt" + + def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: + if swa_start is not None and epochs > swa_start: + return ( + self.tag + + self._epochs_string + + str(epochs) + + "_swa" + + "." + + self._filename_extension + ) + return ( + self.tag + + self._epochs_string + + str(epochs) + + "." + + self._filename_extension + ) + + def _list_file_paths(self) -> List[str]: + if not os.path.isdir(self.directory): + return [] + all_paths = [ + os.path.join(self.directory, f) for f in os.listdir(self.directory) + ] + return [path for path in all_paths if os.path.isfile(path)] + + def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: + filename = os.path.basename(path) + regex = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" + ) + regex2 = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" + ) + match = regex.match(filename) + match2 = regex2.match(filename) + swa = False + if not match: + if not match2: + return None + match = match2 + swa = True + + return CheckpointPathInfo( + path=path, + tag=match.group("tag"), + epochs=int(match.group("epochs")), + swa=swa, + ) + + def _get_latest_checkpoint_path(self, swa) -> Optional[str]: + all_file_paths = self._list_file_paths() + checkpoint_info_list = [ + self._parse_checkpoint_path(path) for path in all_file_paths + ] + selected_checkpoint_info_list = [ + info for info in checkpoint_info_list if info and info.tag == self.tag + ] + + if len(selected_checkpoint_info_list) == 0: + logging.warning( + f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" + ) + return None + + selected_checkpoint_info_list_swa = [] + selected_checkpoint_info_list_no_swa = [] + + for ckp in selected_checkpoint_info_list: + if ckp.swa: + selected_checkpoint_info_list_swa.append(ckp) + else: + selected_checkpoint_info_list_no_swa.append(ckp) + if swa: + try: + latest_checkpoint_info = max( + selected_checkpoint_info_list_swa, key=lambda info: info.epochs + ) + except ValueError: + logging.warning( + "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." + ) + else: + latest_checkpoint_info = max( + selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs + ) + return latest_checkpoint_info.path + + def save( + self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False + ) -> None: + if not self.keep and self.old_path and not keep_last: + logging.debug(f"Deleting old checkpoint file: {self.old_path}") + os.remove(self.old_path) + + filename = self._get_checkpoint_filename(epochs, self.swa_start) + path = os.path.join(self.directory, filename) + logging.debug(f"Saving checkpoint: {path}") + os.makedirs(self.directory, exist_ok=True) + torch.save(obj=checkpoint, f=path) + self.old_path = path + + def load_latest( + self, swa: Optional[bool] = False, device: Optional[torch.device] = None + ) -> Optional[Tuple[Checkpoint, int]]: + path = self._get_latest_checkpoint_path(swa=swa) + if path is None: + return None + + return self.load(path, device=device) + + def load( + self, path: str, device: Optional[torch.device] = None + ) -> Tuple[Checkpoint, int]: + checkpoint_info = self._parse_checkpoint_path(path) + + if checkpoint_info is None: + raise RuntimeError(f"Cannot find path '{path}'") + + logging.info(f"Loading checkpoint: {checkpoint_info.path}") + return ( + torch.load(f=checkpoint_info.path, map_location=device), + checkpoint_info.epochs, + ) + + +class CheckpointHandler: + def __init__(self, *args, **kwargs) -> None: + self.io = CheckpointIO(*args, **kwargs) + self.builder = CheckpointBuilder() + + def save( + self, state: CheckpointState, epochs: int, keep_last: bool = False + ) -> None: + checkpoint = self.builder.create_checkpoint(state) + self.io.save(checkpoint, epochs, keep_last) + + def load_latest( + self, + state: CheckpointState, + swa: Optional[bool] = False, + device: Optional[torch.device] = None, + strict=False, + ) -> Optional[int]: + result = self.io.load_latest(swa=swa, device=device) + if result is None: + return None + + checkpoint, epochs = result + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs + + def load( + self, + state: CheckpointState, + path: str, + strict=False, + device: Optional[torch.device] = None, + ) -> int: + checkpoint, epochs = self.io.load(path, device=device) + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs diff --git a/hydragnn/utils/mace_utils/tools/compile.py b/hydragnn/utils/mace_utils/tools/compile.py new file mode 100644 index 000000000..425e4c02d --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/compile.py @@ -0,0 +1,95 @@ +from contextlib import contextmanager +from functools import wraps +from typing import Callable, Tuple + +try: + import torch._dynamo as dynamo +except ImportError: + dynamo = None +from e3nn import get_optimization_defaults, set_optimization_defaults +from torch import autograd, nn +from torch.fx import symbolic_trace + +ModuleFactory = Callable[..., nn.Module] +TypeTuple = Tuple[type, ...] + + +@contextmanager +def disable_e3nn_codegen(): + """Context manager that disables the legacy PyTorch code generation used in e3nn.""" + init_val = get_optimization_defaults()["jit_script_fx"] + set_optimization_defaults(jit_script_fx=False) + yield + set_optimization_defaults(jit_script_fx=init_val) + + +def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: + """Function transform that prepares a MACE module for torch.compile + + Args: + func (ModuleFactory): A function that creates an nn.Module + allow_autograd (bool, optional): Force inductor compiler to inline call to + `torch.autograd.grad`. Defaults to True. + + Returns: + ModuleFactory: Decorated function that creates a torch.compile compatible module + """ + if allow_autograd: + dynamo.allow_in_graph(autograd.grad) + elif dynamo.allowed_functions.is_allowed(autograd.grad): + dynamo.disallow_in_graph(autograd.grad) + + @wraps(func) + def wrapper(*args, **kwargs): + with disable_e3nn_codegen(): + model = func(*args, **kwargs) + + model = simplify(model) + return model + + return wrapper + + +_SIMPLIFY_REGISTRY = set() + + +def simplify_if_compile(module: nn.Module) -> nn.Module: + """Decorator to register a module for symbolic simplification + + The decorated module will be simplifed using `torch.fx.symbolic_trace`. + This constrains the module to not have any dynamic control flow, see: + + https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing + + Args: + module (nn.Module): the module to register + + Returns: + nn.Module: registered module + """ + _SIMPLIFY_REGISTRY.add(module) + return module + + +def simplify(module: nn.Module) -> nn.Module: + """Recursively searches for registered modules to simplify with + `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. + + Modules are registered with the `simplify_if_compile` decorator and + + Args: + module (nn.Module): the module to simplify + + Returns: + nn.Module: the simplified module + """ + simplify_types = tuple(_SIMPLIFY_REGISTRY) + + for name, child in module.named_children(): + if isinstance(child, simplify_types): + traced = symbolic_trace(child) + setattr(module, name, traced) + else: + simplify(child) + + return module diff --git a/hydragnn/utils/mace_utils/tools/finetuning_utils.py b/hydragnn/utils/mace_utils/tools/finetuning_utils.py new file mode 100644 index 000000000..0aad091ba --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/finetuning_utils.py @@ -0,0 +1,149 @@ +import torch + +from mace.tools.utils import AtomicNumberTable + + +def load_foundations( + model: torch.nn.Module, + model_foundations: torch.nn.Module, + table: AtomicNumberTable, + load_readout=False, + use_shift=False, + use_scale=True, + max_L=2, +): + """ + Load the foundations of a model into a model for fine-tuning. + """ + assert model_foundations.r_max == model.r_max + z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) + new_z_table = table + num_species_foundations = len(z_table.zs) + num_channels_foundation = ( + model_foundations.node_embedding.linear.weight.shape[0] + // num_species_foundations + ) + indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] + num_radial = model.radial_embedding.out_dim + num_species = len(indices_weights) + max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access + model.node_embedding.linear.weight = torch.nn.Parameter( + model_foundations.node_embedding.linear.weight.view( + num_species_foundations, -1 + )[indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": + model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( + model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() + ) + + for i in range(int(model.num_interactions)): + model.interactions[i].linear_up.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear_up.weight.clone() + ) + model.interactions[i].avg_num_neighbors = model_foundations.interactions[ + i + ].avg_num_neighbors + for j in range(4): # Assuming 4 layers in conv_tp_weights, + layer_name = f"layer{j}" + if j == 0: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ) + .weight[:num_radial, :] + .clone() + ) + ) + else: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() + ) + ) + + model.interactions[i].linear.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear.weight.clone() + ) + if ( + model.interactions[i].__class__.__name__ + == "RealAgnosticResidualInteractionBlock" + ): + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + num_species_foundations, + num_channels_foundation, + )[:, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + else: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + (max_ell + 1), + num_species_foundations, + num_channels_foundation, + )[:, :, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + # Transferring products + for i in range(2): # Assuming 2 products modules + max_range = max_L + 1 if i == 0 else 1 + for j in range(max_range): # Assuming 3 contractions in symmetric_contractions + model.products[i].symmetric_contractions.contractions[j].weights_max = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights_max[indices_weights, :, :] + .clone() + ) + ) + + for k in range(2): # Assuming 2 weights in each contraction + model.products[i].symmetric_contractions.contractions[j].weights[k] = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() + ) + ) + + model.products[i].linear.weight = torch.nn.Parameter( + model_foundations.products[i].linear.weight.clone() + ) + + if load_readout: + # Transferring readouts + model.readouts[0].linear.weight = torch.nn.Parameter( + model_foundations.readouts[0].linear.weight.clone() + ) + + model.readouts[1].linear_1.weight = torch.nn.Parameter( + model_foundations.readouts[1].linear_1.weight.clone() + ) + + model.readouts[1].linear_2.weight = torch.nn.Parameter( + model_foundations.readouts[1].linear_2.weight.clone() + ) + if model_foundations.scale_shift is not None: + if use_scale: + model.scale_shift.scale = model_foundations.scale_shift.scale.clone() + if use_shift: + model.scale_shift.shift = model_foundations.scale_shift.shift.clone() + return model diff --git a/hydragnn/utils/mace_utils/tools/scatter.py b/hydragnn/utils/mace_utils/tools/scatter.py new file mode 100644 index 000000000..7e1139a99 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/scatter.py @@ -0,0 +1,112 @@ +"""basic scatter_sum operations from torch_scatter from +https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py +Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. +PyTorch plans to move these features into the main repo, but until then, +to make installation simpler, we need this pure python set of wrappers +that don't require installing PyTorch C++ extensions. +See https://github.com/pytorch/pytorch/issues/63780. +""" + +from typing import Optional + +import torch + + +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> torch.Tensor: + assert reduce == "sum" # for now, TODO + index = _broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_std( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> torch.Tensor: + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clamp(1) + mean = tmp.div(count) + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count = count.sub(1).clamp_(1) + out = out.div(count + 1e-6).sqrt() + + return out + + +def scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = _broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode="floor") + return out diff --git a/hydragnn/utils/mace_utils/tools/scripts_utils.py b/hydragnn/utils/mace_utils/tools/scripts_utils.py new file mode 100644 index 000000000..27455944b --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/scripts_utils.py @@ -0,0 +1,653 @@ +########################################################################################### +# Training utils +# Authors: David Kovacs, Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import ast +import dataclasses +import json +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed +from e3nn import o3 +from prettytable import PrettyTable + +from mace import data, modules +from mace.tools import evaluate + + +@dataclasses.dataclass +class SubsetCollection: + train: data.Configurations + valid: data.Configurations + tests: List[Tuple[str, data.Configurations]] + + +def get_dataset_from_xyz( + work_dir: str, + train_path: str, + valid_path: str, + valid_fraction: float, + config_type_weights: Dict, + test_path: str = None, + seed: int = 1234, + keep_isolated_atoms: bool = False, + energy_key: str = "REF_energy", + forces_key: str = "REF_forces", + stress_key: str = "REF_stress", + virials_key: str = "virials", + dipole_key: str = "dipoles", + charges_key: str = "charges", +) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: + """Load training and test dataset from xyz file""" + atomic_energies_dict, all_train_configs = data.load_from_xyz( + file_path=train_path, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + extract_atomic_energies=True, + keep_isolated_atoms=keep_isolated_atoms, + ) + logging.info( + f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" + ) + if valid_path is not None: + _, valid_configs = data.load_from_xyz( + file_path=valid_path, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + stress_key=stress_key, + virials_key=virials_key, + dipole_key=dipole_key, + charges_key=charges_key, + extract_atomic_energies=False, + ) + logging.info( + f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" + ) + train_configs = all_train_configs + else: + train_configs, valid_configs = data.random_train_valid_split( + all_train_configs, valid_fraction, seed, work_dir + ) + logging.info( + f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" + ) + + test_configs = [] + if test_path is not None: + _, all_test_configs = data.load_from_xyz( + file_path=test_path, + config_type_weights=config_type_weights, + energy_key=energy_key, + forces_key=forces_key, + dipole_key=dipole_key, + stress_key=stress_key, + virials_key=virials_key, + charges_key=charges_key, + extract_atomic_energies=False, + ) + # create list of tuples (config_type, list(Atoms)) + test_configs = data.test_config_types(all_test_configs) + logging.info( + f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" + ) + for name, tmp_configs in test_configs: + logging.info( + f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" + ) + + return ( + SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), + atomic_energies_dict, + ) + + +def get_config_type_weights(ct_weights): + """ + Get config type weights from command line argument + """ + try: + config_type_weights = ast.literal_eval(ct_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + return config_type_weights + + +def print_git_commit(): + try: + import git + + repo = git.Repo(search_parent_directories=True) + commit = repo.head.commit.hexsha + logging.debug(f"Current Git commit: {commit}") + return commit + except Exception as e: # pylint: disable=W0703 + logging.debug(f"Error accessing Git repository: {e}") + return "None" + + +def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: + if model.__class__.__name__ != "ScaleShiftMACE": + return {"error": "Model is not a ScaleShiftMACE model"} + + def radial_to_name(radial_type): + if radial_type == "BesselBasis": + return "bessel" + if radial_type == "GaussianBasis": + return "gaussian" + if radial_type == "ChebychevBasis": + return "chebyshev" + return radial_type + + def radial_to_transform(radial): + if not hasattr(radial, "distance_transform"): + return None + if radial.distance_transform.__class__.__name__ == "AgnesiTransform": + return "Agnesi" + if radial.distance_transform.__class__.__name__ == "SoftTransform": + return "Soft" + return radial.distance_transform.__class__.__name__ + + config = { + "r_max": model.r_max.item(), + "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), + "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), + "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access + "interaction_cls": model.interactions[-1].__class__, + "interaction_cls_first": model.interactions[0].__class__, + "num_interactions": model.num_interactions.item(), + "num_elements": len(model.atomic_numbers), + "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), + "MLP_irreps": ( + o3.Irreps(str(model.readouts[-1].hidden_irreps)) + if model.num_interactions.item() > 1 + else 1 + ), + "gate": ( + model.readouts[-1] # pylint: disable=protected-access + .non_linearity._modules["acts"][0] + .f + if model.num_interactions.item() > 1 + else None + ), + "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), + "avg_num_neighbors": model.interactions[0].avg_num_neighbors, + "atomic_numbers": model.atomic_numbers, + "correlation": len( + model.products[0].symmetric_contractions.contractions[0].weights + ) + + 1, + "radial_type": radial_to_name( + model.radial_embedding.bessel_fn.__class__.__name__ + ), + "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], + "pair_repulsion": hasattr(model, "pair_repulsion_fn"), + "distance_transform": radial_to_transform(model.radial_embedding), + "atomic_inter_scale": model.scale_shift.scale.item(), + "atomic_inter_shift": model.scale_shift.shift.item(), + } + return config + + +def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: + model = torch.load(f=f, map_location=map_location) + model_copy = model.__class__(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + return model_copy.to(map_location) + + +def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: + model_copy = model.__class__(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + return model_copy.to(map_location) + + +def convert_to_json_format(dict_input): + for key, value in dict_input.items(): + if isinstance(value, (np.ndarray, torch.Tensor)): + dict_input[key] = value.tolist() + # # check if the value is a class and convert it to a string + elif hasattr(value, "__class__"): + dict_input[key] = str(value) + return dict_input + + +def convert_from_json_format(dict_input): + dict_output = dict_input.copy() + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticInteractionBlock + ) + dict_output["r_max"] = float(dict_input["r_max"]) + dict_output["num_bessel"] = int(dict_input["num_bessel"]) + dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) + dict_output["max_ell"] = int(dict_input["max_ell"]) + dict_output["num_interactions"] = int(dict_input["num_interactions"]) + dict_output["num_elements"] = int(dict_input["num_elements"]) + dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) + dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) + dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) + dict_output["gate"] = torch.nn.functional.silu + dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) + dict_output["atomic_numbers"] = dict_input["atomic_numbers"] + dict_output["correlation"] = int(dict_input["correlation"]) + dict_output["radial_type"] = dict_input["radial_type"] + dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) + dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) + dict_output["distance_transform"] = dict_input["distance_transform"] + dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) + dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) + + return dict_output + + +def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: + extra_files_extract = {"commit.txt": None, "config.json": None} + model_jit_load = torch.jit.load( + f, _extra_files=extra_files_extract, map_location=map_location + ) + model_load_yaml = modules.ScaleShiftMACE( + **convert_from_json_format(json.loads(extra_files_extract["config.json"])) + ) + model_load_yaml.load_state_dict(model_jit_load.state_dict()) + return model_load_yaml.to(map_location) + + +def get_atomic_energies(E0s, train_collection, z_table) -> dict: + if E0s is not None: + logging.info( + "Isolated Atomic Energies (E0s) not in training file, using command line argument" + ) + if E0s.lower() == "average": + logging.info( + "Computing average Atomic Energies using least squares regression" + ) + # catch if colections.train not defined above + try: + assert train_collection is not None + atomic_energies_dict = data.compute_average_E0s( + train_collection, z_table + ) + except Exception as e: + raise RuntimeError( + f"Could not compute average E0s if no training xyz given, error {e} occured" + ) from e + else: + if E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + with open(E0s, "r", encoding="utf-8") as f: + atomic_energies_dict = json.load(f) + else: + try: + atomic_energies_dict = ast.literal_eval(E0s) + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e + else: + raise RuntimeError( + "E0s not found in training file and not specified in command line" + ) + return atomic_energies_dict + + +def get_loss_fn( + loss: str, + energy_weight: float, + forces_weight: float, + stress_weight: float, + virials_weight: float, + dipole_weight: float, + dipole_only: bool, + compute_dipole: bool, +) -> torch.nn.Module: + if loss == "weighted": + loss_fn = modules.WeightedEnergyForcesLoss( + energy_weight=energy_weight, forces_weight=forces_weight + ) + elif loss == "forces_only": + loss_fn = modules.WeightedForcesLoss(forces_weight=forces_weight) + elif loss == "virials": + loss_fn = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=energy_weight, + forces_weight=forces_weight, + virials_weight=virials_weight, + ) + elif loss == "stress": + loss_fn = modules.WeightedEnergyForcesStressLoss( + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + ) + elif loss == "dipole": + assert ( + dipole_only is True + ), "dipole loss can only be used with AtomicDipolesMACE model" + loss_fn = modules.DipoleSingleLoss( + dipole_weight=dipole_weight, + ) + elif loss == "energy_forces_dipole": + assert dipole_only is False and compute_dipole is True + loss_fn = modules.WeightedEnergyForcesDipoleLoss( + energy_weight=energy_weight, + forces_weight=forces_weight, + dipole_weight=dipole_weight, + ) + else: + loss_fn = modules.EnergyForcesLoss( + energy_weight=energy_weight, forces_weight=forces_weight + ) + return loss_fn + + +def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: + return [ + os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) + ] + + +def custom_key(key): + """ + Helper function to sort the keys of the data loader dictionary + to ensure that the training set, and validation set + are evaluated first + """ + if key == "train": + return (0, key) + if key == "valid": + return (1, key) + return (2, key) + + +class LRScheduler: + def __init__(self, optimizer, args) -> None: + self.scheduler = args.scheduler + self._optimizer_type = ( + args.optimizer + ) # Schedulefree does not need an optimizer but checkpoint handler does. + if args.scheduler == "ExponentialLR": + self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, gamma=args.lr_scheduler_gamma + ) + elif args.scheduler == "ReduceLROnPlateau": + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + factor=args.lr_factor, + patience=args.scheduler_patience, + ) + else: + raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") + + def step(self, metrics=None, epoch=None): # pylint: disable=E1123 + if self._optimizer_type == "schedulefree": + return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary + if self.scheduler == "ExponentialLR": + self.lr_scheduler.step(epoch=epoch) + elif self.scheduler == "ReduceLROnPlateau": + self.lr_scheduler.step( # pylint: disable=E1123 + metrics=metrics, epoch=epoch + ) + + def __getattr__(self, name): + if name == "step": + return self.step + return getattr(self.lr_scheduler, name) + + +def create_error_table( + table_type: str, + all_data_loaders: dict, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + output_args: Dict[str, bool], + log_wandb: bool, + device: str, + distributed: bool = False, +) -> PrettyTable: + if log_wandb: + import wandb + table = PrettyTable() + if table_type == "TotalRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSEstressvirials": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + "RMSE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "TotalMAE": + table.field_names = [ + "config_type", + "MAE E / meV", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "PerAtomMAE": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "DipoleRMSE": + table.field_names = [ + "config_type", + "RMSE MU / mDebye / atom", + "relative MU RMSE %", + ] + elif table_type == "DipoleMAE": + table.field_names = [ + "config_type", + "MAE MU / mDebye / atom", + "relative MU MAE %", + ] + elif table_type == "EnergyDipoleRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "rel F RMSE %", + "RMSE MU / mDebye / atom", + "rel MU RMSE %", + ] + + for name in sorted(all_data_loaders, key=custom_key): + data_loader = all_data_loaders[name] + logging.info(f"Evaluating {name} ...") + _, metrics = evaluate( + model, + loss_fn=loss_fn, + data_loader=data_loader, + output_args=output_args, + device=device, + ) + if distributed: + torch.distributed.barrier() + + del data_loader + torch.cuda.empty_cache() + if log_wandb: + wandb_log_dict = { + name + + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] + * 1e3, # meV / atom + name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A + name + "_final_rel_rmse_f": metrics["rel_rmse_f"], + } + wandb.log(wandb_log_dict) + if table_type == "TotalRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif table_type == "PerAtomRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", + ] + ) + elif table_type == "TotalMAE": + table.add_row( + [ + name, + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "PerAtomMAE": + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "DipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + elif table_type == "DipoleMAE": + table.add_row( + [ + name, + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", + ] + ) + elif table_type == "EnergyDipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + return table diff --git a/hydragnn/utils/mace_utils/tools/slurm_distributed.py b/hydragnn/utils/mace_utils/tools/slurm_distributed.py new file mode 100644 index 000000000..78de52a1b --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/slurm_distributed.py @@ -0,0 +1,34 @@ +########################################################################################### +# Slurm environment setup for distributed training. +# This code is refactored from rsarm's contribution at: +# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import os + +import hostlist + + +class DistributedEnvironment: + def __init__(self): + self._setup_distr_env() + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.world_size = int(os.environ["WORLD_SIZE"]) + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.rank = int(os.environ["RANK"]) + + def _setup_distr_env(self): + hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] + os.environ["MASTER_ADDR"] = hostname + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") + os.environ["WORLD_SIZE"] = os.environ.get( + "SLURM_NTASKS", + str( + int(os.environ["SLURM_NTASKS_PER_NODE"]) + * int(os.environ["SLURM_NNODES"]) + ), + ) + os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] + os.environ["RANK"] = os.environ["SLURM_PROCID"] diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/README.md b/hydragnn/utils/mace_utils/tools/torch_geometric/README.md new file mode 100644 index 000000000..261ebbbc7 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/README.md @@ -0,0 +1,12 @@ +# Trimmed-down `pytorch_geometric` + +MACE uses [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. However as only use a very limited subset of that library: the most basic graph data structures. + +We follow the same approach to NequIP (https://github.com/mir-group/nequip/tree/main/nequip) and copy their code here. + +To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is necessary for our code. + +We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch. + +[1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric
+[2] https://arxiv.org/abs/1903.02428 diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py b/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py new file mode 100644 index 000000000..486f0d09d --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py @@ -0,0 +1,7 @@ +from .batch import Batch +from .data import Data +from .dataloader import DataLoader +from .dataset import Dataset +from .seed import seed_everything + +__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py b/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py new file mode 100644 index 000000000..be5ec9d0c --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py @@ -0,0 +1,257 @@ +from collections.abc import Sequence +from typing import List + +import numpy as np +import torch +from torch import Tensor + +from .data import Data +from .dataset import IndexType + + +class Batch(Data): + r"""A plain old python object modeling a batch of graphs as one big + (disconnected) graph. With :class:`torch_geometric.data.Data` being the + base class, all its methods can also be used here. + In addition, single graphs can be reconstructed via the assignment vector + :obj:`batch`, which maps each node to its respective graph identifier. + """ + + def __init__(self, batch=None, ptr=None, **kwargs): + super(Batch, self).__init__(**kwargs) + + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + self.batch = batch + self.ptr = ptr + self.__data_class__ = Data + self.__slices__ = None + self.__cumsum__ = None + self.__cat_dims__ = None + self.__num_nodes_list__ = None + self.__num_graphs__ = None + + @classmethod + def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`.""" + + keys = list(set(data_list[0].keys) - set(exclude_keys)) + assert "batch" not in keys and "ptr" not in keys + + batch = cls() + for key in data_list[0].__dict__.keys(): + if key[:2] != "__" and key[-2:] != "__": + batch[key] = None + + batch.__num_graphs__ = len(data_list) + batch.__data_class__ = data_list[0].__class__ + for key in keys + ["batch"]: + batch[key] = [] + batch["ptr"] = [0] + + device = None + slices = {key: [0] for key in keys} + cumsum = {key: [0] for key in keys} + cat_dims = {} + num_nodes_list = [] + for i, data in enumerate(data_list): + for key in keys: + item = data[key] + + # Increase values by `cumsum` value. + cum = cumsum[key][-1] + if isinstance(item, Tensor) and item.dtype != torch.bool: + if not isinstance(cum, int) or cum != 0: + item = item + cum + elif isinstance(item, (int, float)): + item = item + cum + + # Gather the size of the `cat` dimension. + size = 1 + cat_dim = data.__cat_dim__(key, data[key]) + # 0-dimensional tensors have no dimension along which to + # concatenate, so we set `cat_dim` to `None`. + if isinstance(item, Tensor) and item.dim() == 0: + cat_dim = None + cat_dims[key] = cat_dim + + # Add a batch dimension to items whose `cat_dim` is `None`: + if isinstance(item, Tensor) and cat_dim is None: + cat_dim = 0 # Concatenate along this new batch dimension. + item = item.unsqueeze(0) + device = item.device + elif isinstance(item, Tensor): + size = item.size(cat_dim) + device = item.device + + batch[key].append(item) # Append item to the attribute list. + + slices[key].append(size + slices[key][-1]) + inc = data.__inc__(key, item) + if isinstance(inc, (tuple, list)): + inc = torch.tensor(inc) + cumsum[key].append(inc + cumsum[key][-1]) + + if key in follow_batch: + if isinstance(size, Tensor): + for j, size in enumerate(size.tolist()): + tmp = f"{key}_{j}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + else: + tmp = f"{key}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + + if hasattr(data, "__num_nodes__"): + num_nodes_list.append(data.__num_nodes__) + else: + num_nodes_list.append(None) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes,), i, dtype=torch.long, device=device) + batch.batch.append(item) + batch.ptr.append(batch.ptr[-1] + num_nodes) + + batch.batch = None if len(batch.batch) == 0 else batch.batch + batch.ptr = None if len(batch.ptr) == 1 else batch.ptr + batch.__slices__ = slices + batch.__cumsum__ = cumsum + batch.__cat_dims__ = cat_dims + batch.__num_nodes_list__ = num_nodes_list + + ref_data = data_list[0] + for key in batch.keys: + items = batch[key] + item = items[0] + cat_dim = ref_data.__cat_dim__(key, item) + cat_dim = 0 if cat_dim is None else cat_dim + if isinstance(item, Tensor): + batch[key] = torch.cat(items, cat_dim) + elif isinstance(item, (int, float)): + batch[key] = torch.tensor(items) + + # if torch_geometric.is_debug_enabled(): + # batch.debug() + + return batch.contiguous() + + def get_example(self, idx: int) -> Data: + r"""Reconstructs the :class:`torch_geometric.data.Data` object at index + :obj:`idx` from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + + if self.__slices__ is None: + raise RuntimeError( + ( + "Cannot reconstruct data list from batch because the batch " + "object was not created using `Batch.from_data_list()`." + ) + ) + + data = self.__data_class__() + idx = self.num_graphs + idx if idx < 0 else idx + + for key in self.__slices__.keys(): + item = self[key] + if self.__cat_dims__[key] is None: + # The item was concatenated along a new batch dimension, + # so just index in that dimension: + item = item[idx] + else: + # Narrow the item based on the values in `__slices__`. + if isinstance(item, Tensor): + dim = self.__cat_dims__[key] + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item.narrow(dim, start, end - start) + else: + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item[start:end] + item = item[0] if len(item) == 1 else item + + # Decrease its value by `cumsum` value: + cum = self.__cumsum__[key][idx] + if isinstance(item, Tensor): + if not isinstance(cum, int) or cum != 0: + item = item - cum + elif isinstance(item, (int, float)): + item = item - cum + + data[key] = item + + if self.__num_nodes_list__[idx] is not None: + data.num_nodes = self.__num_nodes_list__[idx] + + return data + + def index_select(self, idx: IndexType) -> List[Data]: + if isinstance(idx, slice): + idx = list(range(self.num_graphs)[idx]) + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + idx = idx.flatten().tolist() + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + idx = idx.flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0].flatten().tolist() + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + pass + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + return [self.get_example(i) for i in idx] + + def __getitem__(self, idx): + if isinstance(idx, str): + return super(Batch, self).__getitem__(idx) + elif isinstance(idx, (int, np.integer)): + return self.get_example(idx) + else: + return self.index_select(idx) + + def to_data_list(self) -> List[Data]: + r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects + from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + return [self.get_example(i) for i in range(self.num_graphs)] + + @property + def num_graphs(self) -> int: + """Returns the number of graphs in the batch.""" + if self.__num_graphs__ is not None: + return self.__num_graphs__ + elif self.ptr is not None: + return self.ptr.numel() - 1 + elif self.batch is not None: + return int(self.batch.max()) + 1 + else: + raise ValueError diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/data.py b/hydragnn/utils/mace_utils/tools/torch_geometric/data.py new file mode 100644 index 000000000..4e1ab3084 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/data.py @@ -0,0 +1,441 @@ +import collections +import copy +import re + +import torch + +# from ..utils.num_nodes import maybe_num_nodes + +__num_nodes_warn_msg__ = ( + "The number of nodes in your data object can only be inferred by its {} " + "indices, and hence may result in unexpected batch-wise behavior, e.g., " + "in case there exists isolated nodes. Please consider explicitly setting " + "the number of nodes for this data object by assigning it to " + "data.num_nodes." +) + + +def size_repr(key, item, indent=0): + indent_str = " " * indent + if torch.is_tensor(item) and item.dim() == 0: + out = item.item() + elif torch.is_tensor(item): + out = str(list(item.size())) + elif isinstance(item, list) or isinstance(item, tuple): + out = str([len(item)]) + elif isinstance(item, dict): + lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] + out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" + elif isinstance(item, str): + out = f'"{item}"' + else: + out = str(item) + + return f"{indent_str}{key}={out}" + + +class Data(object): + r"""A plain old python object modeling a single graph with various + (optional) attributes: + + Args: + x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, + num_node_features]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in COO format + with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y (Tensor, optional): Graph or node targets with arbitrary shape. + (default: :obj:`None`) + pos (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + normal (Tensor, optional): Normal vector matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + face (LongTensor, optional): Face adjacency matrix with shape + :obj:`[3, num_faces]`. (default: :obj:`None`) + + The data object is not restricted to these attributes and can be extended + by any other additional data. + + Example:: + + data = Data(x=x, edge_index=edge_index) + data.train_idx = torch.tensor([...], dtype=torch.long) + data.test_mask = torch.tensor([...], dtype=torch.bool) + """ + + def __init__( + self, + x=None, + edge_index=None, + edge_attr=None, + y=None, + pos=None, + normal=None, + face=None, + **kwargs, + ): + self.x = x + self.edge_index = edge_index + self.edge_attr = edge_attr + self.y = y + self.pos = pos + self.normal = normal + self.face = face + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + if edge_index is not None and edge_index.dtype != torch.long: + raise ValueError( + ( + f"Argument `edge_index` needs to be of type `torch.long` but " + f"found type `{edge_index.dtype}`." + ) + ) + + if face is not None and face.dtype != torch.long: + raise ValueError( + ( + f"Argument `face` needs to be of type `torch.long` but found " + f"type `{face.dtype}`." + ) + ) + + @classmethod + def from_dict(cls, dictionary): + r"""Creates a data object from a python dictionary.""" + data = cls() + + for key, item in dictionary.items(): + data[key] = item + + return data + + def to_dict(self): + return {key: item for key, item in self} + + def to_namedtuple(self): + keys = self.keys + DataTuple = collections.namedtuple("DataTuple", keys) + return DataTuple(*[self[key] for key in keys]) + + def __getitem__(self, key): + r"""Gets the data of the attribute :obj:`key`.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Sets the attribute :obj:`key` to :obj:`value`.""" + setattr(self, key, value) + + def __delitem__(self, key): + r"""Delete the data of the attribute :obj:`key`.""" + return delattr(self, key) + + @property + def keys(self): + r"""Returns all names of graph attributes.""" + keys = [key for key in self.__dict__.keys() if self[key] is not None] + keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] + return keys + + def __len__(self): + r"""Returns the number of all present attributes.""" + return len(self.keys) + + def __contains__(self, key): + r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the + data.""" + return key in self.keys + + def __iter__(self): + r"""Iterates over all present attributes in the data, yielding their + attribute names and content.""" + for key in sorted(self.keys): + yield key, self[key] + + def __call__(self, *keys): + r"""Iterates over all attributes :obj:`*keys` in the data, yielding + their attribute names and content. + If :obj:`*keys` is not given this method will iterative over all + present attributes.""" + for key in sorted(self.keys) if not keys else keys: + if key in self: + yield key, self[key] + + def __cat_dim__(self, key, value): + r"""Returns the dimension for which :obj:`value` of attribute + :obj:`key` will get concatenated when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + if bool(re.search("(index|face)", key)): + return -1 + return 0 + + def __inc__(self, key, value): + r"""Returns the incremental count to cumulatively increase the value + of the next attribute of :obj:`key` when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + # Only `*index*` and `*face*` attributes should be cumulatively summed + # up when creating batches. + return self.num_nodes if bool(re.search("(index|face)", key)) else 0 + + @property + def num_nodes(self): + r"""Returns or sets the number of nodes in the graph. + + .. note:: + The number of nodes in your data object is typically automatically + inferred, *e.g.*, when node features :obj:`x` are present. + In some cases however, a graph may only be given by its edge + indices :obj:`edge_index`. + PyTorch Geometric then *guesses* the number of nodes + according to :obj:`edge_index.max().item() + 1`, but in case there + exists isolated nodes, this number has not to be correct and can + therefore result in unexpected batch-wise behavior. + Thus, we recommend to set the number of nodes in your data object + explicitly via :obj:`data.num_nodes = ...`. + You will be given a warning that requests you to do so. + """ + if hasattr(self, "__num_nodes__"): + return self.__num_nodes__ + for key, item in self("x", "pos", "normal", "batch"): + return item.size(self.__cat_dim__(key, item)) + if hasattr(self, "adj"): + return self.adj.size(0) + if hasattr(self, "adj_t"): + return self.adj_t.size(1) + # if self.face is not None: + # logging.warning(__num_nodes_warn_msg__.format("face")) + # return maybe_num_nodes(self.face) + # if self.edge_index is not None: + # logging.warning(__num_nodes_warn_msg__.format("edge")) + # return maybe_num_nodes(self.edge_index) + return None + + @num_nodes.setter + def num_nodes(self, num_nodes): + self.__num_nodes__ = num_nodes + + @property + def num_edges(self): + """ + Returns the number of edges in the graph. + For undirected graphs, this will return the number of bi-directional + edges, which is double the amount of unique edges. + """ + for key, item in self("edge_index", "edge_attr"): + return item.size(self.__cat_dim__(key, item)) + for key, item in self("adj", "adj_t"): + return item.nnz() + return None + + @property + def num_faces(self): + r"""Returns the number of faces in the mesh.""" + if self.face is not None: + return self.face.size(self.__cat_dim__("face", self.face)) + return None + + @property + def num_node_features(self): + r"""Returns the number of features per node in the graph.""" + if self.x is None: + return 0 + return 1 if self.x.dim() == 1 else self.x.size(1) + + @property + def num_features(self): + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self): + r"""Returns the number of features per edge in the graph.""" + if self.edge_attr is None: + return 0 + return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) + + def __apply__(self, item, func): + if torch.is_tensor(item): + return func(item) + elif isinstance(item, (tuple, list)): + return [self.__apply__(v, func) for v in item] + elif isinstance(item, dict): + return {k: self.__apply__(v, func) for k, v in item.items()} + else: + return item + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + self[key] = self.__apply__(item, func) + return self + + def contiguous(self, *keys): + r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. + If :obj:`*keys` is not given, all present attributes are ensured to + have a contiguous memory layout.""" + return self.apply(lambda x: x.contiguous(), *keys) + + def to(self, device, *keys, **kwargs): + r"""Performs tensor dtype and/or device conversion to all attributes + :obj:`*keys`. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.to(device, **kwargs), *keys) + + def cpu(self, *keys): + r"""Copies all attributes :obj:`*keys` to CPU memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.cpu(), *keys) + + def cuda(self, device=None, non_blocking=False, *keys): + r"""Copies all attributes :obj:`*keys` to CUDA memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply( + lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys + ) + + def clone(self): + r"""Performs a deep-copy of the data object.""" + return self.__class__.from_dict( + { + k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) + for k, v in self.__dict__.items() + } + ) + + def pin_memory(self, *keys): + r"""Copies all attributes :obj:`*keys` to pinned memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.pin_memory(), *keys) + + def debug(self): + if self.edge_index is not None: + if self.edge_index.dtype != torch.long: + raise RuntimeError( + ( + "Expected edge indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.edge_index.dtype) + ) + + if self.face is not None: + if self.face.dtype != torch.long: + raise RuntimeError( + ( + "Expected face indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.face.dtype) + ) + + if self.edge_index is not None: + if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: + raise RuntimeError( + ( + "Edge indices should have shape [2, num_edges] but found" + " shape {}" + ).format(self.edge_index.size()) + ) + + if self.edge_index is not None and self.num_nodes is not None: + if self.edge_index.numel() > 0: + min_index = self.edge_index.min() + max_index = self.edge_index.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Edge indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.face is not None: + if self.face.dim() != 2 or self.face.size(0) != 3: + raise RuntimeError( + ( + "Face indices should have shape [3, num_faces] but found" + " shape {}" + ).format(self.face.size()) + ) + + if self.face is not None and self.num_nodes is not None: + if self.face.numel() > 0: + min_index = self.face.min() + max_index = self.face.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Face indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.edge_index is not None and self.edge_attr is not None: + if self.edge_index.size(1) != self.edge_attr.size(0): + raise RuntimeError( + ( + "Edge indices and edge attributes hold a differing " + "number of edges, found {} and {}" + ).format(self.edge_index.size(), self.edge_attr.size()) + ) + + if self.x is not None and self.num_nodes is not None: + if self.x.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node features should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.x.size(0)) + ) + + if self.pos is not None and self.num_nodes is not None: + if self.pos.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node positions should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.pos.size(0)) + ) + + if self.normal is not None and self.num_nodes is not None: + if self.normal.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node normals should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.normal.size(0)) + ) + + def __repr__(self): + cls = str(self.__class__.__name__) + has_dict = any([isinstance(item, dict) for _, item in self]) + + if not has_dict: + info = [size_repr(key, item) for key, item in self] + return "{}({})".format(cls, ", ".join(info)) + else: + info = [size_repr(key, item, indent=2) for key, item in self] + return "{}(\n{}\n)".format(cls, ",\n".join(info)) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py b/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py new file mode 100644 index 000000000..396b7e728 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py @@ -0,0 +1,87 @@ +from collections.abc import Mapping, Sequence +from typing import List, Optional, Union + +import torch.utils.data +from torch.utils.data.dataloader import default_collate + +from .batch import Batch +from .data import Data +from .dataset import Dataset + + +class Collater: + def __init__(self, follow_batch, exclude_keys): + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + def __call__(self, batch): + elem = batch[0] + if isinstance(elem, Data): + return Batch.from_data_list( + batch, + follow_batch=self.follow_batch, + exclude_keys=self.exclude_keys, + ) + elif isinstance(elem, torch.Tensor): + return default_collate(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, Mapping): + return {key: self([data[key] for data in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): + return type(elem)(*(self(s) for s in zip(*batch))) + elif isinstance(elem, Sequence) and not isinstance(elem, str): + return [self(s) for s in zip(*batch)] + + raise TypeError(f"DataLoader found invalid type: {type(elem)}") + + def collate(self, batch): # Deprecated... + return self(batch) + + +class DataLoader(torch.utils.data.DataLoader): + r"""A data loader which merges data objects from a + :class:`torch_geometric.data.Dataset` to a mini-batch. + Data objects can be either of type :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData`. + Args: + dataset (Dataset): The dataset from which to load the data. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + follow_batch (List[str], optional): Creates assignment batch + vectors for each key in the list. (default: :obj:`None`) + exclude_keys (List[str], optional): Will exclude each key in the + list. (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`torch.utils.data.DataLoader`. + """ + + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + shuffle: bool = False, + follow_batch: Optional[List[str]] = [None], + exclude_keys: Optional[List[str]] = [None], + **kwargs, + ): + if "collate_fn" in kwargs: + del kwargs["collate_fn"] + + # Save for PyTorch Lightning < 1.6: + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + super().__init__( + dataset, + batch_size, + shuffle, + collate_fn=Collater(follow_batch, exclude_keys), + **kwargs, + ) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py b/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py new file mode 100644 index 000000000..b4aeb2be9 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py @@ -0,0 +1,280 @@ +import copy +import os.path as osp +import re +import warnings +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.utils.data +from torch import Tensor + +from .data import Data +from .utils import makedirs + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class Dataset(torch.utils.data.Dataset): + r"""Dataset base class for creating graph datasets. + See `here `__ for the accompanying tutorial. + + Args: + root (string, optional): Root directory where the dataset should be + saved. (optional: :obj:`None`) + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + """ + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.raw_dir` folder in + order to skip the download.""" + raise NotImplementedError + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + raise NotImplementedError + + def download(self): + r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" + raise NotImplementedError + + def process(self): + r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" + raise NotImplementedError + + def len(self) -> int: + raise NotImplementedError + + def get(self, idx: int) -> Data: + r"""Gets the data object at index :obj:`idx`.""" + raise NotImplementedError + + def __init__( + self, + root: Optional[str] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): + super().__init__() + + if isinstance(root, str): + root = osp.expanduser(osp.normpath(root)) + + self.root = root + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self._indices: Optional[Sequence] = None + + if "download" in self.__class__.__dict__.keys(): + self._download() + + if "process" in self.__class__.__dict__.keys(): + self._process() + + def indices(self) -> Sequence: + return range(self.len()) if self._indices is None else self._indices + + @property + def raw_dir(self) -> str: + return osp.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + return osp.join(self.root, "processed") + + @property + def num_node_features(self) -> int: + r"""Returns the number of features per node in the dataset.""" + data = self[0] + if hasattr(data, "num_node_features"): + return data.num_node_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_node_features'" + ) + + @property + def num_features(self) -> int: + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self) -> int: + r"""Returns the number of features per edge in the dataset.""" + data = self[0] + if hasattr(data, "num_edge_features"): + return data.num_edge_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_edge_features'" + ) + + @property + def raw_paths(self) -> List[str]: + r"""The filepaths to find in order to skip the download.""" + files = to_list(self.raw_file_names) + return [osp.join(self.raw_dir, f) for f in files] + + @property + def processed_paths(self) -> List[str]: + r"""The filepaths to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + files = to_list(self.processed_file_names) + return [osp.join(self.processed_dir, f) for f in files] + + def _download(self): + if files_exist(self.raw_paths): # pragma: no cover + return + + makedirs(self.raw_dir) + self.download() + + def _process(self): + f = osp.join(self.processed_dir, "pre_transform.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): + warnings.warn( + f"The `pre_transform` argument differs from the one used in " + f"the pre-processed version of this dataset. If you want to " + f"make use of another pre-processing technique, make sure to " + f"sure to delete '{self.processed_dir}' first" + ) + + f = osp.join(self.processed_dir, "pre_filter.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): + warnings.warn( + "The `pre_filter` argument differs from the one used in the " + "pre-processed version of this dataset. If you want to make " + "use of another pre-fitering technique, make sure to delete " + "'{self.processed_dir}' first" + ) + + if files_exist(self.processed_paths): # pragma: no cover + return + + print("Processing...") + + makedirs(self.processed_dir) + self.process() + + path = osp.join(self.processed_dir, "pre_transform.pt") + torch.save(_repr(self.pre_transform), path) + path = osp.join(self.processed_dir, "pre_filter.pt") + torch.save(_repr(self.pre_filter), path) + + print("Done!") + + def __len__(self) -> int: + r"""The number of examples in the dataset.""" + return len(self.indices()) + + def __getitem__( + self, + idx: Union[int, np.integer, IndexType], + ) -> Union["Dataset", Data]: + r"""In case :obj:`idx` is of type integer, will return the data object + at index :obj:`idx` (and transforms it in case :obj:`transform` is + present). + In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a + tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy + :obj:`np.array`, will return a subset of the dataset at the specified + indices.""" + if ( + isinstance(idx, (int, np.integer)) + or (isinstance(idx, Tensor) and idx.dim() == 0) + or (isinstance(idx, np.ndarray) and np.isscalar(idx)) + ): + data = self.get(self.indices()[idx]) + data = data if self.transform is None else self.transform(data) + return data + + else: + return self.index_select(idx) + + def index_select(self, idx: IndexType) -> "Dataset": + indices = self.indices() + + if isinstance(idx, slice): + indices = indices[idx] + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False) + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0] + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + indices = [indices[i] for i in idx] + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + dataset = copy.copy(self) + dataset._indices = indices + return dataset + + def shuffle( + self, + return_perm: bool = False, + ) -> Union["Dataset", Tuple["Dataset", Tensor]]: + r"""Randomly shuffles the examples in the dataset. + + Args: + return_perm (bool, optional): If set to :obj:`True`, will return + the random permutation used to shuffle the dataset in addition. + (default: :obj:`False`) + """ + perm = torch.randperm(len(self)) + dataset = self.index_select(perm) + return (dataset, perm) if return_perm is True else dataset + + def __repr__(self) -> str: + arg_repr = str(len(self)) if len(self) > 1 else "" + return f"{self.__class__.__name__}({arg_repr})" + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + +def files_exist(files: List[str]) -> bool: + # NOTE: We return `False` in case `files` is empty, leading to a + # re-processing of files on every instantiation. + return len(files) != 0 and all([osp.exists(f) for f in files]) + + +def _repr(obj: Any) -> str: + if obj is None: + return "None" + return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py b/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py new file mode 100644 index 000000000..be27fcaa1 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py @@ -0,0 +1,17 @@ +import random + +import numpy as np +import torch + + +def seed_everything(seed: int): + r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, + :obj:`numpy` and Python. + + Args: + seed (int): The desired seed. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py b/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py new file mode 100644 index 000000000..f53b8f809 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py @@ -0,0 +1,54 @@ +import os +import os.path as osp +import ssl +import urllib +import zipfile + + +def makedirs(dir): + os.makedirs(dir, exist_ok=True) + + +def download_url(url, folder, log=True): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (string): The url. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + filename = url.rpartition("/")[2].split("?")[0] + path = osp.join(folder, filename) + + if osp.exists(path): # pragma: no cover + if log: + print("Using exist file", filename) + return path + + if log: + print("Downloading", url) + + makedirs(folder) + + context = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=context) + + with open(path, "wb") as f: + f.write(data.read()) + + return path + + +def extract_zip(path, folder, log=True): + r"""Extracts a zip archive to a specific folder. + + Args: + path (string): The path to the tar archive. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + with zipfile.ZipFile(path, "r") as f: + f.extractall(folder) diff --git a/hydragnn/utils/mace_utils/tools/torch_tools.py b/hydragnn/utils/mace_utils/tools/torch_tools.py new file mode 100644 index 000000000..1ec3ecde7 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/torch_tools.py @@ -0,0 +1,138 @@ +########################################################################################### +# Tools for torch +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from contextlib import contextmanager +from typing import Dict + +import numpy as np +import torch +from e3nn.io import CartesianTensor + +TensorDict = Dict[str, torch.Tensor] + + +def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Generates one-hot encoding with classes from + :param indices: (N x 1) tensor + :param num_classes: number of classes + :param device: torch device + :return: (N x num_classes) tensor + """ + shape = indices.shape[:-1] + (num_classes,) + oh = torch.zeros(shape, device=indices.device).view(shape) + + # scatter_ is the in-place version of scatter + oh.scatter_(dim=-1, index=indices, value=1) + + return oh.view(*shape) + + +def count_parameters(module: torch.nn.Module) -> int: + return int(sum(np.prod(p.shape) for p in module.parameters())) + + +def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: + return {k: v.to(device) if v is not None else None for k, v in td.items()} + + +def set_seeds(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + + +def to_numpy(t: torch.Tensor) -> np.ndarray: + return t.cpu().detach().numpy() + + +def init_device(device_str: str) -> torch.device: + if "cuda" in device_str: + assert torch.cuda.is_available(), "No CUDA device available!" + if ":" in device_str: + # Check if the desired device is available + assert int(device_str.split(":")[-1]) < torch.cuda.device_count() + logging.info( + f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" + ) + torch.cuda.init() + return torch.device(device_str) + if device_str == "mps": + assert torch.backends.mps.is_available(), "No MPS backend is available!" + logging.info("Using MPS GPU acceleration") + return torch.device("mps") + + logging.info("Using CPU") + return torch.device("cpu") + + +dtype_dict = {"float32": torch.float32, "float64": torch.float64} + + +def set_default_dtype(dtype: str) -> None: + torch.set_default_dtype(dtype_dict[dtype]) + + +def spherical_to_cartesian(t: torch.Tensor): + """ + Convert spherical notation to cartesian notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def cartesian_to_spherical(t: torch.Tensor): + """ + Convert cartesian notation to spherical notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def voigt_to_matrix(t: torch.Tensor): + """ + Convert voigt notation to matrix notation + :param t: (6,) tensor or (3, 3) tensor or (9,) tensor + :return: (3, 3) tensor + """ + if t.shape == (3, 3): + return t + if t.shape == (6,): + return torch.tensor( + [ + [t[0], t[5], t[4]], + [t[5], t[1], t[3]], + [t[4], t[3], t[2]], + ], + dtype=t.dtype, + ) + if t.shape == (9,): + return t.view(3, 3) + + raise ValueError( + f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" + ) + + +def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): + import wandb + + wandb.init(project=project, entity=entity, name=name, config=config, dir=directory) + + +@contextmanager +def default_dtype(dtype: torch.dtype): + """Context manager for configuring the default_dtype used by torch + + Args: + dtype (torch.dtype): the default dtype to use within this context manager + """ + init = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(init) diff --git a/hydragnn/utils/mace_utils/tools/train.py b/hydragnn/utils/mace_utils/tools/train.py new file mode 100644 index 000000000..b38bce167 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/train.py @@ -0,0 +1,524 @@ +########################################################################################### +# Training script +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import time +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed +from torch.nn.parallel import DistributedDataParallel +from torch.optim.swa_utils import SWALR, AveragedModel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch_ema import ExponentialMovingAverage +from torchmetrics import Metric + +from . import torch_geometric +from .checkpoint import CheckpointHandler, CheckpointState +from .torch_tools import to_numpy +from .utils import ( + MetricsLogger, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, +) + + +@dataclasses.dataclass +class SWAContainer: + model: AveragedModel + scheduler: SWALR + start: int + loss_fn: torch.nn.Module + + +def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): + eval_metrics["mode"] = "eval" + eval_metrics["epoch"] = epoch + logger.log(eval_metrics) + if epoch is None: + inintial_phrase = "Initial" + else: + inintial_phrase = f"Epoch {epoch}" + if log_errors == "PerAtomRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_stress_per_atom"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_virials_per_atom"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_stress_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_stress = eval_metrics["mae_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_virials_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_virials = eval_metrics["mae_virials"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" + ) + elif log_errors == "TotalRMSE": + error_e = eval_metrics["rmse_e"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", + ) + elif log_errors == "PerAtomMAE": + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + ) + elif log_errors == "TotalMAE": + error_e = eval_metrics["mae_e"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + ) + elif log_errors == "DipoleRMSE": + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + ) + elif log_errors == "EnergyDipoleRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + ) + + +def train( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + train_loader: DataLoader, + valid_loader: Dict[str, DataLoader], + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, + start_epoch: int, + max_num_epochs: int, + patience: int, + checkpoint_handler: CheckpointHandler, + logger: MetricsLogger, + eval_interval: int, + output_args: Dict[str, bool], + device: torch.device, + log_errors: str, + swa: Optional[SWAContainer] = None, + ema: Optional[ExponentialMovingAverage] = None, + max_grad_norm: Optional[float] = 10.0, + log_wandb: bool = False, + distributed: bool = False, + save_all_checkpoints: bool = False, + distributed_model: Optional[DistributedDataParallel] = None, + train_sampler: Optional[DistributedSampler] = None, + rank: Optional[int] = 0, +): + lowest_loss = np.inf + valid_loss = np.inf + patience_counter = 0 + swa_start = True + keep_last = False + if log_wandb: + import wandb + + if max_grad_norm is not None: + logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") + + logging.info("") + logging.info("===========TRAINING===========") + logging.info("Started training, reporting errors on validation set") + logging.info("Loss metrics on validation set") + epoch = start_epoch + + # # log validation loss before _any_ training + param_context = ema.average_parameters() if ema is not None else nullcontext() + with param_context: + valid_loss, eval_metrics = evaluate( + model=model, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + valid_err_log(valid_loss, eval_metrics, logger, log_errors, None) + + while epoch < max_num_epochs: + # LR scheduler and SWA update + if swa is None or epoch < swa.start: + if epoch > start_epoch: + lr_scheduler.step( + metrics=valid_loss + ) # Can break if exponential LR, TODO fix that! + else: + if swa_start: + logging.info("Changing loss based on Stage Two Weights") + lowest_loss = np.inf + swa_start = False + keep_last = True + loss_fn = swa.loss_fn + swa.model.update_parameters(model) + if epoch > start_epoch: + swa.scheduler.step() + + # Train + if distributed: + train_sampler.set_epoch(epoch) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.train() + train_one_epoch( + model=model, + loss_fn=loss_fn, + data_loader=train_loader, + optimizer=optimizer, + epoch=epoch, + output_args=output_args, + max_grad_norm=max_grad_norm, + ema=ema, + logger=logger, + device=device, + distributed_model=distributed_model, + rank=rank, + ) + if distributed: + torch.distributed.barrier() + + # Validate + if epoch % eval_interval == 0: + model_to_evaluate = ( + model if distributed_model is None else distributed_model + ) + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.eval() + with param_context: + valid_loss, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + if rank == 0: + valid_err_log( + valid_loss, + eval_metrics, + logger, + log_errors, + epoch, + ) + if log_wandb: + wandb_log_dict = { + "epoch": epoch, + "valid_loss": valid_loss, + "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], + "valid_rmse_f": eval_metrics["rmse_f"], + } + wandb.log(wandb_log_dict) + + if valid_loss >= lowest_loss: + patience_counter += 1 + if patience_counter >= patience and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + elif patience_counter >= patience and epoch >= swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break + if save_all_checkpoints: + param_context = ( + ema.average_parameters() + if ema is not None + else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=True, + ) + else: + lowest_loss = valid_loss + patience_counter = 0 + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=keep_last, + ) + keep_last = False or save_all_checkpoints + if distributed: + torch.distributed.barrier() + epoch += 1 + + logging.info("Training complete") + + +def train_one_epoch( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + optimizer: torch.optim.Optimizer, + epoch: int, + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + ema: Optional[ExponentialMovingAverage], + logger: MetricsLogger, + device: torch.device, + distributed_model: Optional[DistributedDataParallel] = None, + rank: Optional[int] = 0, +) -> None: + model_to_train = model if distributed_model is None else distributed_model + for batch in data_loader: + _, opt_metrics = take_step( + model=model_to_train, + loss_fn=loss_fn, + batch=batch, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + ) + opt_metrics["mode"] = "opt" + opt_metrics["epoch"] = epoch + if rank == 0: + logger.log(opt_metrics) + + +def take_step( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + batch: torch_geometric.batch.Batch, + optimizer: torch.optim.Optimizer, + ema: Optional[ExponentialMovingAverage], + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + start_time = time.time() + batch = batch.to(device) + optimizer.zero_grad(set_to_none=True) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=True, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + loss = loss_fn(pred=output, ref=batch) + loss.backward() + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + optimizer.step() + + if ema is not None: + ema.update() + + loss_dict = { + "loss": to_numpy(loss), + "time": time.time() - start_time, + } + + return loss, loss_dict + + +def evaluate( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + output_args: Dict[str, bool], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + for param in model.parameters(): + param.requires_grad = False + + metrics = MACELoss(loss_fn=loss_fn).to(device) + + start_time = time.time() + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + avg_loss, aux = metrics(batch, output) + + avg_loss, aux = metrics.compute() + aux["time"] = time.time() - start_time + metrics.reset() + + for param in model.parameters(): + param.requires_grad = True + + return avg_loss, aux + + +class MACELoss(Metric): + def __init__(self, loss_fn: torch.nn.Module): + super().__init__() + self.loss_fn = loss_fn + self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("delta_es", default=[], dist_reduce_fx="cat") + self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("fs", default=[], dist_reduce_fx="cat") + self.add_state("delta_fs", default=[], dist_reduce_fx="cat") + self.add_state( + "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_stress", default=[], dist_reduce_fx="cat") + self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat") + self.add_state( + "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_virials", default=[], dist_reduce_fx="cat") + self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") + + def update(self, batch, output): # pylint: disable=arguments-differ + loss = self.loss_fn(pred=output, ref=batch) + self.total_loss += loss + self.num_data += batch.num_graphs + + if output.get("energy") is not None and batch.energy is not None: + self.E_computed += 1.0 + self.delta_es.append(batch.energy - output["energy"]) + self.delta_es_per_atom.append( + (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) + ) + if output.get("forces") is not None and batch.forces is not None: + self.Fs_computed += 1.0 + self.fs.append(batch.forces) + self.delta_fs.append(batch.forces - output["forces"]) + if output.get("stress") is not None and batch.stress is not None: + self.stress_computed += 1.0 + self.delta_stress.append(batch.stress - output["stress"]) + self.delta_stress_per_atom.append( + (batch.stress - output["stress"]) + / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) + ) + if output.get("virials") is not None and batch.virials is not None: + self.virials_computed += 1.0 + self.delta_virials.append(batch.virials - output["virials"]) + self.delta_virials_per_atom.append( + (batch.virials - output["virials"]) + / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) + ) + if output.get("dipole") is not None and batch.dipole is not None: + self.Mus_computed += 1.0 + self.mus.append(batch.dipole) + self.delta_mus.append(batch.dipole - output["dipole"]) + self.delta_mus_per_atom.append( + (batch.dipole - output["dipole"]) + / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) + ) + + def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: + if isinstance(delta, list): + delta = torch.cat(delta) + return to_numpy(delta) + + def compute(self): + aux = {} + aux["loss"] = to_numpy(self.total_loss / self.num_data).item() + if self.E_computed: + delta_es = self.convert(self.delta_es) + delta_es_per_atom = self.convert(self.delta_es_per_atom) + aux["mae_e"] = compute_mae(delta_es) + aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) + aux["rmse_e"] = compute_rmse(delta_es) + aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) + aux["q95_e"] = compute_q95(delta_es) + if self.Fs_computed: + fs = self.convert(self.fs) + delta_fs = self.convert(self.delta_fs) + aux["mae_f"] = compute_mae(delta_fs) + aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) + aux["rmse_f"] = compute_rmse(delta_fs) + aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) + aux["q95_f"] = compute_q95(delta_fs) + if self.stress_computed: + delta_stress = self.convert(self.delta_stress) + delta_stress_per_atom = self.convert(self.delta_stress_per_atom) + aux["mae_stress"] = compute_mae(delta_stress) + aux["rmse_stress"] = compute_rmse(delta_stress) + aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom) + aux["q95_stress"] = compute_q95(delta_stress) + if self.virials_computed: + delta_virials = self.convert(self.delta_virials) + delta_virials_per_atom = self.convert(self.delta_virials_per_atom) + aux["mae_virials"] = compute_mae(delta_virials) + aux["rmse_virials"] = compute_rmse(delta_virials) + aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) + aux["q95_virials"] = compute_q95(delta_virials) + if self.Mus_computed: + mus = self.convert(self.mus) + delta_mus = self.convert(self.delta_mus) + delta_mus_per_atom = self.convert(self.delta_mus_per_atom) + aux["mae_mu"] = compute_mae(delta_mus) + aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) + aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) + aux["rmse_mu"] = compute_rmse(delta_mus) + aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) + aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) + aux["q95_mu"] = compute_q95(delta_mus) + + return aux["loss"], aux diff --git a/hydragnn/utils/mace_utils/tools/utils.py b/hydragnn/utils/mace_utils/tools/utils.py new file mode 100644 index 000000000..762d98802 --- /dev/null +++ b/hydragnn/utils/mace_utils/tools/utils.py @@ -0,0 +1,168 @@ +########################################################################################### +# Statistics utilities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import json +import logging +import os +import sys +from typing import Any, Dict, Iterable, Optional, Sequence, Union + +import numpy as np +import torch + +from .torch_tools import to_numpy + + +def compute_mae(delta: np.ndarray) -> float: + return np.mean(np.abs(delta)).item() + + +def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.mean(np.abs(target_val)) + return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100 + + +def compute_rmse(delta: np.ndarray) -> float: + return np.sqrt(np.mean(np.square(delta))).item() + + +def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.sqrt(np.mean(np.square(target_val))).item() + return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100 + + +def compute_q95(delta: np.ndarray) -> float: + return np.percentile(np.abs(delta), q=95) + + +def compute_c(delta: np.ndarray, eta: float) -> float: + return np.mean(np.abs(delta) < eta).item() + + +def get_tag(name: str, seed: int) -> str: + return f"{name}_run-{seed}" + + +def setup_logger( + level: Union[int, str] = logging.INFO, + tag: Optional[str] = None, + directory: Optional[str] = None, + rank: Optional[int] = 0, +): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels + + # Create formatters + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Add filter for rank + logger.addFilter(lambda _: rank == 0) + + # Create console handler + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + if directory is not None and tag is not None: + os.makedirs(name=directory, exist_ok=True) + + # Create file handler for non-debug logs + main_log_path = os.path.join(directory, f"{tag}.log") + fh_main = logging.FileHandler(main_log_path) + fh_main.setLevel(level) + fh_main.setFormatter(formatter) + logger.addHandler(fh_main) + + # Create file handler for debug logs + debug_log_path = os.path.join(directory, f"{tag}_debug.log") + fh_debug = logging.FileHandler(debug_log_path) + fh_debug.setLevel(logging.DEBUG) + fh_debug.setFormatter(formatter) + fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) + logger.addHandler(fh_debug) + + +class AtomicNumberTable: + def __init__(self, zs: Sequence[int]): + self.zs = zs + + def __len__(self) -> int: + return len(self.zs) + + def __str__(self): + return f"AtomicNumberTable: {tuple(s for s in self.zs)}" + + def index_to_z(self, index: int) -> int: + return self.zs[index] + + def z_to_index(self, atomic_number: str) -> int: + return self.zs.index(atomic_number) + + +def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: + z_set = set() + for z in zs: + z_set.add(z) + return AtomicNumberTable(sorted(list(z_set))) + + +def atomic_numbers_to_indices( + atomic_numbers: np.ndarray, z_table: AtomicNumberTable +) -> np.ndarray: + to_index_fn = np.vectorize(z_table.z_to_index) + return to_index_fn(atomic_numbers) + + +def get_optimizer( + name: str, + amsgrad: bool, + learning_rate: float, + weight_decay: float, + parameters: Iterable[torch.Tensor], +) -> torch.optim.Optimizer: + if name == "adam": + return torch.optim.Adam( + parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay + ) + + if name == "adamw": + return torch.optim.AdamW( + parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay + ) + + raise RuntimeError(f"Unknown optimizer '{name}'") + + +class UniversalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + if isinstance(o, torch.Tensor): + return to_numpy(o) + return json.JSONEncoder.default(self, o) + + +class MetricsLogger: + def __init__(self, directory: str, tag: str) -> None: + self.directory = directory + self.filename = tag + ".txt" + self.path = os.path.join(self.directory, self.filename) + + def log(self, d: Dict[str, Any]) -> None: + logging.debug(f"Saving info: {self.path}") + os.makedirs(name=self.directory, exist_ok=True) + with open(self.path, mode="a", encoding="utf-8") as f: + f.write(json.dumps(d, cls=UniversalEncoder)) + f.write("\n") From 8ab9eb2327541ad64b1b10e4f5f49cccf59cf845 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 11 Sep 2024 19:29:00 -0400 Subject: [PATCH 02/51] MACE Rebased --- .vscode/.vscode/launch.json | 15 + .vscode/launch.json | 74 ++ examples/qm9/qm9.json | 11 +- hydragnn/models/MACEStack.py | 1030 ++++++++--------- hydragnn/models/create.py | 42 + .../input_config_parsing/config_utils.py | 27 +- hydragnn/utils/mace_utils/data/atomic_data.py | 2 +- .../utils/mace_utils/data/hdf5_dataset.py | 6 +- hydragnn/utils/mace_utils/data/utils.py | 2 +- hydragnn/utils/mace_utils/modules/__init__.py | 20 +- hydragnn/utils/mace_utils/modules/blocks.py | 373 +----- hydragnn/utils/mace_utils/modules/loss.py | 4 +- hydragnn/utils/mace_utils/modules/models.py | 6 +- hydragnn/utils/mace_utils/modules/radial.py | 4 +- .../modules/symmetric_contraction.py | 2 +- hydragnn/utils/mace_utils/modules/utils.py | 6 +- .../mace_utils/tools/finetuning_utils.py | 2 +- hydragnn/utils/model/model.py | 47 +- tests/inputs/ci.json | 2 + tests/inputs/ci_equivariant.json | 2 + tests/inputs/ci_multihead.json | 2 + tests/inputs/ci_vectoroutput.json | 2 + tests/test_graphs.py | 41 +- 23 files changed, 757 insertions(+), 965 deletions(-) create mode 100644 .vscode/.vscode/launch.json create mode 100644 .vscode/launch.json diff --git a/.vscode/.vscode/launch.json b/.vscode/.vscode/launch.json new file mode 100644 index 000000000..6b76b4fab --- /dev/null +++ b/.vscode/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..ba662d2e6 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,74 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + + // "version": "0.2.0", + // "configurations": [ + // { + // "name": "Python Debugger: Current File with Conda", + // "type": "python", + // "request": "launch", + // "program": "${file}", + // "console": "integratedTerminal", + // "pythonPath": "/opt/anaconda3/envs/Force/bin/python", + // "env": { + // "PYTHONPATH": "/Users/r9w/Coding/HydraGNN:/Users/r9w/Coding/matsciml" + // } + // } + // ] + + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File with Conda", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "pythonPath": "/opt/anaconda3/envs/HYDRAEq/bin/python", + "env": { + "PYTHONPATH": "/Users/r9w/Coding/My_Fork/HydraGNN" + + }, + // "args": ["--pickle"] + } + ] + + // "version": "0.2.0", + // "configurations": [ + // { + // "name": "Python Debugger: Current File with Conda", + // "type": "python", + // "request": "launch", + // "program": "${file}", + // "console": "integratedTerminal", + // "pythonPath": "/opt/anaconda3/envs/MACE2/bin/python", + // "args": [ + // "--name=MACE_model", + // "--train_file=train.xyz", + // "--valid_fraction=0.05", + // "--test_file=test.xyz", + // "--config_type_weights={\"Default\":1.0}", + // "--E0s={1:-13.663181292231226, 6:-1029.2809654211628, 7:-1484.1187695035828, 8:-2042.0330099956639, 9:-256.207655}", + // "--model=MACE", + // "--hidden_irreps=128x0e + 128x1o", + // "--r_max=5.0", + // "--batch_size=200", + // "--max_num_epochs=1500", + // "--swa", + // "--start_swa=1200", + // "--ema", + // "--ema_decay=0.99", + // "--amsgrad", + // "--restart_latest" + // ], + // "env": { + // "PYTHONPATH": "/Users/r9w/Coding/MACE/mace" + + // }, + // } + // ] + + +} \ No newline at end of file diff --git a/examples/qm9/qm9.json b/examples/qm9/qm9.json index 78b5c7c96..789ba5ccf 100644 --- a/examples/qm9/qm9.json +++ b/examples/qm9/qm9.json @@ -5,12 +5,15 @@ "NeuralNetwork": { "Profile": {"enable": 1}, "Architecture": { - "model_type": "GIN", + "model_type": "MACE", "radius": 7, + "num_radial": 6, + "max_ell": 1, + "node_max_ell": 1, "max_neighbours": 5, "periodic_boundary_conditions": false, - "hidden_dim": 5, - "num_conv_layers": 6, + "hidden_dim": 15, + "num_conv_layers": 3, "output_heads": { "graph":{ "num_sharedlayers": 2, @@ -30,7 +33,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 2, + "num_epoch": 100, "perc_train": 0.7, "loss_function_type": "mse", "batch_size": 64, diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 2537f945f..c1313ad60 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -18,45 +18,51 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +# NOTE MACE Architecture: +## There are two key ideas of MACE: +### (1) Message passing and interaction blocks are equivariant to the O(3) group. And invariant to the T(3) group (translations). +### (2) Predictions are made in an n-body expansion, where n is the numnber of layers. This is done by creating multi-body +### interactions, then decoding them. Layer 1 will decode 1-body interactions, layer 2 will decode w-body interactions, +### and so on. So, for a 3-layer model predicting energy, there are 3 outputs for energy, one at each layer, and they +### are summed at the end. This requires some adjustment to the behavior from Base.py + from typing import Any, Callable, Dict, List, Optional, Type, Union -import numpy as np +# Torch import torch -from e3nn import o3 -from e3nn.util.jit import compile_mode +from torch.nn import ModuleList, Sequential +from torch.utils.checkpoint import checkpoint +from torch_scatter import scatter + +# Torch Geo from torch_geometric.nn import ( Sequential as PyGSequential, ) # This naming is because there is torch.nn.Sequential and torch_geometric.nn.Sequential - +from torch_geometric.nn import global_mean_pool # Mace -from hydragnn.utils.mace_utils.data import AtomicData -from hydragnn.utils.mace_utils.modules import ZBLBasis -from hydragnn.utils.mace_utils.tools.scatter import scatter_sum - from hydragnn.utils.mace_utils.modules.blocks import ( - AtomicEnergiesBlock, EquivariantProductBasisBlock, - InteractionBlock, - LinearDipoleReadoutBlock, LinearNodeEmbeddingBlock, - LinearReadoutBlock, - NonLinearDipoleReadoutBlock, - NonLinearReadoutBlock, RadialEmbeddingBlock, - ScaleShiftBlock, + RealAgnosticAttResidualInteractionBlock, ) from hydragnn.utils.mace_utils.modules.utils import ( - compute_fixed_charge_dipole, - compute_forces, get_edge_vectors_and_lengths, - get_outputs, - get_symmetric_displacement, ) +# E3NN +from e3nn import nn, o3 +from e3nn.util.jit import compile_mode + + # HydraGNN from .Base import Base +# Etc +import numpy as np +import math + # pylint: disable=C0302 @@ -64,137 +70,62 @@ class MACEStack(Base): def __init__( self, - radius: float, - num_radial: int, - irreps_cutoff: int, # What is this for? - max_ell: int, # Max l-type for CG-tensor product - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, + r_max: float, # The cutoff radius for the radial basis functions and edge_index + num_bessel: int, # The number of radial bessel functions. This dictates the richness of radial information in message-passing. + max_ell: int, # Max l-type for CG-tensor product. Theoretically, there is no max l-type, but in practice, we need to truncate the CG-tensor product to keep tractible computation + node_max_ell: int, # Max l-type for node features avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: Union[int, List[int]], - gate: Optional[Callable], - pair_repulsion: bool = False, - distance_transform: str = "None", - radial_MLP: Optional[List[int]] = None, - radial_type: Optional[str] = "bessel", + num_polynomial_cutoff, # The polynomial cutoff function ensures that the function goes to zero at the cutoff radius smoothly. Same as envelope_exponent for DimeNet + correlation, # Used in the product basis block and *roughly* determines the richness of interaction in the n-body interaction of layer 'n'. + radial_type, # The type of radial basis function to use *args, **kwargs, ): - self.num_radial = num_radial - self.radius = radius - self.irreps_cutoff = irreps_cutoff - self.max_ell = max_ell - self.interaction_cls = interaction_cls - self.interaction_cls_first = interaction_cls_first - self.num_interactions = num_interactions - self.num_elements = num_elements - self.hidden_irreps = hidden_irreps - self.MLP_irreps = MLP_irreps - self.atomic_energies = atomic_energies + """Notes On MACEStack Arguments:""" + # MACE args that we have given definitions for and the reasons why: + ## Note: These can be changed in the future if the desired argument options change + ## interaction_cls / interaction_cls_first: The choice of interaction block type should not make much of a difference and would require more imports in create.py and/or string handling + ## Atomic Energies: This is not agnostic to what we're predicting, which is a requirement of HYDRA. We also don't have base atomic energies to load, so we simply one-hot encode the atomic numbers and train. + ## Atomic Numbers / num_elements: It's more robust in preventing errors to just cover the entire periodic table (1-118) + + # MACE args that we have dropped and the resons why: + ## pair repulsion, distance_transform, compute_virials, etc: HYDRA's framework is meant to compute based on graph or node type, so must be agnostic to these property specific types of computations + + # MACE args constructed by HYDRA args + ## Reasoning: Oftentimes, MACE arguments show similarity to HYDRA arguments, but are labelled differently + ## num_interactions is represented by num_conv_layers + ## radial_MLP uses ceil(hidden_dim/3) for its layer sizes + ## hidden_irreps and MLP_irreps are constructed from hidden_dim + ## - Note that this is a nontrivial choice... reconstructing irreps allows users to be unfamiliar with the e3nn library, and is more attached to the HYDRA framework, but limits customization slightly + ## - I use a hidden_max_ell argument to allow the user to set max ell in the hidden dimensions as well + """""" + + # Init Args + ## Passed + self.node_max_ell = node_max_ell + num_interactions = kwargs["num_conv_layers"] self.avg_num_neighbors = avg_num_neighbors - self.atomic_numbers = atomic_numbers - self.correlation = correlation - self.gate = gate - self.pair_repulsion = pair_repulsion - self.distance_transform = distance_transform - self.radial_MLP = radial_MLP - self.radial_type = radial_type - + ## Defined + self.interaction_cls = RealAgnosticAttResidualInteractionBlock + self.interaction_cls_first = RealAgnosticAttResidualInteractionBlock + self.num_elements = 118 # Number of elements in the periodic table + atomic_numbers = list(range(1, self.num_elements+1)) + # Optional + num_polynomial_cutoff = 5 if num_polynomial_cutoff is None else num_polynomial_cutoff + self.correlation = [2] if correlation is None else correlation + radial_type = "bessel" if radial_type is None else radial_type + + # Making Irreps + self.node_attr_irreps = o3.Irreps([(self.num_elements, (0, 1))]) # 118 is the number of elements in the periodic table + self.sh_irreps = o3.Irreps.spherical_harmonics(max_ell) # This makes the irreps string + self.edge_feats_irreps = o3.Irreps(f"{num_bessel}x0e") + super().__init__(*args, **kwargs) - - def get_conv(self, input_dim, output_dim): - conv = MACEConv( - self, - r_max=self.radius, - num_bessel=self.num_radial, - max_ell=self.max_ell, - interaction_cls=self.interaction_cls, - interaction_cls_first=self.interaction_cls_first, - num_interactions=self.num_interactions, - num_elements=self.num_elements, - hidden_irreps=self.hidden_irreps, - MLP_irreps=self.MLP_irreps, - atomic_energies=self.atomic_energies, - avg_num_neighbors=self.avg_num_neighbors, - atomic_numbers=self.atomic_numbers, - correlation=self.correlation, - gate=self.gate, - pair_repulsion=self.pair_repulsion, - distance_transform=self.distance_transform, - radial_MLP=self.radial_MLP, - radial_type=self.radial_type, - ) - - input_args = "x, pos, edge_index, rbf" - conv_args = "x, edge_index, rbf" - - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - - return PyGSequential( - input_args, - [ - (conv, conv_args + " -> x"), - (lambda x, pos: [x, pos], "x, pos -> x, pos"), - ], + + self.spherical_harmonics = o3.SphericalHarmonics( + self.sh_irreps, normalize=True, normalization="component" # This makes the spherical harmonic class to be called with forward ) - - def _conv_args(self, data): - assert ( - data.pos is not None - ), "PNA+ requires node positions (data.pos) to be set." - - j, i = data.edge_index # j->i - dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt() - rbf = self.rbf(dist) - # rbf = dist.unsqueeze(-1) - conv_args = {"edge_index": data.edge_index.to(torch.long), "rbf": rbf} - - if self.use_edge_attr: - assert ( - data.edge_attr is not None - ), "Data must have edge attributes if use_edge_attributes is set." - conv_args.update({"edge_attr": data.edge_attr}) - - return conv_args - - def __str__(self): - return "PNAStack" - - - -@compile_mode("script") -class MACEConv(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: Union[int, List[int]], - gate: Optional[Callable], - pair_repulsion: bool = False, - distance_transform: str = "None", - radial_MLP: Optional[List[int]] = None, - radial_type: Optional[str] = "bessel", - ): - super().__init__() + # Register buffers are made when parameters need to be saved and transferred with the model, but not trained. self.register_buffer( "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) @@ -206,472 +137,433 @@ def __init__( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) if isinstance(correlation, int): - correlation = [correlation] * num_interactions - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) + self.correlation = [self.correlation] * self.num_interactions self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, radial_type=radial_type, - distance_transform=distance_transform, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) - self.pair_repulsion = True - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) # This makes the irreps string - num_features = hidden_irreps.count(o3.Irrep(0, 1)) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() #.sort() is a tuple, so we need the [0] element for the sorted result - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" # This makes the spherical harmonic class to be called with forward - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readout - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) # For atom ground-state energy. It takes a one-hot encoding of atom types and returns the energy of each atom type - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, + distance_transform=None, ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation[0], - num_elements=num_elements, - use_sc=use_sc_first, + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=self.node_attr_irreps, irreps_out=create_irreps_string(self.hidden_dim, 0) # Changed this to hidden_dim because no longer had node_feats_irreps ) - self.products = torch.nn.ModuleList([prod]) - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearReadoutBlock(hidden_irreps)) - for i in range(num_interactions - 1): - if i == num_interactions - 2: - hidden_irreps_out = str( - hidden_irreps[0] - ) # Select only scalars for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, + def _init_conv(self): + # Multihead Decoders + ## This integrates HYDRA multihead nature with MACE's layer-wise readouts + ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance + self.multihead_decoders = ModuleList() + hidden_irreps = o3.Irreps(create_irreps_string(self.hidden_dim, self.node_max_ell)) + final_hidden_irreps = o3.Irreps(create_irreps_string(self.hidden_dim, 0)) # Only scalars are outputted in the last layer + + last_layer = 1 == self.num_conv_layers + + self.multihead_decoders.append(MultiheadDecoderBlock(self.node_attr_irreps, self.node_max_ell, self.config_heads, self.head_dims, self.head_type, self.num_heads, self.activation_function, self.num_nodes, nonlinear=True)) # For base-node traits + self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim, first_layer=True)) + self.multihead_decoders.append(MultiheadDecoderBlock(hidden_irreps, self.node_max_ell, self.config_heads, self.head_dims, self.head_type, self.num_heads, self.activation_function, self.num_nodes, nonlinear=last_layer)) + for i in range(self.num_conv_layers - 1): + last_layer = i == self.num_conv_layers - 2 + conv = self.get_conv(self.hidden_dim, self.hidden_dim, last_layer=last_layer) + self.graph_convs.append(conv) + self.multihead_decoders.append(MultiheadDecoderBlock(final_hidden_irreps, self.node_max_ell, self.config_heads, self.head_dims, self.head_type, self.num_heads, self.activation_function, self.num_nodes, nonlinear=last_layer)) # Last layer will be nonlinear node decoding + + def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): + hidden_dim = output_dim if input_dim == 1 else input_dim + + # All of these should be constructed with HYDRA dimensional arguments + ## Radial + radial_MLP_dim = math.ceil(float(hidden_dim) / 3) # Go based off hidden_dim for radial_MLP + radial_MLP = [radial_MLP_dim, radial_MLP_dim, radial_MLP_dim] + ## Input, Hidden, and Output irreps sizing (this is usually just hidden in MACE) + ### Input dimensions are handled implicitly + ### Hidden + hidden_irreps = create_irreps_string(hidden_dim, self.node_max_ell) + hidden_irreps = o3.Irreps(hidden_irreps) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() ## This makes it a requirement that hidden irreps all have the same number of channels + interaction_irreps = (self.sh_irreps * num_features).sort()[0].simplify() #.sort() is a tuple, so we need the [0] element for the sorted result + ### Output + output_irreps = create_irreps_string(output_dim, self.node_max_ell) + output_irreps = o3.Irreps(output_irreps) + + # Constructing convolutional layers + if first_layer: + hidden_irreps_out = hidden_irreps + inter = self.interaction_cls_first( + node_attrs_irreps=self.node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=self.sh_irreps, + edge_feats_irreps=self.edge_feats_irreps, + target_irreps=interaction_irreps, # Replace with output? + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=self.avg_num_neighbors, + radial_MLP=radial_MLP, + ) + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(self.interaction_cls_first): + use_sc_first = True + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps, + correlation=self.correlation[0], + num_elements=self.num_elements, + use_sc=use_sc_first, + ) + sizing = o3.Linear(hidden_irreps_out, output_irreps) # Change sizing to output_irreps + elif last_layer: + # Select only scalars output for last layer + hidden_irreps_out = str( + hidden_irreps[0] + ) + output_irreps = str( + output_irreps[0] + ) + inter = self.interaction_cls( + node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, + edge_attrs_irreps=self.sh_irreps, + edge_feats_irreps=self.edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, + avg_num_neighbors=self.avg_num_neighbors, radial_MLP=radial_MLP, ) - self.interactions.append(inter) prod = EquivariantProductBasisBlock( node_feats_irreps=interaction_irreps, target_irreps=hidden_irreps_out, - correlation=correlation[i + 1], - num_elements=num_elements, + correlation=self.correlation[0], + num_elements=self.num_elements, use_sc=True, ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) - ) - else: - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + sizing = o3.Linear(hidden_irreps_out, output_irreps) # Change sizing to output_irreps + else: + hidden_irreps_out = hidden_irreps + inter = self.interaction_cls( + node_attrs_irreps=self.node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=self.sh_irreps, + edge_feats_irreps=self.edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=self.avg_num_neighbors, + radial_MLP=radial_MLP, + ) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=self.correlation[0], # Should this be i+1? + num_elements=self.num_elements, + use_sc=True, + ) + sizing = o3.Linear(hidden_irreps_out, output_irreps) # Change sizing to output_irreps - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], + + input_args = "node_attributes, pos, node_features, edge_attributes, edge_features, edge_index" + # readout_args = "node_energies" + conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward + + if self.use_edge_attr: + input_args += ", edge_attr" + conv_args += ", edge_attr" + + if not last_layer: + return PyGSequential( + input_args, + [ + (inter, "node_features, " + conv_args + " -> node_features, sc"), + (prod, "node_features, sc, node_attributes -> node_features"), + (sizing, "node_features -> node_features"), + (lambda node_features, pos: [node_features, pos], "node_features, pos -> node_features, pos"), + ], + ) + else: + return PyGSequential( + input_args, + [ + (inter, "node_features, " + conv_args + " -> node_features, sc"), + (prod, "node_features, sc, node_attributes -> node_features"), + (sizing, "node_features -> node_features"), + (lambda node_features, pos: [node_features, pos], "node_features, pos -> node_features, pos"), + ], ) + + def forward(self, data): + data, conv_args = self._conv_args(data) + node_features = data.node_features + node_attributes = data.node_attributes + pos = data.pos + + ### encoder / decoder part #### + ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance + + ### There is a readout before the first convolution layer ### + outputs = [] + output = self.multihead_decoders[0](data, node_attributes) # [index][n_output, size_output] + # Create outputs first + outputs = output + + ### Do conv --> readout --> repeat for each convolution layer ### + for conv, readout in zip(self.graph_convs, self.multihead_decoders[1:]): + if not self.conv_checkpointing: + node_features, pos = conv(node_features=node_features, pos=pos, **conv_args) + output = readout(data, node_features) # [index][n_output, size_output] + else: + node_features, pos = checkpoint( + conv, use_reentrant=False, node_features=node_features, pos=pos, **conv_args + ) + output = readout(data, node_features) # output is a list of tensors with [index][n_output, size_output] + # Sum predictions for each index, taking care of size differences + for idx, prediction in enumerate(output): + outputs[idx] = outputs[idx] + prediction + + return outputs - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + def _conv_args(self, data): + assert ( + data.pos is not None + ), "MACE requires node positions (data.pos) to be set." + + # Center positions at 0 per graph. This is a requirement for equivariant models that + # initialize the spherical harmonics, since the initial spherical harmonic projection + # uses the nodal position vector x/||x|| as the input to the spherical harmonics. + # If we didn't center at 0, these models wouldn't even be invariant to translation. + mean_pos = scatter(data.pos, data.batch, dim=0, reduce="mean") + data.pos = data.pos - mean_pos[data.batch] + + # Create node_attrs from atomic numbers. Later on it may contain more information + ## Node attrs are intrinsic properties of the atoms, like charge, atomic number, etc.. + ## data.node_attrs is already used as a method or smt in another place, so has been renamed to data.node_attributes from MACE and same with other data variable names + one_hot = torch.nn.functional.one_hot(data["x"].long().squeeze(-1), num_classes=118).float() # [n_atoms, 118] ## 118 atoms in the peridoic table + data.node_attributes = one_hot # To-Do: Add more information to node_attrs + data.shifts = torch.zeros((data.edge_index.shape[1], 3), dtype=data.pos.dtype, device=data.pos.device) # Shifts takes into account pbc conditions, but I believe we already generate data.pos to take it into account + # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) + node_feats = self.node_embedding(data["node_attributes"]) vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], + positions=data["pos"], edge_index=data["edge_index"], shifts=data["shifts"], ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - pair_energy = scatter_sum( - src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - else: - pair_node_energy = torch.zeros_like(node_e0) - pair_energy = torch.zeros_like(e0) - - # Interactions - energies = [e0, pair_energy] - node_energies_list = [node_e0, pair_node_energy] - node_feats_list = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_feats_list.append(node_feats) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_energies_list.append(node_energies) - - # Concatenate node features - node_feats_out = torch.cat(node_feats_list, dim=-1) - - # Sum over energy contributions - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - - # Outputs - forces, virials, stress, hessian = get_outputs( - energy=total_energy, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, + edge_attributes = self.spherical_harmonics(vectors) + edge_features = self.radial_embedding( + lengths, data["node_attributes"], data["edge_index"], self.atomic_numbers ) - - return { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, - "hessian": hessian, - "node_feats": node_feats_out, + + # Variable names + data.node_features = node_feats + data.edge_attributes = edge_attributes + data.edge_features = edge_features + data.lengths = lengths + + conv_args = { + "node_attributes": data.node_attributes, + "edge_attributes": data.edge_attributes, + "edge_features": data.edge_features, + "edge_index": data.edge_index, } + return data, conv_args + + + def _multihead(self): + # NOTE Multihead is skipped as it's an integral part of MACE's architecture to have a decoder after every layer, + # and a convolutional layer in decoding is not supported. Therefore, this final step is not necessary for MACE. + # However, various parts of multihead are applied in the MultiheadLinearBlock and MultiheadNonLinearBlock classes. + pass + def __str__(self): + return "MACEStack" + + + +def create_irreps_string(n: int, ell: int): # Custom function to allow for use of HYDRA arguments in creating irreps + irreps = [f"{n}x{ell}{'e' if ell % 2 == 0 else 'o'}" for ell in range(ell + 1)] + return " + ".join(irreps) + - +@compile_mode("script") +class MultiheadDecoderBlock(torch.nn.Module): + def __init__(self, input_irreps, node_max_ell, config_heads, head_dims, head_type, num_heads, activation_function, num_nodes, nonlinear=False): + super(MultiheadDecoderBlock, self).__init__() + self.input_irreps = input_irreps + self.node_max_ell = node_max_ell if not nonlinear else 0 + self.config_heads = config_heads + self.head_dims = head_dims + self.head_type = head_type + self.num_heads = num_heads + self.activation_function = activation_function + self.num_nodes = num_nodes + + self.graph_shared = None + self.node_NN_type = None + self.heads = ModuleList() + + # Create shared dense layers for graph-level output if applicable + if "graph" in self.config_heads: + graph_input_irreps = o3.Irreps(f"{self.input_irreps.count(o3.Irrep(0, 1))}x0e") + dim_sharedlayers = self.config_heads["graph"]["dim_sharedlayers"] + sharedlayers_irreps = o3.Irreps(f"{dim_sharedlayers}x0e") + denselayers = [] + denselayers.append(o3.Linear(graph_input_irreps, sharedlayers_irreps)) + denselayers.append(nn.Activation(irreps_in=sharedlayers_irreps, acts=[self.activation_function])) + for _ in range(self.config_heads["graph"]["num_sharedlayers"] - 1): + denselayers.append(o3.Linear(sharedlayers_irreps, sharedlayers_irreps)) + denselayers.append(nn.Activation(irreps_in=sharedlayers_irreps, acts=[self.activation_function])) + self.graph_shared = Sequential(*denselayers) + + # Create layers for each head + for ihead in range(self.num_heads): + if self.head_type[ihead] == "graph": + num_layers_graph = self.config_heads["graph"]["num_headlayers"] + hidden_dim_graph = self.config_heads["graph"]["dim_headlayers"] + denselayers = [] + head_hidden_irreps = o3.Irreps(f"{hidden_dim_graph[0]}x0e") + denselayers.append(o3.Linear(sharedlayers_irreps, head_hidden_irreps)) + denselayers.append(nn.Activation(irreps_in=head_hidden_irreps, acts=[self.activation_function])) + for ilayer in range(num_layers_graph - 1): + input_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer]}x0e") + output_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer + 1]}x0e") + denselayers.append(o3.Linear(input_irreps, output_irreps)) + denselayers.append(nn.Activation(irreps_in=output_irreps, acts=[self.activation_function])) + input_irreps = o3.Irreps(f"{hidden_dim_graph[-1]}x0e") + output_irreps = o3.Irreps(f"{self.head_dims[ihead]}x0e") + denselayers.append(o3.Linear(input_irreps, output_irreps)) + self.heads.append(Sequential(*denselayers)) + elif self.head_type[ihead] == "node": + self.node_NN_type = self.config_heads["node"]["type"] + head = ModuleList() + if self.node_NN_type == "mlp" or self.node_NN_type == "mlp_per_node": + self.num_mlp = 1 if self.node_NN_type == "mlp" else self.num_nodes + assert ( + self.num_nodes is not None + ), "num_nodes must be a positive integer for MLP" + num_layers_node = self.config_heads["node"]["num_headlayers"] + hidden_dim_node = self.config_heads["node"]["dim_headlayers"] + head = MLPNode( + self.input_irreps, + self.node_max_ell, + self.config_heads, + num_layers_node, + hidden_dim_node, + self.head_dims[ihead], + self.num_mlp, + self.num_nodes, + self.config_heads["node"]["type"], + self.activation_function, + nonlinear=nonlinear + ) + self.heads.append(head) + else: + raise ValueError(f"Unknown head NN structure for node features: {self.node_NN_type}") + else: + raise ValueError(f"Unknown head type: {self.head_type[ihead]}; supported types are 'graph' or 'node'") + + def forward(self, data, node_features): + if data.batch is None: + graph_features = node_features[:,:self.hidden_dim].mean(dim=0, keepdim=True) # Need to take only the type-0 irreps for aggregation + else: + graph_features = global_mean_pool(node_features[:,:self.input_irreps.count(o3.Irrep(0, 1))], data.batch.to(node_features.device)) + outputs = [] + for (headloc, type_head) in zip(self.heads, self.head_type): + if type_head == "graph": + x_graph_head = self.graph_shared(graph_features) + outputs.append(headloc(x_graph_head)) + else: # Node-level output + if self.node_NN_type == "conv": + raise ValueError("Node-level convolutional layers are not supported in MACE") + else: + x_node = headloc(node_features, data.batch) + outputs.append(x_node) + return outputs @compile_mode("script") -class MACE(torch.nn.Module): +class MLPNode(torch.nn.Module): def __init__( self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: Union[int, List[int]], - gate: Optional[Callable], - pair_repulsion: bool = False, - distance_transform: str = "None", - radial_MLP: Optional[List[int]] = None, - radial_type: Optional[str] = "bessel", + input_irreps, + node_max_ell, + config_heads, + num_layers, + hidden_dims, + output_dim, + num_mlp, + num_nodes, + node_type, + activation_function, + nonlinear=False ): super().__init__() - # Register buffers are made when parameters need to be saved and transferred with the model, but not trained. - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + self.input_irreps = input_irreps + self.hidden_dims = hidden_dims + self.output_dim = output_dim + self.node_max_ell = node_max_ell if not nonlinear else 0 + self.config_heads = config_heads + self.num_layers = num_layers + self.node_type = node_type + self.num_mlp = num_mlp + self.num_nodes = num_nodes + self.activation_function = activation_function + + self.mlp = ModuleList() + + # Create dense layers for each MLP based on node_type ("mlp" or "mlp_per_node") + for _ in range(self.num_mlp): + denselayers = [] + + # Input and hidden irreps for each MLP layer + input_irreps = input_irreps + hidden_irreps = o3.Irreps(f"{hidden_dims[0]}x0e") # Hidden irreps + + denselayers.append(o3.Linear(input_irreps, hidden_irreps)) + denselayers.append(nn.Activation(irreps_in=hidden_irreps, acts=[self.activation_function])) + + # Add intermediate layers + for ilayer in range(self.num_layers - 1): + input_irreps = o3.Irreps(f"{hidden_dims[ilayer]}x0e") + hidden_irreps = o3.Irreps(f"{hidden_dims[ilayer + 1]}x0e") + denselayers.append(o3.Linear(input_irreps, hidden_irreps)) + denselayers.append(nn.Activation(irreps_in=hidden_irreps, acts=[self.activation_function])) + + # Last layer + hidden_irreps = o3.Irreps(f"{hidden_dims[-1]}x0e") + output_irreps = o3.Irreps(f"{self.output_dim}x0e") # Assuming head_dims has been passed for the final output + denselayers.append(o3.Linear(hidden_irreps, output_irreps)) + + # Append to MLP + self.mlp.append(Sequential(*denselayers)) + + def node_features_reshape(self, node_features, batch): + """Reshape node_features from [batch_size*num_nodes, num_features] to [batch_size, num_features, num_nodes]""" + num_features = node_features.shape[1] + batch_size = batch.max() + 1 + out = torch.zeros( + (batch_size, num_features, self.num_nodes), + dtype=node_features.dtype, + device=node_features.device, ) - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - if isinstance(correlation, int): - correlation = [correlation] * num_interactions - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - radial_type=radial_type, - distance_transform=distance_transform, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) - self.pair_repulsion = True - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) # This makes the irreps string - num_features = hidden_irreps.count(o3.Irrep(0, 1)) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() #.sort() is a tuple, so we need the [0] element for the sorted result - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" # This makes the spherical harmonic class to be called with forward - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readout - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) # For atom ground-state energy. It takes a one-hot encoding of atom types and returns the energy of each atom type - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation[0], - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearReadoutBlock(hidden_irreps)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - hidden_irreps_out = str( - hidden_irreps[0] - ) # Select only scalars for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation[i + 1], - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) - ) - else: - self.readouts.append(LinearReadoutBlock(hidden_irreps)) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - pair_energy = scatter_sum( - src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + out[:, :, inode] = node_features[inode_index, :] + return out + + def forward(self, node_features: torch.Tensor, batch: torch.Tensor): + if self.node_type == "mlp": + outs = self.mlp[0](node_features) else: - pair_node_energy = torch.zeros_like(node_e0) - pair_energy = torch.zeros_like(e0) - - # Interactions - energies = [e0, pair_energy] - node_energies_list = [node_e0, pair_node_energy] - node_feats_list = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], + outs = torch.zeros( + (node_features.shape[0], self.head_dims[0]), # Assuming `head_dims` defines the final output dimension + dtype=node_features.dtype, + device=node_features.device, ) - node_feats_list.append(node_feats) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_energies_list.append(node_energies) - - # Concatenate node features - node_feats_out = torch.cat(node_feats_list, dim=-1) - - # Sum over energy contributions - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - - # Outputs - forces, virials, stress, hessian = get_outputs( - energy=total_energy, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, - ) - - return { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, - "hessian": hessian, - "node_feats": node_feats_out, - } + x_nodes = self.node_features_reshape(x, batch) + for inode in range(self.num_nodes): + inode_index = [i for i in range(inode, batch.shape[0], self.num_nodes)] + outs[inode_index, :] = self.mlp[inode](x_nodes[:, :, inode]) + return outs + def __str__(self): + return "MLPNode" diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index c39f1414f..8f6346ce7 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -12,6 +12,7 @@ import os import torch from torch_geometric.data import Data +from typing import List, Union from hydragnn.models.GINStack import GINStack from hydragnn.models.PNAStack import PNAStack @@ -23,6 +24,7 @@ from hydragnn.models.SCFStack import SCFStack from hydragnn.models.DIMEStack import DIMEStack from hydragnn.models.EGCLStack import EGCLStack +from hydragnn.models.MACEStack import MACEStack from hydragnn.utils.distributed import get_device from hydragnn.utils.profiling_and_tracing.time_utils import Timer @@ -53,6 +55,7 @@ def create_model_config( config["Architecture"]["num_before_skip"], config["Architecture"]["num_after_skip"], config["Architecture"]["num_radial"], + config["Architecture"]["radial_type"], config["Architecture"]["basis_emb_size"], config["Architecture"]["int_emb_size"], config["Architecture"]["out_emb_size"], @@ -62,6 +65,10 @@ def create_model_config( config["Architecture"]["num_filters"], config["Architecture"]["radius"], config["Architecture"]["equivariance"], + config["Architecture"]["correlation"], + config["Architecture"]["max_ell"], + config["Architecture"]["node_max_ell"], + config["Architecture"]["avg_num_neighbors"], config["Training"]["conv_checkpointing"], verbosity, use_gpu, @@ -89,6 +96,7 @@ def create_model( num_before_skip: int = None, num_after_skip: int = None, num_radial: int = None, + radial_type: str = None, basis_emb_size: int = None, int_emb_size: int = None, out_emb_size: int = None, @@ -98,6 +106,10 @@ def create_model( num_filters: int = None, radius: float = None, equivariance: bool = False, + correlation: Union[int, List[int]] = None, + max_ell: int = None, + node_max_ell: int = None, + avg_num_neighbors: int = None, conv_checkopinting: bool = False, verbosity: int = 0, use_gpu: bool = True, @@ -328,6 +340,36 @@ def create_model( num_conv_layers=num_conv_layers, num_nodes=num_nodes, ) + elif model_type == "MACE": + assert radius is not None, "MACE requires radius input." + assert num_radial is not None, "MACE requires num_radial input." + assert max_ell is not None, "MACE requires max_ell input." + assert node_max_ell is not None, "MACE requires node_max_ell input." + assert max_ell >= 1, "MACE requires max_ell >= 1." + assert node_max_ell >= 1, "MACE requires node_max_ell >= 1." + model = MACEStack( + radius, + num_radial, + max_ell, + node_max_ell, + avg_num_neighbors, + envelope_exponent, + correlation, + radial_type, + input_dim, + hidden_dim, + output_dim, + output_type, + output_heads, + activation_function, + loss_function_type, + equivariance, + loss_weights=task_weights, + freeze_conv=freeze_conv, + initial_bias=initial_bias, + num_conv_layers=num_conv_layers, + num_nodes=num_nodes, + ) else: raise ValueError("Unknown model_type: {0}".format(model_type)) diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index 86525c766..ce19024a5 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -14,6 +14,7 @@ check_if_graph_size_variable, gather_deg, ) +from hydragnn.utils.model import calculate_avg_deg from hydragnn.utils.distributed import get_comm_size_and_rank from copy import deepcopy import json @@ -54,6 +55,16 @@ def update_config(config, train_loader, val_loader, test_loader): config["NeuralNetwork"]["Architecture"]["max_neighbours"] = len(deg) - 1 else: config["NeuralNetwork"]["Architecture"]["pna_deg"] = None + + if config["NeuralNetwork"]["Architecture"]["model_type"] == "MACE": + if hasattr(train_loader.dataset, "avg_num_neighbors"): + ## Use avg neighbours used in the dataset. + avg_num_neighbors = torch.tensor(train_loader.dataset.avg_num_neighbors) + else: + avg_num_neighbors = float(calculate_avg_deg(train_loader.dataset)) + config["NeuralNetwork"]["Architecture"]["avg_num_neighbors"] = avg_num_neighbors + else: + config["NeuralNetwork"]["Architecture"]["avg_num_neighbors"] = None if "radius" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["radius"] = None @@ -77,6 +88,14 @@ def update_config(config, train_loader, val_loader, test_loader): config["NeuralNetwork"]["Architecture"]["num_radial"] = None if "num_spherical" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["num_spherical"] = None + if "radial_type" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["radial_type"] = None + if "correlation" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["correlation"] = None + if "max_ell" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["max_ell"] = None + if "node_max_ell" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["node_max_ell"] = None config["NeuralNetwork"]["Architecture"] = update_config_edge_dim( config["NeuralNetwork"]["Architecture"] @@ -112,11 +131,11 @@ def update_config(config, train_loader, val_loader, test_loader): def update_config_equivariance(config): - equivariant_models = ["EGNN", "SchNet"] + equivariant_models = ["EGNN", "SchNet", "MACE"] if "equivariance" in config and config["equivariance"]: assert ( config["model_type"] in equivariant_models - ), "E(3) equivariance can only be ensured for EGNN and SchNet." + ), "E(3) equivariance can only be ensured for EGNN, SchNet, and MACE." elif "equivariance" not in config: config["equivariance"] = False return config @@ -124,11 +143,11 @@ def update_config_equivariance(config): def update_config_edge_dim(config): config["edge_dim"] = None - edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN"] + edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN", "MACE"] if "edge_features" in config and config["edge_features"]: assert ( config["model_type"] in edge_models - ), "Edge features can only be used with EGNN, SchNet, PNA, PNAPlus, and CGCNN." + ), "Edge features can only be used with MACE, EGNN, SchNet, PNA, PNAPlus, and CGCNN." config["edge_dim"] = len(config["edge_features"]) elif config["model_type"] == "CGCNN": # CG always needs an integer edge_dim diff --git a/hydragnn/utils/mace_utils/data/atomic_data.py b/hydragnn/utils/mace_utils/data/atomic_data.py index edb91b14c..01815fdc5 100644 --- a/hydragnn/utils/mace_utils/data/atomic_data.py +++ b/hydragnn/utils/mace_utils/data/atomic_data.py @@ -8,7 +8,7 @@ import torch.utils.data -from mace.tools import ( +from hydragnn.utils.mace_utils.tools import ( AtomicNumberTable, atomic_numbers_to_indices, to_one_hot, diff --git a/hydragnn/utils/mace_utils/data/hdf5_dataset.py b/hydragnn/utils/mace_utils/data/hdf5_dataset.py index 5057fd7f1..12d0ff295 100644 --- a/hydragnn/utils/mace_utils/data/hdf5_dataset.py +++ b/hydragnn/utils/mace_utils/data/hdf5_dataset.py @@ -4,9 +4,9 @@ import h5py from torch.utils.data import ConcatDataset, Dataset -from mace.data.atomic_data import AtomicData -from mace.data.utils import Configuration -from mace.tools.utils import AtomicNumberTable +from hydragnn.utils.mace_utils.data.atomic_data import AtomicData +from hydragnn.utils.mace_utils.data.utils import Configuration +from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable class HDF5Dataset(Dataset): diff --git a/hydragnn/utils/mace_utils/data/utils.py b/hydragnn/utils/mace_utils/data/utils.py index 78e3e76fd..4b008079b 100644 --- a/hydragnn/utils/mace_utils/data/utils.py +++ b/hydragnn/utils/mace_utils/data/utils.py @@ -13,7 +13,7 @@ import h5py import numpy as np -from mace.tools import AtomicNumberTable +from hydragnn.utils.mace_utils.tools import AtomicNumberTable Vector = np.ndarray # [3,] Positions = np.ndarray # [..., 3] diff --git a/hydragnn/utils/mace_utils/modules/__init__.py b/hydragnn/utils/mace_utils/modules/__init__.py index 9278130fd..b767383f3 100644 --- a/hydragnn/utils/mace_utils/modules/__init__.py +++ b/hydragnn/utils/mace_utils/modules/__init__.py @@ -3,8 +3,8 @@ import torch from .blocks import ( - AgnosticNonlinearInteractionBlock, - AgnosticResidualNonlinearInteractionBlock, + # AgnosticNonlinearInteractionBlock, + # AgnosticResidualNonlinearInteractionBlock, AtomicEnergiesBlock, EquivariantProductBasisBlock, InteractionBlock, @@ -15,9 +15,9 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, - RealAgnosticInteractionBlock, - RealAgnosticResidualInteractionBlock, - ResidualElementDependentInteractionBlock, + # RealAgnosticInteractionBlock, + # RealAgnosticResidualInteractionBlock, + # ResidualElementDependentInteractionBlock, ScaleShiftBlock, ) from .loss import ( @@ -50,12 +50,12 @@ ) interaction_classes: Dict[str, Type[InteractionBlock]] = { - "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, - "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, - "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock, - "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, + # "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, + # "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, + # "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock, + # "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, - "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + # "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, } scaling_classes: Dict[str, Callable] = { diff --git a/hydragnn/utils/mace_utils/modules/blocks.py b/hydragnn/utils/mace_utils/modules/blocks.py index e8645a8e7..698a38678 100644 --- a/hydragnn/utils/mace_utils/modules/blocks.py +++ b/hydragnn/utils/mace_utils/modules/blocks.py @@ -12,8 +12,8 @@ from e3nn import nn, o3 from e3nn.util.jit import compile_mode -from mace.tools.compile import simplify_if_compile -from mace.tools.scatter import scatter_sum +from hydragnn.utils.mace_utils.tools.compile import simplify_if_compile +from hydragnn.utils.mace_utils.tools.scatter import scatter_sum from .irreps_tools import ( linear_out_irreps, @@ -63,7 +63,7 @@ def __init__( super().__init__() self.hidden_irreps = MLP_irreps self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) - self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) + self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) # Need to adjust this to actually use the gate self.linear_2 = o3.Linear( irreps_in=self.hidden_irreps, irreps_out=o3.Irreps("0e") ) @@ -148,7 +148,17 @@ def forward( def __repr__(self): formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) return f"{self.__class__.__name__}(energies=[{formatted_energies}])" + +@compile_mode("script") +class AtomicBlock(torch.nn.Module): + def __init__(self, output_dim): + super().__init__() + # Initialize the atomic energies as a trainable parameter + self.atomic_energies = torch.nn.Parameter(torch.randn(118, output_dim)) # There are 118 known elements + def forward(self, atomic_numbers): + # Perform the linear multiplication (no bias) + return atomic_numbers @ self.atomic_energies # Output will now have shape [batch_size, output_dim] @compile_mode("script") class RadialEmbeddingBlock(torch.nn.Module): @@ -303,352 +313,10 @@ def __repr__(self): ) -@compile_mode("script") -class ResidualElementDependentInteractionBlock(InteractionBlock): - def _setup(self) -> None: - self.linear_up = o3.Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps - ) - self.conv_tp = o3.TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - ) - self.conv_tp_weights = TensorProductWeightsBlock( - num_elements=self.node_attrs_irreps.num_irreps, - num_edge_feats=self.edge_feats_irreps.num_irreps, - num_feats_out=self.conv_tp.weight_numel, - ) - - # Linear - irreps_mid = irreps_mid.simplify() - self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) - self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True - ) - - # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out - ) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - sc = self.skip_tp(node_feats, node_attrs) - node_feats = self.linear_up(node_feats) - tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.linear(message) / self.avg_num_neighbors - return message + sc # [n_nodes, irreps] - - -@compile_mode("script") -class AgnosticNonlinearInteractionBlock(InteractionBlock): - def _setup(self) -> None: - self.linear_up = o3.Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps - ) - self.conv_tp = o3.TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - irreps_mid = irreps_mid.simplify() - self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) - self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True - ) - - # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out - ) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - tp_weights = self.conv_tp_weights(edge_feats) - node_feats = self.linear_up(node_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.linear(message) / self.avg_num_neighbors - message = self.skip_tp(message, node_attrs) - return message # [n_nodes, irreps] - - -@compile_mode("script") -class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): - def _setup(self) -> None: - # First linear - self.linear_up = o3.Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps - ) - self.conv_tp = o3.TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - irreps_mid = irreps_mid.simplify() - self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) - self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True - ) - - # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out - ) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - sc = self.skip_tp(node_feats, node_attrs) - node_feats = self.linear_up(node_feats) - tp_weights = self.conv_tp_weights(edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.linear(message) / self.avg_num_neighbors - message = message + sc - return message # [n_nodes, irreps] - - -@compile_mode("script") -class RealAgnosticInteractionBlock(InteractionBlock): - def _setup(self) -> None: - # First linear - self.linear_up = o3.Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = o3.TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - irreps_mid = irreps_mid.simplify() - self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True - ) - - # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out - ) - self.reshape = reshape_irreps(self.irreps_out) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - ) -> Tuple[torch.Tensor, None]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - node_feats = self.linear_up(node_feats) - tp_weights = self.conv_tp_weights(edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.linear(message) / self.avg_num_neighbors - message = self.skip_tp(message, node_attrs) - return ( - self.reshape(message), - None, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticResidualInteractionBlock(InteractionBlock): - def _setup(self) -> None: - # First linear - self.linear_up = o3.Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = o3.TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - irreps_mid = irreps_mid.simplify() - self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True - ) - - # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps - ) - self.reshape = reshape_irreps(self.irreps_out) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - sc = self.skip_tp(node_feats, node_attrs) - node_feats = self.linear_up(node_feats) - tp_weights = self.conv_tp_weights(edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.linear(message) / self.avg_num_neighbors - return ( - self.reshape(message), - sc, - ) # [n_nodes, channels, (lmax + 1)**2] - - @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.node_feats_down_irreps = o3.Irreps("64x0e") + self.node_feats_down_irreps = o3.Irreps("64x0e") # Interesting, this seems to be a required shaping # First linear self.linear_up = o3.Linear( self.node_feats_irreps, @@ -679,16 +347,19 @@ def _setup(self) -> None: shared_weights=True, ) input_dim = ( - self.edge_feats_irreps.num_irreps + self.edge_feats_irreps.num_irreps # The irreps here should be scalars because they are fed through an activation in self.conv_tp_weights + 2 * self.node_feats_down_irreps.num_irreps ) + # The following specifies the network architecture for embedding l=0 (scalar) irreps + ## It is worth double-checking, but I believe this means that type 0 (scalar) irreps + ## are being embedded by 3 layers of size 256 and the output dim, then activated. self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], + [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], torch.nn.functional.silu, ) # Linear - irreps_mid = irreps_mid.simplify() + irreps_mid = irreps_mid.simplify() # .simplify() essentially combines irreps of the same type so that normalization is done across them all. The site has an in-depth explanation self.irreps_out = self.target_irreps self.linear = o3.Linear( irreps_mid, @@ -700,12 +371,12 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Skip connection. - self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) + self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) # This will be size (num_nodes, 64*9) when there are 64 channels and irreps, 0, 1, 2 (1+3+5=9) ## This becomes sc def forward( self, - node_attrs: torch.Tensor, node_feats: torch.Tensor, + node_attrs: torch.Tensor, edge_attrs: torch.Tensor, edge_feats: torch.Tensor, edge_index: torch.Tensor, diff --git a/hydragnn/utils/mace_utils/modules/loss.py b/hydragnn/utils/mace_utils/modules/loss.py index b3421ef59..5c754defc 100644 --- a/hydragnn/utils/mace_utils/modules/loss.py +++ b/hydragnn/utils/mace_utils/modules/loss.py @@ -6,8 +6,8 @@ import torch -from mace.tools import TensorDict -from mace.tools.torch_geometric import Batch +from hydragnn.utils.mace_utils.tools import TensorDict +from hydragnn.utils.mace_utils.tools.torch_geometric import Batch def mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: diff --git a/hydragnn/utils/mace_utils/modules/models.py b/hydragnn/utils/mace_utils/modules/models.py index 3e5cb6626..e0fa51ee2 100644 --- a/hydragnn/utils/mace_utils/modules/models.py +++ b/hydragnn/utils/mace_utils/modules/models.py @@ -11,9 +11,9 @@ from e3nn import o3 from e3nn.util.jit import compile_mode -from mace.data import AtomicData -from mace.modules.radial import ZBLBasis -from mace.tools.scatter import scatter_sum +from hydragnn.utils.mace_utils.data import AtomicData +from hydragnn.utils.mace_utils.modules.radial import ZBLBasis +from hydragnn.utils.mace_utils.tools.scatter import scatter_sum from .blocks import ( AtomicEnergiesBlock, diff --git a/hydragnn/utils/mace_utils/modules/radial.py b/hydragnn/utils/mace_utils/modules/radial.py index a928c1847..94c0b8064 100644 --- a/hydragnn/utils/mace_utils/modules/radial.py +++ b/hydragnn/utils/mace_utils/modules/radial.py @@ -9,8 +9,8 @@ import torch from e3nn.util.jit import compile_mode -from mace.tools.compile import simplify_if_compile -from mace.tools.scatter import scatter_sum +from hydragnn.utils.mace_utils.tools.compile import simplify_if_compile +from hydragnn.utils.mace_utils.tools.scatter import scatter_sum @compile_mode("script") diff --git a/hydragnn/utils/mace_utils/modules/symmetric_contraction.py b/hydragnn/utils/mace_utils/modules/symmetric_contraction.py index 9db75da02..5c807c717 100644 --- a/hydragnn/utils/mace_utils/modules/symmetric_contraction.py +++ b/hydragnn/utils/mace_utils/modules/symmetric_contraction.py @@ -14,7 +14,7 @@ from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode -from mace.tools.cg import U_matrix_real +from hydragnn.utils.mace_utils.tools.cg import U_matrix_real BATCH_EXAMPLE = 10 ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index 37fef1bbd..c6a44fff6 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -13,9 +13,9 @@ import torch.utils.data from scipy.constants import c, e -from mace.tools import to_numpy -from mace.tools.scatter import scatter_sum -from mace.tools.torch_geometric.batch import Batch +from hydragnn.utils.mace_utils.tools import to_numpy +from hydragnn.utils.mace_utils.tools.scatter import scatter_sum +from hydragnn.utils.mace_utils.tools.torch_geometric.batch import Batch from .blocks import AtomicEnergiesBlock diff --git a/hydragnn/utils/mace_utils/tools/finetuning_utils.py b/hydragnn/utils/mace_utils/tools/finetuning_utils.py index 0aad091ba..9443add6f 100644 --- a/hydragnn/utils/mace_utils/tools/finetuning_utils.py +++ b/hydragnn/utils/mace_utils/tools/finetuning_utils.py @@ -1,6 +1,6 @@ import torch -from mace.tools.utils import AtomicNumberTable +from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable def load_foundations( diff --git a/hydragnn/utils/model/model.py b/hydragnn/utils/model/model.py index 6b6d3eb56..768ac1ff2 100644 --- a/hydragnn/utils/model/model.py +++ b/hydragnn/utils/model/model.py @@ -124,7 +124,7 @@ def load_existing_model( ## This function may cause OOM if datasets is too large ## to fit in a single GPU (i.e., with DDP). Use with caution. -## Recommend to use calculate_PNA_degree_dist +## Recommend to use calculate_PNA_degree_dist or calculate_avg_deg_dist def calculate_PNA_degree(loader, max_neighbours): backend = os.getenv("HYDRAGNN_AGGR_BACKEND", "torch") if backend == "torch": @@ -137,6 +137,21 @@ def calculate_PNA_degree(loader, max_neighbours): d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel())[: max_neighbours + 1] return deg + +def calculate_avg_deg(loader): + backend = os.getenv("HYDRAGNN_AGGR_BACKEND", "torch") + if backend == "torch": + return calculate_avg_deg_dist(loader) + elif backend == "mpi": + return calculate_avg_deg_mpi(loader) + else: + deg = 0 + counter = 0 + for data in iterate_tqdm(loader, 2, desc="Calculate avg degree"): + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + deg += d.sum() + counter += d.size(0) + return deg / counter def calculate_PNA_degree_dist(loader, max_neighbours): @@ -150,6 +165,22 @@ def calculate_PNA_degree_dist(loader, max_neighbours): deg = deg.detach().cpu() return deg +def calculate_avg_deg_dist(loader): + assert dist.is_initialized() + deg = 0 + counter = 0 + for data in iterate_tqdm(loader, 2, desc="Calculate avg degree"): + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + deg += d.sum() + counter += d.size(0) + deg = torch.tensor(deg) + counter = torch.tensor(counter) + dist.all_reduce(deg, op=dist.ReduceOp.SUM) + dist.all_reduce(counter, op=dist.ReduceOp.SUM) + deg = deg.detach().cpu() + counter = counter.detach().cpu() + return deg / counter + def calculate_PNA_degree_mpi(loader, max_neighbours): assert dist.is_initialized() @@ -162,6 +193,20 @@ def calculate_PNA_degree_mpi(loader, max_neighbours): deg = MPI.COMM_WORLD.allreduce(deg.numpy(), op=MPI.SUM) return torch.tensor(deg) +def calculate_avg_deg_mpi(loader): + assert dist.is_initialized() + deg = 0 + counter = 0 + for data in iterate_tqdm(loader, 2, desc="Calculate avg degree"): + d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) + deg += d.sum() + counter += d.size(0) + from mpi4py import MPI + + deg = MPI.COMM_WORLD.allreduce(deg, op=MPI.SUM) + counter = MPI.COMM_WORLD.allreduce(counter, op=MPI.SUM) + return deg / counter + def unsorted_segment_mean(data, segment_ids, num_segments): result_shape = (num_segments, data.size(1)) diff --git a/tests/inputs/ci.json b/tests/inputs/ci.json index 36613eaa8..951a44adf 100644 --- a/tests/inputs/ci.json +++ b/tests/inputs/ci.json @@ -38,6 +38,8 @@ "num_radial": 6, "num_spherical": 7, "num_filters": 126, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/inputs/ci_equivariant.json b/tests/inputs/ci_equivariant.json index 1afc93c21..51175e9e7 100644 --- a/tests/inputs/ci_equivariant.json +++ b/tests/inputs/ci_equivariant.json @@ -39,6 +39,8 @@ "num_radial": 6, "num_spherical": 7, "num_filters": 126, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/inputs/ci_multihead.json b/tests/inputs/ci_multihead.json index 4d0a743e4..f408c4aa8 100644 --- a/tests/inputs/ci_multihead.json +++ b/tests/inputs/ci_multihead.json @@ -36,6 +36,8 @@ "num_radial": 6, "num_spherical": 7, "num_filters": 126, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/inputs/ci_vectoroutput.json b/tests/inputs/ci_vectoroutput.json index 238e1c793..ddb616615 100644 --- a/tests/inputs/ci_vectoroutput.json +++ b/tests/inputs/ci_vectoroutput.json @@ -28,6 +28,8 @@ "max_neighbours": 100, "envelope_exponent": 5, "num_radial": 6, + "max_ell": 1, + "node_max_ell": 1, "periodic_boundary_conditions": false, "hidden_dim": 8, "num_conv_layers": 2, diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 6222d707b..2468d9df9 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -147,6 +147,7 @@ def unittest_train_model( "SchNet": [0.20, 0.20], "DimeNet": [0.50, 0.50], "EGNN": [0.20, 0.20], + "MACE": [0.60, 0.70], } if use_lengths and ("vector" not in ci_input): thresholds["CGCNN"] = [0.175, 0.175] @@ -206,6 +207,7 @@ def unittest_train_model( "SchNet", "DimeNet", "EGNN", + "MACE", ], ) @pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"]) @@ -214,26 +216,47 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): # Test only models -@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN"]) +# "PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", +@pytest.mark.parametrize("model_type", ["MACE"]) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) # Test across equivariant models -@pytest.mark.parametrize("model_type", ["EGNN", "SchNet"]) +# "EGNN", "SchNet", +@pytest.mark.parametrize("model_type", ["MACE"]) def pytest_train_equivariant_model(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_equivariant.json", False, overwrite_data) # Test vector output -@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus"]) +# "PNA", "PNAPlus", +@pytest.mark.parametrize("model_type", ["MACE"]) def pytest_train_model_vectoroutput(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data) -@pytest.mark.parametrize( - "model_type", - ["SAGE", "GIN", "GAT", "MFC", "PNA", "PNAPlus", "SchNet", "DimeNet", "EGNN"], -) -def pytest_train_model_conv_head(model_type, overwrite_data=False): - unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) +# @pytest.mark.parametrize( +# "model_type", +# ["SAGE", "GIN", "GAT", "MFC", "PNA", "PNAPlus", "SchNet", "DimeNet", "EGNN"], +# ) +# def pytest_train_model_conv_head(model_type, overwrite_data=False): +# unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) + + +# def debug_train_model_vectoroutput(model_type="MACE", overwrite_data=False): +# """ +# A function to test vector output that can be run in VSCode's debug mode. +# Set breakpoints in VSCode as needed for debugging. +# """ +# # Call the test function directly +# unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data) + +# # After execution, print some debug information if needed +# print(f"Finished training model: {model_type} with overwrite_data={overwrite_data}") + +# # Manual testing entry point +# if __name__ == "__main__": +# # Manually call the function with the desired parameters for debugging in VSCode +# debug_train_model_vectoroutput(model_type="MACE", overwrite_data=True) + From d1fc744037ce7a366c9797d2680ea1c880b97673 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 10:55:41 -0400 Subject: [PATCH 03/51] test first MACE push --- requirements-torch.txt | 3 ++- requirements.txt | 2 ++ tests/test_graphs.py | 23 ++++++++++------------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/requirements-torch.txt b/requirements-torch.txt index d0673bd15..b955d1b21 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,4 +1,5 @@ torch==2.0.1 torchvision torchaudio - +torch-ema +torchmetrics diff --git a/requirements.txt b/requirements.txt index 2e9bfef74..75ed57e75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ tqdm tensorboard psutil sympy +e3nn +matscipy diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 2468d9df9..1ea5fff05 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -215,33 +215,30 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): unittest_train_model(model_type, ci_input, False, overwrite_data) -# Test only models -# "PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", -@pytest.mark.parametrize("model_type", ["MACE"]) +# Test only models +@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"]) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) # Test across equivariant models -# "EGNN", "SchNet", -@pytest.mark.parametrize("model_type", ["MACE"]) +@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "MACE"]) def pytest_train_equivariant_model(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_equivariant.json", False, overwrite_data) # Test vector output -# "PNA", "PNAPlus", -@pytest.mark.parametrize("model_type", ["MACE"]) +@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "MACE"]) def pytest_train_model_vectoroutput(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data) -# @pytest.mark.parametrize( -# "model_type", -# ["SAGE", "GIN", "GAT", "MFC", "PNA", "PNAPlus", "SchNet", "DimeNet", "EGNN"], -# ) -# def pytest_train_model_conv_head(model_type, overwrite_data=False): -# unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) +@pytest.mark.parametrize( + "model_type", + ["SAGE", "GIN", "GAT", "MFC", "PNA", "PNAPlus", "SchNet", "DimeNet", "EGNN"], +) +def pytest_train_model_conv_head(model_type, overwrite_data=False): + unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) # def debug_train_model_vectoroutput(model_type="MACE", overwrite_data=False): From eb46b76b06fba040d412c100b63639c61104a369 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 11:09:32 -0400 Subject: [PATCH 04/51] formatting and clean-up --- .vscode/.vscode/launch.json | 15 - .vscode/launch.json | 74 ---- examples/qm9/qm9.json | 11 +- hydragnn/models/MACEStack.py | 316 +++++++++++++----- .../input_config_parsing/config_utils.py | 2 +- hydragnn/utils/mace_utils/modules/blocks.py | 30 +- tests/test_graphs.py | 7 +- 7 files changed, 255 insertions(+), 200 deletions(-) delete mode 100644 .vscode/.vscode/launch.json delete mode 100644 .vscode/launch.json diff --git a/.vscode/.vscode/launch.json b/.vscode/.vscode/launch.json deleted file mode 100644 index 6b76b4fab..000000000 --- a/.vscode/.vscode/launch.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "name": "Python Debugger: Current File", - "type": "debugpy", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal" - } - ] -} \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index ba662d2e6..000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - - // "version": "0.2.0", - // "configurations": [ - // { - // "name": "Python Debugger: Current File with Conda", - // "type": "python", - // "request": "launch", - // "program": "${file}", - // "console": "integratedTerminal", - // "pythonPath": "/opt/anaconda3/envs/Force/bin/python", - // "env": { - // "PYTHONPATH": "/Users/r9w/Coding/HydraGNN:/Users/r9w/Coding/matsciml" - // } - // } - // ] - - "version": "0.2.0", - "configurations": [ - { - "name": "Python Debugger: Current File with Conda", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "pythonPath": "/opt/anaconda3/envs/HYDRAEq/bin/python", - "env": { - "PYTHONPATH": "/Users/r9w/Coding/My_Fork/HydraGNN" - - }, - // "args": ["--pickle"] - } - ] - - // "version": "0.2.0", - // "configurations": [ - // { - // "name": "Python Debugger: Current File with Conda", - // "type": "python", - // "request": "launch", - // "program": "${file}", - // "console": "integratedTerminal", - // "pythonPath": "/opt/anaconda3/envs/MACE2/bin/python", - // "args": [ - // "--name=MACE_model", - // "--train_file=train.xyz", - // "--valid_fraction=0.05", - // "--test_file=test.xyz", - // "--config_type_weights={\"Default\":1.0}", - // "--E0s={1:-13.663181292231226, 6:-1029.2809654211628, 7:-1484.1187695035828, 8:-2042.0330099956639, 9:-256.207655}", - // "--model=MACE", - // "--hidden_irreps=128x0e + 128x1o", - // "--r_max=5.0", - // "--batch_size=200", - // "--max_num_epochs=1500", - // "--swa", - // "--start_swa=1200", - // "--ema", - // "--ema_decay=0.99", - // "--amsgrad", - // "--restart_latest" - // ], - // "env": { - // "PYTHONPATH": "/Users/r9w/Coding/MACE/mace" - - // }, - // } - // ] - - -} \ No newline at end of file diff --git a/examples/qm9/qm9.json b/examples/qm9/qm9.json index 789ba5ccf..78b5c7c96 100644 --- a/examples/qm9/qm9.json +++ b/examples/qm9/qm9.json @@ -5,15 +5,12 @@ "NeuralNetwork": { "Profile": {"enable": 1}, "Architecture": { - "model_type": "MACE", + "model_type": "GIN", "radius": 7, - "num_radial": 6, - "max_ell": 1, - "node_max_ell": 1, "max_neighbours": 5, "periodic_boundary_conditions": false, - "hidden_dim": 15, - "num_conv_layers": 3, + "hidden_dim": 5, + "num_conv_layers": 6, "output_heads": { "graph":{ "num_sharedlayers": 2, @@ -33,7 +30,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 100, + "num_epoch": 2, "perc_train": 0.7, "loss_function_type": "mse", "batch_size": 64, diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index c1313ad60..11b78591a 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -23,7 +23,7 @@ ### (1) Message passing and interaction blocks are equivariant to the O(3) group. And invariant to the T(3) group (translations). ### (2) Predictions are made in an n-body expansion, where n is the numnber of layers. This is done by creating multi-body ### interactions, then decoding them. Layer 1 will decode 1-body interactions, layer 2 will decode w-body interactions, -### and so on. So, for a 3-layer model predicting energy, there are 3 outputs for energy, one at each layer, and they +### and so on. So, for a 3-layer model predicting energy, there are 3 outputs for energy, one at each layer, and they ### are summed at the end. This requires some adjustment to the behavior from Base.py from typing import Any, Callable, Dict, List, Optional, Type, Union @@ -81,16 +81,16 @@ def __init__( *args, **kwargs, ): - """Notes On MACEStack Arguments:""" + """Notes On MACEStack Arguments:""" # MACE args that we have given definitions for and the reasons why: ## Note: These can be changed in the future if the desired argument options change ## interaction_cls / interaction_cls_first: The choice of interaction block type should not make much of a difference and would require more imports in create.py and/or string handling ## Atomic Energies: This is not agnostic to what we're predicting, which is a requirement of HYDRA. We also don't have base atomic energies to load, so we simply one-hot encode the atomic numbers and train. ## Atomic Numbers / num_elements: It's more robust in preventing errors to just cover the entire periodic table (1-118) - + # MACE args that we have dropped and the resons why: ## pair repulsion, distance_transform, compute_virials, etc: HYDRA's framework is meant to compute based on graph or node type, so must be agnostic to these property specific types of computations - + # MACE args constructed by HYDRA args ## Reasoning: Oftentimes, MACE arguments show similarity to HYDRA arguments, but are labelled differently ## num_interactions is represented by num_conv_layers @@ -109,23 +109,31 @@ def __init__( self.interaction_cls = RealAgnosticAttResidualInteractionBlock self.interaction_cls_first = RealAgnosticAttResidualInteractionBlock self.num_elements = 118 # Number of elements in the periodic table - atomic_numbers = list(range(1, self.num_elements+1)) + atomic_numbers = list(range(1, self.num_elements + 1)) # Optional - num_polynomial_cutoff = 5 if num_polynomial_cutoff is None else num_polynomial_cutoff + num_polynomial_cutoff = ( + 5 if num_polynomial_cutoff is None else num_polynomial_cutoff + ) self.correlation = [2] if correlation is None else correlation radial_type = "bessel" if radial_type is None else radial_type - + # Making Irreps - self.node_attr_irreps = o3.Irreps([(self.num_elements, (0, 1))]) # 118 is the number of elements in the periodic table - self.sh_irreps = o3.Irreps.spherical_harmonics(max_ell) # This makes the irreps string + self.node_attr_irreps = o3.Irreps( + [(self.num_elements, (0, 1))] + ) # 118 is the number of elements in the periodic table + self.sh_irreps = o3.Irreps.spherical_harmonics( + max_ell + ) # This makes the irreps string self.edge_feats_irreps = o3.Irreps(f"{num_bessel}x0e") - + super().__init__(*args, **kwargs) - + self.spherical_harmonics = o3.SphericalHarmonics( - self.sh_irreps, normalize=True, normalization="component" # This makes the spherical harmonic class to be called with forward + self.sh_irreps, + normalize=True, + normalization="component", # This makes the spherical harmonic class to be called with forward ) - + # Register buffers are made when parameters need to be saved and transferred with the model, but not trained. self.register_buffer( "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) @@ -146,48 +154,100 @@ def __init__( distance_transform=None, ) self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=self.node_attr_irreps, irreps_out=create_irreps_string(self.hidden_dim, 0) # Changed this to hidden_dim because no longer had node_feats_irreps + irreps_in=self.node_attr_irreps, + irreps_out=create_irreps_string( + self.hidden_dim, 0 + ), # Changed this to hidden_dim because no longer had node_feats_irreps ) - def _init_conv(self): # Multihead Decoders ## This integrates HYDRA multihead nature with MACE's layer-wise readouts ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance self.multihead_decoders = ModuleList() - hidden_irreps = o3.Irreps(create_irreps_string(self.hidden_dim, self.node_max_ell)) - final_hidden_irreps = o3.Irreps(create_irreps_string(self.hidden_dim, 0)) # Only scalars are outputted in the last layer - + hidden_irreps = o3.Irreps( + create_irreps_string(self.hidden_dim, self.node_max_ell) + ) + final_hidden_irreps = o3.Irreps( + create_irreps_string(self.hidden_dim, 0) + ) # Only scalars are outputted in the last layer + last_layer = 1 == self.num_conv_layers - - self.multihead_decoders.append(MultiheadDecoderBlock(self.node_attr_irreps, self.node_max_ell, self.config_heads, self.head_dims, self.head_type, self.num_heads, self.activation_function, self.num_nodes, nonlinear=True)) # For base-node traits - self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim, first_layer=True)) - self.multihead_decoders.append(MultiheadDecoderBlock(hidden_irreps, self.node_max_ell, self.config_heads, self.head_dims, self.head_type, self.num_heads, self.activation_function, self.num_nodes, nonlinear=last_layer)) + + self.multihead_decoders.append( + MultiheadDecoderBlock( + self.node_attr_irreps, + self.node_max_ell, + self.config_heads, + self.head_dims, + self.head_type, + self.num_heads, + self.activation_function, + self.num_nodes, + nonlinear=True, + ) + ) # For base-node traits + self.graph_convs.append( + self.get_conv(self.input_dim, self.hidden_dim, first_layer=True) + ) + self.multihead_decoders.append( + MultiheadDecoderBlock( + hidden_irreps, + self.node_max_ell, + self.config_heads, + self.head_dims, + self.head_type, + self.num_heads, + self.activation_function, + self.num_nodes, + nonlinear=last_layer, + ) + ) for i in range(self.num_conv_layers - 1): last_layer = i == self.num_conv_layers - 2 - conv = self.get_conv(self.hidden_dim, self.hidden_dim, last_layer=last_layer) + conv = self.get_conv( + self.hidden_dim, self.hidden_dim, last_layer=last_layer + ) self.graph_convs.append(conv) - self.multihead_decoders.append(MultiheadDecoderBlock(final_hidden_irreps, self.node_max_ell, self.config_heads, self.head_dims, self.head_type, self.num_heads, self.activation_function, self.num_nodes, nonlinear=last_layer)) # Last layer will be nonlinear node decoding - + self.multihead_decoders.append( + MultiheadDecoderBlock( + final_hidden_irreps, + self.node_max_ell, + self.config_heads, + self.head_dims, + self.head_type, + self.num_heads, + self.activation_function, + self.num_nodes, + nonlinear=last_layer, + ) + ) # Last layer will be nonlinear node decoding + def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): hidden_dim = output_dim if input_dim == 1 else input_dim - + # All of these should be constructed with HYDRA dimensional arguments ## Radial - radial_MLP_dim = math.ceil(float(hidden_dim) / 3) # Go based off hidden_dim for radial_MLP + radial_MLP_dim = math.ceil( + float(hidden_dim) / 3 + ) # Go based off hidden_dim for radial_MLP radial_MLP = [radial_MLP_dim, radial_MLP_dim, radial_MLP_dim] ## Input, Hidden, and Output irreps sizing (this is usually just hidden in MACE) ### Input dimensions are handled implicitly ### Hidden - hidden_irreps = create_irreps_string(hidden_dim, self.node_max_ell) + hidden_irreps = create_irreps_string(hidden_dim, self.node_max_ell) hidden_irreps = o3.Irreps(hidden_irreps) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() ## This makes it a requirement that hidden irreps all have the same number of channels - interaction_irreps = (self.sh_irreps * num_features).sort()[0].simplify() #.sort() is a tuple, so we need the [0] element for the sorted result + num_features = hidden_irreps.count( + o3.Irrep(0, 1) + ) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() ## This makes it a requirement that hidden irreps all have the same number of channels + interaction_irreps = ( + (self.sh_irreps * num_features).sort()[0].simplify() + ) # .sort() is a tuple, so we need the [0] element for the sorted result ### Output - output_irreps = create_irreps_string(output_dim, self.node_max_ell) + output_irreps = create_irreps_string(output_dim, self.node_max_ell) output_irreps = o3.Irreps(output_irreps) - + # Constructing convolutional layers if first_layer: hidden_irreps_out = hidden_irreps @@ -212,15 +272,13 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): num_elements=self.num_elements, use_sc=use_sc_first, ) - sizing = o3.Linear(hidden_irreps_out, output_irreps) # Change sizing to output_irreps + sizing = o3.Linear( + hidden_irreps_out, output_irreps + ) # Change sizing to output_irreps elif last_layer: # Select only scalars output for last layer - hidden_irreps_out = str( - hidden_irreps[0] - ) - output_irreps = str( - output_irreps[0] - ) + hidden_irreps_out = str(hidden_irreps[0]) + output_irreps = str(output_irreps[0]) inter = self.interaction_cls( node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=hidden_irreps, @@ -238,7 +296,9 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): num_elements=self.num_elements, use_sc=True, ) - sizing = o3.Linear(hidden_irreps_out, output_irreps) # Change sizing to output_irreps + sizing = o3.Linear( + hidden_irreps_out, output_irreps + ) # Change sizing to output_irreps else: hidden_irreps_out = hidden_irreps inter = self.interaction_cls( @@ -258,13 +318,14 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): num_elements=self.num_elements, use_sc=True, ) - sizing = o3.Linear(hidden_irreps_out, output_irreps) # Change sizing to output_irreps + sizing = o3.Linear( + hidden_irreps_out, output_irreps + ) # Change sizing to output_irreps - input_args = "node_attributes, pos, node_features, edge_attributes, edge_features, edge_index" # readout_args = "node_energies" conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward - + if self.use_edge_attr: input_args += ", edge_attr" conv_args += ", edge_attr" @@ -276,7 +337,10 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): (inter, "node_features, " + conv_args + " -> node_features, sc"), (prod, "node_features, sc, node_attributes -> node_features"), (sizing, "node_features -> node_features"), - (lambda node_features, pos: [node_features, pos], "node_features, pos -> node_features, pos"), + ( + lambda node_features, pos: [node_features, pos], + "node_features, pos -> node_features, pos", + ), ], ) else: @@ -286,60 +350,77 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): (inter, "node_features, " + conv_args + " -> node_features, sc"), (prod, "node_features, sc, node_attributes -> node_features"), (sizing, "node_features -> node_features"), - (lambda node_features, pos: [node_features, pos], "node_features, pos -> node_features, pos"), + ( + lambda node_features, pos: [node_features, pos], + "node_features, pos -> node_features, pos", + ), ], ) - + def forward(self, data): data, conv_args = self._conv_args(data) node_features = data.node_features node_attributes = data.node_attributes pos = data.pos - + ### encoder / decoder part #### ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance - + ### There is a readout before the first convolution layer ### outputs = [] - output = self.multihead_decoders[0](data, node_attributes) # [index][n_output, size_output] + output = self.multihead_decoders[0]( + data, node_attributes + ) # [index][n_output, size_output] # Create outputs first outputs = output - + ### Do conv --> readout --> repeat for each convolution layer ### for conv, readout in zip(self.graph_convs, self.multihead_decoders[1:]): if not self.conv_checkpointing: - node_features, pos = conv(node_features=node_features, pos=pos, **conv_args) - output = readout(data, node_features) # [index][n_output, size_output] + node_features, pos = conv( + node_features=node_features, pos=pos, **conv_args + ) + output = readout(data, node_features) # [index][n_output, size_output] else: node_features, pos = checkpoint( - conv, use_reentrant=False, node_features=node_features, pos=pos, **conv_args + conv, + use_reentrant=False, + node_features=node_features, + pos=pos, + **conv_args, ) - output = readout(data, node_features) # output is a list of tensors with [index][n_output, size_output] + output = readout( + data, node_features + ) # output is a list of tensors with [index][n_output, size_output] # Sum predictions for each index, taking care of size differences for idx, prediction in enumerate(output): outputs[idx] = outputs[idx] + prediction - + return outputs def _conv_args(self, data): assert ( data.pos is not None ), "MACE requires node positions (data.pos) to be set." - - # Center positions at 0 per graph. This is a requirement for equivariant models that - # initialize the spherical harmonics, since the initial spherical harmonic projection + + # Center positions at 0 per graph. This is a requirement for equivariant models that + # initialize the spherical harmonics, since the initial spherical harmonic projection # uses the nodal position vector x/||x|| as the input to the spherical harmonics. # If we didn't center at 0, these models wouldn't even be invariant to translation. mean_pos = scatter(data.pos, data.batch, dim=0, reduce="mean") data.pos = data.pos - mean_pos[data.batch] - + # Create node_attrs from atomic numbers. Later on it may contain more information ## Node attrs are intrinsic properties of the atoms, like charge, atomic number, etc.. ## data.node_attrs is already used as a method or smt in another place, so has been renamed to data.node_attributes from MACE and same with other data variable names - one_hot = torch.nn.functional.one_hot(data["x"].long().squeeze(-1), num_classes=118).float() # [n_atoms, 118] ## 118 atoms in the peridoic table + one_hot = torch.nn.functional.one_hot( + data["x"].long().squeeze(-1), num_classes=118 + ).float() # [n_atoms, 118] ## 118 atoms in the peridoic table data.node_attributes = one_hot # To-Do: Add more information to node_attrs - data.shifts = torch.zeros((data.edge_index.shape[1], 3), dtype=data.pos.dtype, device=data.pos.device) # Shifts takes into account pbc conditions, but I believe we already generate data.pos to take it into account - + data.shifts = torch.zeros( + (data.edge_index.shape[1], 3), dtype=data.pos.dtype, device=data.pos.device + ) # Shifts takes into account pbc conditions, but I believe we already generate data.pos to take it into account + # Embeddings node_feats = self.node_embedding(data["node_attributes"]) vectors, lengths = get_edge_vectors_and_lengths( @@ -351,13 +432,13 @@ def _conv_args(self, data): edge_features = self.radial_embedding( lengths, data["node_attributes"], data["edge_index"], self.atomic_numbers ) - + # Variable names data.node_features = node_feats data.edge_attributes = edge_attributes data.edge_features = edge_features data.lengths = lengths - + conv_args = { "node_attributes": data.node_attributes, "edge_attributes": data.edge_attributes, @@ -366,8 +447,7 @@ def _conv_args(self, data): } return data, conv_args - - + def _multihead(self): # NOTE Multihead is skipped as it's an integral part of MACE's architecture to have a decoder after every layer, # and a convolutional layer in decoding is not supported. Therefore, this final step is not necessary for MACE. @@ -376,17 +456,29 @@ def _multihead(self): def __str__(self): return "MACEStack" - - - -def create_irreps_string(n: int, ell: int): # Custom function to allow for use of HYDRA arguments in creating irreps - irreps = [f"{n}x{ell}{'e' if ell % 2 == 0 else 'o'}" for ell in range(ell + 1)] - return " + ".join(irreps) - + + +def create_irreps_string( + n: int, ell: int +): # Custom function to allow for use of HYDRA arguments in creating irreps + irreps = [f"{n}x{ell}{'e' if ell % 2 == 0 else 'o'}" for ell in range(ell + 1)] + return " + ".join(irreps) + @compile_mode("script") class MultiheadDecoderBlock(torch.nn.Module): - def __init__(self, input_irreps, node_max_ell, config_heads, head_dims, head_type, num_heads, activation_function, num_nodes, nonlinear=False): + def __init__( + self, + input_irreps, + node_max_ell, + config_heads, + head_dims, + head_type, + num_heads, + activation_function, + num_nodes, + nonlinear=False, + ): super(MultiheadDecoderBlock, self).__init__() self.input_irreps = input_irreps self.node_max_ell = node_max_ell if not nonlinear else 0 @@ -403,15 +495,25 @@ def __init__(self, input_irreps, node_max_ell, config_heads, head_dims, head_typ # Create shared dense layers for graph-level output if applicable if "graph" in self.config_heads: - graph_input_irreps = o3.Irreps(f"{self.input_irreps.count(o3.Irrep(0, 1))}x0e") + graph_input_irreps = o3.Irreps( + f"{self.input_irreps.count(o3.Irrep(0, 1))}x0e" + ) dim_sharedlayers = self.config_heads["graph"]["dim_sharedlayers"] sharedlayers_irreps = o3.Irreps(f"{dim_sharedlayers}x0e") denselayers = [] denselayers.append(o3.Linear(graph_input_irreps, sharedlayers_irreps)) - denselayers.append(nn.Activation(irreps_in=sharedlayers_irreps, acts=[self.activation_function])) + denselayers.append( + nn.Activation( + irreps_in=sharedlayers_irreps, acts=[self.activation_function] + ) + ) for _ in range(self.config_heads["graph"]["num_sharedlayers"] - 1): denselayers.append(o3.Linear(sharedlayers_irreps, sharedlayers_irreps)) - denselayers.append(nn.Activation(irreps_in=sharedlayers_irreps, acts=[self.activation_function])) + denselayers.append( + nn.Activation( + irreps_in=sharedlayers_irreps, acts=[self.activation_function] + ) + ) self.graph_shared = Sequential(*denselayers) # Create layers for each head @@ -422,12 +524,20 @@ def __init__(self, input_irreps, node_max_ell, config_heads, head_dims, head_typ denselayers = [] head_hidden_irreps = o3.Irreps(f"{hidden_dim_graph[0]}x0e") denselayers.append(o3.Linear(sharedlayers_irreps, head_hidden_irreps)) - denselayers.append(nn.Activation(irreps_in=head_hidden_irreps, acts=[self.activation_function])) + denselayers.append( + nn.Activation( + irreps_in=head_hidden_irreps, acts=[self.activation_function] + ) + ) for ilayer in range(num_layers_graph - 1): input_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer]}x0e") output_irreps = o3.Irreps(f"{hidden_dim_graph[ilayer + 1]}x0e") denselayers.append(o3.Linear(input_irreps, output_irreps)) - denselayers.append(nn.Activation(irreps_in=output_irreps, acts=[self.activation_function])) + denselayers.append( + nn.Activation( + irreps_in=output_irreps, acts=[self.activation_function] + ) + ) input_irreps = o3.Irreps(f"{hidden_dim_graph[-1]}x0e") output_irreps = o3.Irreps(f"{self.head_dims[ihead]}x0e") denselayers.append(o3.Linear(input_irreps, output_irreps)) @@ -453,27 +563,38 @@ def __init__(self, input_irreps, node_max_ell, config_heads, head_dims, head_typ self.num_nodes, self.config_heads["node"]["type"], self.activation_function, - nonlinear=nonlinear + nonlinear=nonlinear, ) self.heads.append(head) else: - raise ValueError(f"Unknown head NN structure for node features: {self.node_NN_type}") + raise ValueError( + f"Unknown head NN structure for node features: {self.node_NN_type}" + ) else: - raise ValueError(f"Unknown head type: {self.head_type[ihead]}; supported types are 'graph' or 'node'") - + raise ValueError( + f"Unknown head type: {self.head_type[ihead]}; supported types are 'graph' or 'node'" + ) + def forward(self, data, node_features): if data.batch is None: - graph_features = node_features[:,:self.hidden_dim].mean(dim=0, keepdim=True) # Need to take only the type-0 irreps for aggregation + graph_features = node_features[:, : self.hidden_dim].mean( + dim=0, keepdim=True + ) # Need to take only the type-0 irreps for aggregation else: - graph_features = global_mean_pool(node_features[:,:self.input_irreps.count(o3.Irrep(0, 1))], data.batch.to(node_features.device)) + graph_features = global_mean_pool( + node_features[:, : self.input_irreps.count(o3.Irrep(0, 1))], + data.batch.to(node_features.device), + ) outputs = [] - for (headloc, type_head) in zip(self.heads, self.head_type): + for headloc, type_head in zip(self.heads, self.head_type): if type_head == "graph": x_graph_head = self.graph_shared(graph_features) outputs.append(headloc(x_graph_head)) else: # Node-level output if self.node_NN_type == "conv": - raise ValueError("Node-level convolutional layers are not supported in MACE") + raise ValueError( + "Node-level convolutional layers are not supported in MACE" + ) else: x_node = headloc(node_features, data.batch) outputs.append(x_node) @@ -494,7 +615,7 @@ def __init__( num_nodes, node_type, activation_function, - nonlinear=False + nonlinear=False, ): super().__init__() self.input_irreps = input_irreps @@ -519,18 +640,26 @@ def __init__( hidden_irreps = o3.Irreps(f"{hidden_dims[0]}x0e") # Hidden irreps denselayers.append(o3.Linear(input_irreps, hidden_irreps)) - denselayers.append(nn.Activation(irreps_in=hidden_irreps, acts=[self.activation_function])) + denselayers.append( + nn.Activation(irreps_in=hidden_irreps, acts=[self.activation_function]) + ) # Add intermediate layers for ilayer in range(self.num_layers - 1): input_irreps = o3.Irreps(f"{hidden_dims[ilayer]}x0e") hidden_irreps = o3.Irreps(f"{hidden_dims[ilayer + 1]}x0e") denselayers.append(o3.Linear(input_irreps, hidden_irreps)) - denselayers.append(nn.Activation(irreps_in=hidden_irreps, acts=[self.activation_function])) + denselayers.append( + nn.Activation( + irreps_in=hidden_irreps, acts=[self.activation_function] + ) + ) # Last layer hidden_irreps = o3.Irreps(f"{hidden_dims[-1]}x0e") - output_irreps = o3.Irreps(f"{self.output_dim}x0e") # Assuming head_dims has been passed for the final output + output_irreps = o3.Irreps( + f"{self.output_dim}x0e" + ) # Assuming head_dims has been passed for the final output denselayers.append(o3.Linear(hidden_irreps, output_irreps)) # Append to MLP @@ -555,7 +684,10 @@ def forward(self, node_features: torch.Tensor, batch: torch.Tensor): outs = self.mlp[0](node_features) else: outs = torch.zeros( - (node_features.shape[0], self.head_dims[0]), # Assuming `head_dims` defines the final output dimension + ( + node_features.shape[0], + self.head_dims[0], + ), # Assuming `head_dims` defines the final output dimension dtype=node_features.dtype, device=node_features.device, ) diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index ce19024a5..c37e0881e 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -55,7 +55,7 @@ def update_config(config, train_loader, val_loader, test_loader): config["NeuralNetwork"]["Architecture"]["max_neighbours"] = len(deg) - 1 else: config["NeuralNetwork"]["Architecture"]["pna_deg"] = None - + if config["NeuralNetwork"]["Architecture"]["model_type"] == "MACE": if hasattr(train_loader.dataset, "avg_num_neighbors"): ## Use avg neighbours used in the dataset. diff --git a/hydragnn/utils/mace_utils/modules/blocks.py b/hydragnn/utils/mace_utils/modules/blocks.py index 698a38678..b77f4811d 100644 --- a/hydragnn/utils/mace_utils/modules/blocks.py +++ b/hydragnn/utils/mace_utils/modules/blocks.py @@ -63,7 +63,9 @@ def __init__( super().__init__() self.hidden_irreps = MLP_irreps self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) - self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) # Need to adjust this to actually use the gate + self.non_linearity = nn.Activation( + irreps_in=self.hidden_irreps, acts=[gate] + ) # Need to adjust this to actually use the gate self.linear_2 = o3.Linear( irreps_in=self.hidden_irreps, irreps_out=o3.Irreps("0e") ) @@ -148,17 +150,23 @@ def forward( def __repr__(self): formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) return f"{self.__class__.__name__}(energies=[{formatted_energies}])" - + + @compile_mode("script") class AtomicBlock(torch.nn.Module): def __init__(self, output_dim): super().__init__() # Initialize the atomic energies as a trainable parameter - self.atomic_energies = torch.nn.Parameter(torch.randn(118, output_dim)) # There are 118 known elements + self.atomic_energies = torch.nn.Parameter( + torch.randn(118, output_dim) + ) # There are 118 known elements def forward(self, atomic_numbers): # Perform the linear multiplication (no bias) - return atomic_numbers @ self.atomic_energies # Output will now have shape [batch_size, output_dim] + return ( + atomic_numbers @ self.atomic_energies + ) # Output will now have shape [batch_size, output_dim] + @compile_mode("script") class RadialEmbeddingBlock(torch.nn.Module): @@ -316,7 +324,9 @@ def __repr__(self): @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.node_feats_down_irreps = o3.Irreps("64x0e") # Interesting, this seems to be a required shaping + self.node_feats_down_irreps = o3.Irreps( + "64x0e" + ) # Interesting, this seems to be a required shaping # First linear self.linear_up = o3.Linear( self.node_feats_irreps, @@ -354,12 +364,14 @@ def _setup(self) -> None: ## It is worth double-checking, but I believe this means that type 0 (scalar) irreps ## are being embedded by 3 layers of size 256 and the output dim, then activated. self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], + [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], torch.nn.functional.silu, ) # Linear - irreps_mid = irreps_mid.simplify() # .simplify() essentially combines irreps of the same type so that normalization is done across them all. The site has an in-depth explanation + irreps_mid = ( + irreps_mid.simplify() + ) # .simplify() essentially combines irreps of the same type so that normalization is done across them all. The site has an in-depth explanation self.irreps_out = self.target_irreps self.linear = o3.Linear( irreps_mid, @@ -371,7 +383,9 @@ def _setup(self) -> None: self.reshape = reshape_irreps(self.irreps_out) # Skip connection. - self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) # This will be size (num_nodes, 64*9) when there are 64 channels and irreps, 0, 1, 2 (1+3+5=9) ## This becomes sc + self.skip_linear = o3.Linear( + self.node_feats_irreps, self.hidden_irreps + ) # This will be size (num_nodes, 64*9) when there are 64 channels and irreps, 0, 1, 2 (1+3+5=9) ## This becomes sc def forward( self, diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 1ea5fff05..ea3c0c333 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -215,8 +215,10 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): unittest_train_model(model_type, ci_input, False, overwrite_data) -# Test only models -@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"]) +# Test only models +@pytest.mark.parametrize( + "model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"] +) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) @@ -256,4 +258,3 @@ def pytest_train_model_conv_head(model_type, overwrite_data=False): # if __name__ == "__main__": # # Manually call the function with the desired parameters for debugging in VSCode # debug_train_model_vectoroutput(model_type="MACE", overwrite_data=True) - From e5ba82fce4b4f170a51c60882770afad85681d9a Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 11:27:47 -0400 Subject: [PATCH 05/51] revise library downloads --- requirements-torch.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements-torch.txt b/requirements-torch.txt index b955d1b21..71041d97a 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,5 +1,5 @@ -torch==2.0.1 -torchvision +torch==2.0.1 +torchvision torchaudio -torch-ema +git+https://github.com/fadel/pytorch_ema torchmetrics From 1c34beda59e3314279a8c51d155baf7c559ea08a Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 11:32:11 -0400 Subject: [PATCH 06/51] formatting and typo --- hydragnn/utils/model/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/hydragnn/utils/model/model.py b/hydragnn/utils/model/model.py index 768ac1ff2..8dacafc7d 100644 --- a/hydragnn/utils/model/model.py +++ b/hydragnn/utils/model/model.py @@ -122,7 +122,7 @@ def load_existing_model( model.load_checkpoint(os.path.join(path, model_name), model_name) -## This function may cause OOM if datasets is too large +## These functions may cause OOM if dataset is too large ## to fit in a single GPU (i.e., with DDP). Use with caution. ## Recommend to use calculate_PNA_degree_dist or calculate_avg_deg_dist def calculate_PNA_degree(loader, max_neighbours): @@ -137,7 +137,8 @@ def calculate_PNA_degree(loader, max_neighbours): d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg += torch.bincount(d, minlength=deg.numel())[: max_neighbours + 1] return deg - + + def calculate_avg_deg(loader): backend = os.getenv("HYDRAGNN_AGGR_BACKEND", "torch") if backend == "torch": @@ -165,6 +166,7 @@ def calculate_PNA_degree_dist(loader, max_neighbours): deg = deg.detach().cpu() return deg + def calculate_avg_deg_dist(loader): assert dist.is_initialized() deg = 0 @@ -193,6 +195,7 @@ def calculate_PNA_degree_mpi(loader, max_neighbours): deg = MPI.COMM_WORLD.allreduce(deg.numpy(), op=MPI.SUM) return torch.tensor(deg) + def calculate_avg_deg_mpi(loader): assert dist.is_initialized() deg = 0 From 258bf78ecc0acebcc1437f7b4a77bbd631675496 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 11:35:30 -0400 Subject: [PATCH 07/51] installation change --- requirements-torch.txt | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements-torch.txt b/requirements-torch.txt index 71041d97a..cbf4e15fc 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,5 +1,6 @@ torch==2.0.1 torchvision torchaudio +wheel # Pre-requisite for pytorch-ema git+https://github.com/fadel/pytorch_ema torchmetrics diff --git a/requirements.txt b/requirements.txt index 75ed57e75..2fc81d6df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ psutil sympy e3nn matscipy +wheel From 239629193fe627631b20e0cd12e324c3ed584be0 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 11:54:45 -0400 Subject: [PATCH 08/51] change versioning of torchmetrics for github platform. GitHub cannot find lightning-utilities version above 0.7, which is a dependency of torchmetrics --- requirements-torch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-torch.txt b/requirements-torch.txt index cbf4e15fc..48c92ab20 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -3,4 +3,4 @@ torchvision torchaudio wheel # Pre-requisite for pytorch-ema git+https://github.com/fadel/pytorch_ema -torchmetrics +torchmetrics==1.3.2 From de269a14c9aeb13f1685cba714b609266128ddc6 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 12:00:14 -0400 Subject: [PATCH 09/51] GitHub only has access to some torchmetrics version --- requirements-torch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-torch.txt b/requirements-torch.txt index 48c92ab20..7e0178202 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -3,4 +3,4 @@ torchvision torchaudio wheel # Pre-requisite for pytorch-ema git+https://github.com/fadel/pytorch_ema -torchmetrics==1.3.2 +torchmetrics==1.0.3 From 3de20069cd89c7bce7ccbd98fa59d3f9d93a8d33 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 12:58:47 -0400 Subject: [PATCH 10/51] testing for issue with index url and torch-ema/torchmetrics --- requirements-torch.txt | 3 --- requirements.txt | 3 ++- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/requirements-torch.txt b/requirements-torch.txt index 7e0178202..41211d9e3 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,6 +1,3 @@ torch==2.0.1 torchvision torchaudio -wheel # Pre-requisite for pytorch-ema -git+https://github.com/fadel/pytorch_ema -torchmetrics==1.0.3 diff --git a/requirements.txt b/requirements.txt index 2fc81d6df..72f254d8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ psutil sympy e3nn matscipy -wheel +torch-ema +torchmetrics From a76fd4feba54457db5d21f2265387bc7c2259e8b Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 13:09:56 -0400 Subject: [PATCH 11/51] fix installs to account for index-url and formatting --- .../utils/mace_utils/tools/scripts_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/hydragnn/utils/mace_utils/tools/scripts_utils.py b/hydragnn/utils/mace_utils/tools/scripts_utils.py index 27455944b..bca80bf64 100644 --- a/hydragnn/utils/mace_utils/tools/scripts_utils.py +++ b/hydragnn/utils/mace_utils/tools/scripts_utils.py @@ -233,9 +233,9 @@ def convert_from_json_format(dict_input): dict_input["interaction_cls"] == "" ): - dict_output["interaction_cls"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) + dict_output[ + "interaction_cls" + ] = modules.blocks.RealAgnosticResidualInteractionBlock if ( dict_input["interaction_cls"] == "" @@ -245,16 +245,16 @@ def convert_from_json_format(dict_input): dict_input["interaction_cls_first"] == "" ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) + dict_output[ + "interaction_cls_first" + ] = modules.blocks.RealAgnosticResidualInteractionBlock if ( dict_input["interaction_cls_first"] == "" ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticInteractionBlock - ) + dict_output[ + "interaction_cls_first" + ] = modules.blocks.RealAgnosticInteractionBlock dict_output["r_max"] = float(dict_input["r_max"]) dict_output["num_bessel"] = int(dict_input["num_bessel"]) dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) From e9ed32ced0fdc0c50b89ad1d5e3750e3bc044fb3 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 13:24:07 -0400 Subject: [PATCH 12/51] formatting --- hydragnn/utils/mace_utils/tools/cg.py | 4 +- .../mace_utils/tools/finetuning_utils.py | 58 +++++++++---------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/hydragnn/utils/mace_utils/tools/cg.py b/hydragnn/utils/mace_utils/tools/cg.py index 2cca09c94..6c1b94864 100644 --- a/hydragnn/utils/mace_utils/tools/cg.py +++ b/hydragnn/utils/mace_utils/tools/cg.py @@ -52,9 +52,9 @@ def _wigner_nj( C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) if normalization == "component": - C *= ir_out.dim**0.5 + C *= ir_out.dim ** 0.5 if normalization == "norm": - C *= ir_left.dim**0.5 * ir.dim**0.5 + C *= ir_left.dim ** 0.5 * ir.dim ** 0.5 C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) C = C.reshape( diff --git a/hydragnn/utils/mace_utils/tools/finetuning_utils.py b/hydragnn/utils/mace_utils/tools/finetuning_utils.py index 9443add6f..0b214a287 100644 --- a/hydragnn/utils/mace_utils/tools/finetuning_utils.py +++ b/hydragnn/utils/mace_utils/tools/finetuning_utils.py @@ -50,24 +50,24 @@ def load_foundations( for j in range(4): # Assuming 4 layers in conv_tp_weights, layer_name = f"layer{j}" if j == 0: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ) - .weight[:num_radial, :] - .clone() + getattr( + model.interactions[i].conv_tp_weights, layer_name + ).weight = torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, ) + .weight[:num_radial, :] + .clone() ) else: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() - ) + getattr( + model.interactions[i].conv_tp_weights, layer_name + ).weight = torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() ) model.interactions[i].linear.weight = torch.nn.Parameter( @@ -105,23 +105,23 @@ def load_foundations( for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[j].weights_max = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) + model.products[i].symmetric_contractions.contractions[ + j + ].weights_max = torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights_max[indices_weights, :, :] + .clone() ) for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[k] = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] - .clone() - ) + model.products[i].symmetric_contractions.contractions[j].weights[ + k + ] = torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() ) model.products[i].linear.weight = torch.nn.Parameter( From a9d06a2abdad2df728e2bd3dec24e121d4cc48e9 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 13:35:04 -0400 Subject: [PATCH 13/51] try/except import to remove dependency --- hydragnn/utils/mace_utils/data/hdf5_dataset.py | 6 +++++- hydragnn/utils/mace_utils/data/utils.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/hydragnn/utils/mace_utils/data/hdf5_dataset.py b/hydragnn/utils/mace_utils/data/hdf5_dataset.py index 12d0ff295..df741fa45 100644 --- a/hydragnn/utils/mace_utils/data/hdf5_dataset.py +++ b/hydragnn/utils/mace_utils/data/hdf5_dataset.py @@ -1,8 +1,12 @@ from glob import glob from typing import List -import h5py from torch.utils.data import ConcatDataset, Dataset +# Try import but pass otherwise +try: + import h5py +except ImportError: + pass from hydragnn.utils.mace_utils.data.atomic_data import AtomicData from hydragnn.utils.mace_utils.data.utils import Configuration diff --git a/hydragnn/utils/mace_utils/data/utils.py b/hydragnn/utils/mace_utils/data/utils.py index 4b008079b..8cc442cc6 100644 --- a/hydragnn/utils/mace_utils/data/utils.py +++ b/hydragnn/utils/mace_utils/data/utils.py @@ -10,8 +10,12 @@ import ase.data import ase.io -import h5py import numpy as np +# Try import but pass otherwise +try: + import h5py +except ImportError: + pass from hydragnn.utils.mace_utils.tools import AtomicNumberTable From 5a2da6fb785b356506acde4dcce0bd153b4a69c5 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 13:37:30 -0400 Subject: [PATCH 14/51] formatting --- hydragnn/utils/mace_utils/data/hdf5_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hydragnn/utils/mace_utils/data/hdf5_dataset.py b/hydragnn/utils/mace_utils/data/hdf5_dataset.py index df741fa45..affa6a8d5 100644 --- a/hydragnn/utils/mace_utils/data/hdf5_dataset.py +++ b/hydragnn/utils/mace_utils/data/hdf5_dataset.py @@ -2,6 +2,7 @@ from typing import List from torch.utils.data import ConcatDataset, Dataset + # Try import but pass otherwise try: import h5py From 4127d086cf740b7c6e0f40a4edceca09dc9c742e Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 13:45:30 -0400 Subject: [PATCH 15/51] formatting --- hydragnn/utils/mace_utils/data/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hydragnn/utils/mace_utils/data/utils.py b/hydragnn/utils/mace_utils/data/utils.py index 8cc442cc6..6458e7107 100644 --- a/hydragnn/utils/mace_utils/data/utils.py +++ b/hydragnn/utils/mace_utils/data/utils.py @@ -11,6 +11,7 @@ import ase.data import ase.io import numpy as np + # Try import but pass otherwise try: import h5py From f6c32cfe212baf261425b1974a98c12fa166b31a Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 19:12:08 -0400 Subject: [PATCH 16/51] Add MACE to test --- tests/test_forces_equivariant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_forces_equivariant.py b/tests/test_forces_equivariant.py index 4609844c1..3f0b51218 100644 --- a/tests/test_forces_equivariant.py +++ b/tests/test_forces_equivariant.py @@ -16,7 +16,7 @@ @pytest.mark.parametrize("example", ["LennardJones"]) -@pytest.mark.parametrize("model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus"]) +@pytest.mark.parametrize("model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus", "MACE"]) @pytest.mark.mpi_skip() def pytest_examples(example, model_type): path = os.path.join(os.path.dirname(__file__), "..", "examples", example) From ebae4d0266ff279c69ef42bc64d2c89d50e7cfe3 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 19:14:23 -0400 Subject: [PATCH 17/51] testing --- examples/LennardJones.py | 327 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 examples/LennardJones.py diff --git a/examples/LennardJones.py b/examples/LennardJones.py new file mode 100644 index 000000000..2f4e57774 --- /dev/null +++ b/examples/LennardJones.py @@ -0,0 +1,327 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +# General +import os, json +import logging +import sys +import argparse + +# Torch +import torch + +# torch.set_default_tensor_type(torch.DoubleTensor) +# torch.set_default_dtype(torch.float64) + +# Distributed +import mpi4py +from mpi4py import MPI + +mpi4py.rc.thread_level = "serialized" +mpi4py.rc.threads = False + +# HydraGNN +import hydragnn +from hydragnn.utils.print_utils import log +from hydragnn.utils.time_utils import Timer +import hydragnn.utils.tracer as tr +from hydragnn.preprocess.load_data import split_dataset +from hydragnn.utils.distdataset import DistDataset +from hydragnn.utils.pickledataset import SimplePickleWriter, SimplePickleDataset +from hydragnn.preprocess.utils import gather_deg + +try: + from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset +except ImportError: + pass + +# Lennard Jones +from LJ_data import create_dataset, LJDataset, info + + +################################################################################################################## + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--sampling", type=float, help="sampling ratio", default=None) + parser.add_argument( + "--preonly", + action="store_true", + help="preprocess only (no training)", + ) + parser.add_argument("--inputfile", help="input file", type=str, default="LJ.json") + parser.add_argument("--model_type", help="model type", type=str, default=None) + parser.add_argument("--mae", action="store_true", help="do mae calculation") + parser.add_argument("--ddstore", action="store_true", help="ddstore dataset") + parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None) + parser.add_argument("--shmem", action="store_true", help="shmem") + parser.add_argument("--log", help="log name") + parser.add_argument("--batch_size", type=int, help="batch_size", default=None) + parser.add_argument("--everyone", action="store_true", help="gptimer") + + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--adios", + help="Adios dataset", + action="store_const", + dest="format", + const="adios", + ) + group.add_argument( + "--pickle", + help="Pickle dataset", + action="store_const", + dest="format", + const="pickle", + ) + parser.set_defaults(format="pickle") # Changed this for my PC + args = parser.parse_args() + + graph_feature_names = ["total_energy"] + graph_feature_dims = [1] + node_feature_names = ["atomic_number", "potential", "forces"] + node_feature_dims = [1, 1, 3] + dirpwd = os.path.dirname(os.path.abspath(__file__)) + ################################################################################################################## + input_filename = os.path.join(dirpwd, args.inputfile) + ################################################################################################################## + # Configurable run choices (JSON file that accompanies this example script). + with open(input_filename, "r") as f: + config = json.load(f) + config["NeuralNetwork"]["Architecture"]["model_type"] = ( + args.model_type + if args.model_type + else config["NeuralNetwork"]["Architecture"]["model_type"] + ) + verbosity = config["Verbosity"]["level"] + config["NeuralNetwork"]["Variables_of_interest"][ + "graph_feature_names" + ] = graph_feature_names + config["NeuralNetwork"]["Variables_of_interest"][ + "graph_feature_dims" + ] = graph_feature_dims + config["NeuralNetwork"]["Variables_of_interest"][ + "node_feature_names" + ] = node_feature_names + config["NeuralNetwork"]["Variables_of_interest"][ + "node_feature_dims" + ] = node_feature_dims + + if args.batch_size is not None: + config["NeuralNetwork"]["Training"]["batch_size"] = args.batch_size + + ################################################################################################################## + # Always initialize for multi-rank training. + comm_size, rank = hydragnn.utils.setup_ddp() + ################################################################################################################## + + comm = MPI.COMM_WORLD + + ## Set up logging + logging.basicConfig( + level=logging.INFO, + format="%%(levelname)s (rank %d): %%(message)s" % (rank), + datefmt="%H:%M:%S", + ) + + log_name = "LJ" if args.log is None else args.log + hydragnn.utils.setup_log(log_name) + writer = hydragnn.utils.get_summary_writer(log_name) + + log("Command: {0}\n".format(" ".join([x for x in sys.argv])), rank=0) + + modelname = "LJ" + # Check for dataset for each format + if args.format == "pickle": + basedir = os.path.join(dirpwd, "dataset", "%s.pickle" % modelname) + dataset_exists = os.path.exists(os.path.join(dirpwd, "dataset/LJ.pickle")) + if args.format == "adios": + fname = os.path.join(dirpwd, "./dataset/%s.bp" % modelname) + dataset_exists = os.path.exists( + os.path.join(dirpwd, "dataset", "%s.bp" % modelname) + ) + + # Create dataset if preonly specified or dataset does not exist + if not dataset_exists: + + ## local data + create_dataset(os.path.join(dirpwd, "dataset/data"), config) + total = LJDataset( + os.path.join(dirpwd, "dataset/data"), + config, + dist=True, + ) + ## This is a local split + trainset, valset, testset = split_dataset( + dataset=total, + perc_train=config["NeuralNetwork"]["Training"]["perc_train"], + stratify_splitting=False, + ) + print("Local splitting: ", len(total), len(trainset), len(valset), len(testset)) + + deg = gather_deg(trainset) + config["pna_deg"] = deg.tolist() + + setnames = ["trainset", "valset", "testset"] + + if args.format == "pickle": + + ## pickle + attrs = dict() + attrs["pna_deg"] = deg + SimplePickleWriter( + trainset, + basedir, + "trainset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + attrs=attrs, + ) + SimplePickleWriter( + valset, + basedir, + "valset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + ) + SimplePickleWriter( + testset, + basedir, + "testset", + # minmax_node_feature=total.minmax_node_feature, + # minmax_graph_feature=total.minmax_graph_feature, + use_subdir=True, + ) + + if args.format == "adios": + ## adios + adwriter = AdiosWriter(fname, comm) + adwriter.add("trainset", trainset) + adwriter.add("valset", valset) + adwriter.add("testset", testset) + # adwriter.add_global("minmax_node_feature", total.minmax_node_feature) + # adwriter.add_global("minmax_graph_feature", total.minmax_graph_feature) + adwriter.add_global("pna_deg", deg) + adwriter.save() + + tr.initialize() + tr.disable() + timer = Timer("load_data") + timer.start() + if args.format == "adios": + info("Adios load") + assert not (args.shmem and args.ddstore), "Cannot use both ddstore and shmem" + opt = { + "preload": False, + "shmem": args.shmem, + "ddstore": args.ddstore, + "ddstore_width": args.ddstore_width, + } + trainset = AdiosDataset(fname, "trainset", comm, **opt) + valset = AdiosDataset(fname, "valset", comm, **opt) + testset = AdiosDataset(fname, "testset", comm, **opt) + elif args.format == "pickle": + info("Pickle load") + var_config = config["NeuralNetwork"]["Variables_of_interest"] + trainset = SimplePickleDataset( + basedir=basedir, label="trainset", preload=True, var_config=var_config + ) + valset = SimplePickleDataset( + basedir=basedir, label="valset", var_config=var_config + ) + testset = SimplePickleDataset( + basedir=basedir, label="testset", var_config=var_config + ) + # minmax_node_feature = trainset.minmax_node_feature + # minmax_graph_feature = trainset.minmax_graph_feature + pna_deg = trainset.pna_deg + if args.ddstore: + opt = {"ddstore_width": args.ddstore_width} + trainset = DistDataset(trainset, "trainset", comm, **opt) + valset = DistDataset(valset, "valset", comm, **opt) + testset = DistDataset(testset, "testset", comm, **opt) + # trainset.minmax_node_feature = minmax_node_feature + # trainset.minmax_graph_feature = minmax_graph_feature + trainset.pna_deg = pna_deg + else: + raise NotImplementedError("No supported format: %s" % (args.format)) + + info( + "trainset,valset,testset size: %d %d %d" + % (len(trainset), len(valset), len(testset)) + ) + + if args.ddstore: + os.environ["HYDRAGNN_AGGR_BACKEND"] = "mpi" + os.environ["HYDRAGNN_USE_ddstore"] = "1" + + (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( + trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] + ) + + config = hydragnn.utils.update_config(config, train_loader, val_loader, test_loader) + ## Good to sync with everyone right after DDStore setup + comm.Barrier() + + hydragnn.utils.save_config(config, log_name) + + timer.stop() + + model = hydragnn.models.create_model_config( + config=config["NeuralNetwork"], + verbosity=verbosity, + ) + model = hydragnn.utils.get_distributed_model(model, verbosity) + + learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"] + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001 + ) + + hydragnn.utils.load_existing_model_config( + model, config["NeuralNetwork"]["Training"], optimizer=optimizer + ) + + ################################################################################################################## + + hydragnn.train.train_validate_test( + model, + optimizer, + train_loader, + val_loader, + test_loader, + writer, + scheduler, + config["NeuralNetwork"], + log_name, + verbosity, + create_plots=False, + compute_grad_energy=True, + ) + + hydragnn.utils.save_model(model, optimizer, log_name) + hydragnn.utils.print_timers(verbosity) + + if tr.has("GPTLTracer"): + import gptl4py as gp + + eligible = rank if args.everyone else 0 + if rank == eligible: + gp.pr_file(os.path.join("logs", log_name, "gp_timing.p%d" % rank)) + gp.pr_summary_file(os.path.join("logs", log_name, "gp_timing.summary")) + gp.finalize() + sys.exit(0) From 1039592076711b4f32134c0cf3896ed7e4af26a8 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 19:15:03 -0400 Subject: [PATCH 18/51] revert separate attempt to get test_forces in --- examples/LJ.json | 75 +++++ examples/LJ_data.py | 504 +++++++++++++++++++++++++++++++ examples/LJ_inference_plots.py | 241 +++++++++++++++ examples/LennardJones.py | 327 -------------------- tests/test_forces_equivariant.py | 27 -- 5 files changed, 820 insertions(+), 354 deletions(-) create mode 100644 examples/LJ.json create mode 100644 examples/LJ_data.py create mode 100644 examples/LJ_inference_plots.py delete mode 100644 examples/LennardJones.py delete mode 100644 tests/test_forces_equivariant.py diff --git a/examples/LJ.json b/examples/LJ.json new file mode 100644 index 000000000..a6b18f12b --- /dev/null +++ b/examples/LJ.json @@ -0,0 +1,75 @@ +{ + "Verbosity": { + "level": 2 + }, + "Dataset": { + "name": "LJdataset", + "format": "XYZ", + "node_features": { + "name": ["atom_type"], + "dim": [1], + "column_index": [0] + }, + "graph_features":{ + "name": ["total_energy"], + "dim": [1], + "column_index": [0] + } + }, + "NeuralNetwork": { + "Architecture": { + "periodic_boundary_conditions": true, + "model_type": "DimeNet", + "radius": 5.0, + "max_neighbours": 5, + "int_emb_size": 32, + "out_emb_size": 16, + "basis_emb_size": 8, + "num_gaussians": 10, + "num_filters": 8, + "num_before_skip": 1, + "num_after_skip": 1, + "envelope_exponent": 5, + "num_radial": 5, + "num_spherical": 2, + "hidden_dim": 20, + "num_conv_layers": 4, + "output_heads": { + "node": { + "num_headlayers": 2, + "dim_headlayers": [60,20], + "type": "mlp" + } + }, + "task_weights": [1] + }, + "Variables_of_interest": { + "input_node_features": [0], + "output_index": [ + 0 + ], + "type": [ + "node" + ], + "output_dim": [1], + "output_names": ["graph_energy"] + }, + "Training": { + "num_epoch": 15, + "batch_size": 64, + "perc_train": 0.7, + "patience": 20, + "early_stopping": true, + "Optimizer": { + "type": "Adam", + "learning_rate": 0.005 + }, + "conv_checkpointing": false + } + }, + "Visualization": { + "plot_init_solution": true, + "plot_hist_solution": true, + "create_plots": true + } +} diff --git a/examples/LJ_data.py b/examples/LJ_data.py new file mode 100644 index 000000000..594d6d154 --- /dev/null +++ b/examples/LJ_data.py @@ -0,0 +1,504 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +# General +import os +import logging +import numpy + +numpy.set_printoptions(threshold=numpy.inf) +numpy.set_printoptions(linewidth=numpy.inf) + +# Torch +import torch +from torch_geometric.data import Data + +# torch.set_default_tensor_type(torch.DoubleTensor) +# torch.set_default_dtype(torch.float64) + +# Distributed +import mpi4py +from mpi4py import MPI + +mpi4py.rc.thread_level = "serialized" +mpi4py.rc.threads = False + +# HydraGNN +from hydragnn.utils.abstractrawdataset import AbstractBaseDataset +from hydragnn.utils import nsplit +from hydragnn.preprocess.utils import get_radius_graph_pbc + +# Angstrom unit +primitive_bravais_lattice_constant_x = 3.8 +primitive_bravais_lattice_constant_y = 3.8 +primitive_bravais_lattice_constant_z = 3.8 + + +################################################################################################################## + + +"""High-Level Function""" + + +def create_dataset(path, config): + radius_cutoff = config["NeuralNetwork"]["Architecture"]["radius"] + number_configurations = ( + config["Dataset"]["number_configurations"] + if "number_configurations" in config["Dataset"] + else 300 + ) + atom_types = [1] + formula = LJpotential(1.0, 3.4) + atomic_structure_handler = AtomicStructureHandler( + atom_types, + [ + primitive_bravais_lattice_constant_x, + primitive_bravais_lattice_constant_y, + primitive_bravais_lattice_constant_z, + ], + radius_cutoff, + formula, + ) + deterministic_graph_data( + path, + atom_types, + atomic_structure_handler=atomic_structure_handler, + radius_cutoff=radius_cutoff, + relative_maximum_atomic_displacement=1e-1, + number_configurations=number_configurations, + ) + + +"""Reading/Transforming Data""" + + +class LJDataset(AbstractBaseDataset): + """LJDataset dataset class""" + + def __init__(self, dirpath, config, dist=False, sampling=None): + super().__init__() + + self.dist = dist + self.world_size = 1 + self.rank = 1 + if self.dist: + assert torch.distributed.is_initialized() + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + + self.radius = config["NeuralNetwork"]["Architecture"]["radius"] + self.max_neighbours = config["NeuralNetwork"]["Architecture"]["max_neighbours"] + + dirfiles = sorted(os.listdir(dirpath)) + + rx = list(nsplit((dirfiles), self.world_size))[self.rank] + + for file in rx: + filepath = os.path.join(dirpath, file) + self.dataset.append(self.transform_input_to_data_object_base(filepath)) + + def transform_input_to_data_object_base(self, filepath): + + # Using readline() + file = open(filepath, "r") + + torch_data = torch.empty((0, 8), dtype=torch.float32) + torch_supercell = torch.zeros((0, 3), dtype=torch.float32) + + count = 0 + + while True: + count += 1 + + # Get next line from file + line = file.readline() + + # if line is empty + # end of file is reached + if not line: + break + + if count == 1: + total_energy = float(line) + elif count == 2: + energy_per_atom = float(line) + elif 2 < count < 6: + array_line = numpy.fromstring(line, dtype=float, sep="\t") + torch_supercell = torch.cat( + [torch_supercell, torch.from_numpy(array_line).unsqueeze(0)], axis=0 + ) + elif count > 5: + array_line = numpy.fromstring(line, dtype=float, sep="\t") + torch_data = torch.cat( + [torch_data, torch.from_numpy(array_line).unsqueeze(0)], axis=0 + ) + # print("Line{}: {}".format(count, line.strip())) + + file.close() + + num_nodes = torch_data.shape[0] + + energy_pre_translation_factor = 0.0 + energy_pre_scaling_factor = 1.0 / num_nodes + energy_per_atom_pretransformed = ( + energy_per_atom - energy_pre_translation_factor + ) * energy_pre_scaling_factor + grad_energy_post_scaling_factor = ( + 1.0 / energy_pre_scaling_factor * torch.ones(num_nodes, 1) + ) + forces = torch_data[:, [5, 6, 7]] + forces_pre_scaling_factor = 1.0 + forces_pre_scaled = forces * forces_pre_scaling_factor + + data = Data( + supercell_size=torch_supercell.to(torch.float32), + num_nodes=num_nodes, + grad_energy_post_scaling_factor=grad_energy_post_scaling_factor, + forces_pre_scaling_factor=torch.tensor(forces_pre_scaling_factor).to( + torch.float32 + ), + forces=forces, + forces_pre_scaled=forces_pre_scaled, + pos=torch_data[:, [1, 2, 3]].to(torch.float32), + x=torch.cat([torch_data[:, [0, 4]]], axis=1).to(torch.float32), + y=torch.tensor(total_energy).unsqueeze(0).to(torch.float32), + energy_per_atom=torch.tensor(energy_per_atom_pretransformed) + .unsqueeze(0) + .to(torch.float32), + energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32), + ) + + # Create pbc edges and lengths + edge_creation = get_radius_graph_pbc(self.radius, self.max_neighbours) + data = edge_creation(data) + + return data + + def len(self): + return len(self.dataset) + + def get(self, idx): + return self.dataset[idx] + + +"""Create Data""" + + +def deterministic_graph_data( + path: str, + atom_types: list, + atomic_structure_handler, + radius_cutoff=float("inf"), + max_num_neighbors=float("inf"), + number_configurations: int = 500, + configuration_start: int = 0, + unit_cell_x_range: list = [3, 4], + unit_cell_y_range: list = [3, 4], + unit_cell_z_range: list = [3, 4], + relative_maximum_atomic_displacement: float = 1e-1, +): + + comm = MPI.COMM_WORLD + comm_size = comm.Get_size() + comm_rank = comm.Get_rank() + torch.manual_seed(comm_rank) + + if 0 == comm_rank: + os.makedirs(path, exist_ok=False) + comm.Barrier() + + # We assume that the unit cell is Simple Center Cubic (SCC) + unit_cell_x = torch.randint( + unit_cell_x_range[0], + unit_cell_x_range[1], + (number_configurations,), + ) + unit_cell_y = torch.randint( + unit_cell_y_range[0], + unit_cell_y_range[1], + (number_configurations,), + ) + unit_cell_z = torch.randint( + unit_cell_z_range[0], + unit_cell_z_range[1], + (number_configurations,), + ) + + configurations_list = range(number_configurations) + rx = list(nsplit(configurations_list, comm_size))[comm_rank] + + for configuration in configurations_list[rx.start : rx.stop]: + uc_x = unit_cell_x[configuration] + uc_y = unit_cell_y[configuration] + uc_z = unit_cell_z[configuration] + create_configuration( + path, + atomic_structure_handler, + configuration, + configuration_start, + uc_x, + uc_y, + uc_z, + atom_types, + radius_cutoff, + max_num_neighbors, + relative_maximum_atomic_displacement, + ) + + +def create_configuration( + path, + atomic_structure_handler, + configuration, + configuration_start, + uc_x, + uc_y, + uc_z, + types, + radius_cutoff, + max_num_neighbors, + relative_maximum_atomic_displacement, +): + ############################################################################################### + ################################### STRUCTURE OF THE DATA ################################## + ############################################################################################### + + # GLOBAL_OUTPUT1 + # GLOBAL_OUTPUT2 + # NODE1_FEATURE NODE1_INDEX NODE1_COORDINATE_X NODE1_COORDINATE_Y NODE1_COORDINATE_Z NODAL_OUTPUT1 NODAL_OUTPUT2 NODAL_OUTPUT3 + # NODE2_FEATURE NODE2_INDEX NODE2_COORDINATE_X NODE2_COORDINATE_Y NODE2_COORDINATE_Z NODAL_OUTPUT1 NODAL_OUTPUT2 NODAL_OUTPUT3 + # ... + # NODENn_FEATURE NODEn_INDEX NODEn_COORDINATE_X NODEn_COORDINATE_Y NODEn_COORDINATE_Z NODAL_OUTPUT1 NODAL_OUTPUT2 NODAL_OUTPUT3 + + ############################################################################################### + ################################# FORMULAS FOR NODAL FEATURE ############################### + ############################################################################################### + + # NODAL_FEATURE = ATOM SPECIES + + ############################################################################################### + ########################## FORMULAS FOR GLOBAL AND NODAL OUTPUTS ########################### + ############################################################################################### + + # GLOBAL_OUTPUT = TOTAL ENERGY + # GLOBAL_OUTPUT = TOTAL ENERGY / NUMBER OF NODES + # NODAL_OUTPUT1(X) = FORCE ACTING ON ATOM IN X DIRECTION + # NODAL_OUTPUT2(X) = FORCE ACTING ON ATOM IN Y DIRECTION + # NODAL_OUTPUT3(X) = FORCE ACTING ON ATOM IN Z DIRECTION + + ############################################################################################### + count_pos = 0 + number_nodes = uc_x * uc_y * uc_z + positions = torch.zeros(number_nodes, 3) + for x in range(uc_x): + for y in range(uc_y): + for z in range(uc_z): + positions[count_pos][0] = ( + x + + relative_maximum_atomic_displacement + * ((torch.rand(1, 1).item()) - 0.5) + ) * primitive_bravais_lattice_constant_x + positions[count_pos][1] = ( + y + + relative_maximum_atomic_displacement + * ((torch.rand(1, 1).item()) - 0.5) + ) * primitive_bravais_lattice_constant_y + positions[count_pos][2] = ( + z + + relative_maximum_atomic_displacement + * ((torch.rand(1, 1).item()) - 0.5) + ) * primitive_bravais_lattice_constant_z + + count_pos = count_pos + 1 + + atom_types = torch.randint(min(types), max(types) + 1, (number_nodes, 1)) + + data = Data() + + data.pos = positions + supercell_size_x = primitive_bravais_lattice_constant_x * uc_x + supercell_size_y = primitive_bravais_lattice_constant_y * uc_y + supercell_size_z = primitive_bravais_lattice_constant_z * uc_z + data.supercell_size = torch.diag( + torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z]) + ) + + create_graph_connectivity_pbc = get_radius_graph_pbc( + radius_cutoff, max_num_neighbors + ) + data = create_graph_connectivity_pbc(data) + + atomic_descriptors = torch.cat( + ( + atom_types, + positions, + ), + 1, + ) + + data.x = atomic_descriptors + + data = atomic_structure_handler.compute(data) + + total_energy = torch.sum(data.x[:, 4]) + energy_per_atom = total_energy / number_nodes + + total_energy_str = numpy.array2string(total_energy.detach().numpy()) + energy_per_atom_str = numpy.array2string(energy_per_atom.detach().numpy()) + filetxt = total_energy_str + "\n" + energy_per_atom_str + + for index in range(0, 3): + numpy_row = data.supercell_size[index, :].detach().numpy() + numpy_string_row = numpy.array2string(numpy_row, precision=64, separator="\t") + filetxt += "\n" + numpy_string_row.lstrip("[").rstrip("]") + + for index in range(0, number_nodes): + numpy_row = data.x[index, :].detach().numpy() + numpy_string_row = numpy.array2string(numpy_row, precision=64, separator="\t") + filetxt += "\n" + numpy_string_row.lstrip("[").rstrip("]") + + filename = os.path.join( + path, "output" + str(configuration + configuration_start) + ".txt" + ) + with open(filename, "w") as f: + f.write(filetxt) + + +"""Function Calculation""" + + +class AtomicStructureHandler: + def __init__( + self, list_atom_types, bravais_lattice_constants, radius_cutoff, formula + ): + + self.bravais_lattice_constants = bravais_lattice_constants + self.radius_cutoff = radius_cutoff + self.formula = formula + + def compute(self, data): + + assert data.pos.shape[0] == data.x.shape[0] + + interatomic_potential = torch.zeros([data.pos.shape[0], 1]) + interatomic_forces = torch.zeros([data.pos.shape[0], 3]) + + for node_id in range(data.pos.shape[0]): + + neighbor_list_indices = torch.where(data.edge_index[0, :] == node_id)[ + 0 + ].tolist() + neighbor_list = data.edge_index[1, neighbor_list_indices] + + for neighbor_id, edge_id in zip(neighbor_list, neighbor_list_indices): + + neighbor_pos = data.pos[neighbor_id, :] + distance_vector = data.pos[neighbor_id, :] - data.pos[node_id, :] + + # Adjust the neighbor position based on periodic boundary conditions (PBC) + ## If the distance between the atoms is larger than the cutoff radius, the edge is because of PBC conditions + if torch.norm(distance_vector) > self.radius_cutoff: + ## At this point, we know that the edge is due to PBC conditions, so we need to adjust the neighbor position. We also know that + ## that this connection MUST be the closest connection possible as a result of the asserted radius_cutoff < supercell_size earlier + ## in the code. Because of this, we can simply adjust the neighbor position coordinate-wise to be closer than + ## as done in the following lines of code. The logic goes that if the distance vector[index] is larger than half the supercell size, + ## then there is a closer distance at +- supercell_size[index], and we adjust to that for each coordinate + if abs(distance_vector[0]) > data.supercell_size[0, 0] / 2: + if distance_vector[0] > 0: + neighbor_pos[0] -= data.supercell_size[0, 0] + else: + neighbor_pos[0] += data.supercell_size[0, 0] + + if abs(distance_vector[1]) > data.supercell_size[1, 1] / 2: + if distance_vector[1] > 0: + neighbor_pos[1] -= data.supercell_size[1, 1] + else: + neighbor_pos[1] += data.supercell_size[1, 1] + + if abs(distance_vector[2]) > data.supercell_size[2, 2] / 2: + if distance_vector[2] > 0: + neighbor_pos[2] -= data.supercell_size[2, 2] + else: + neighbor_pos[2] += data.supercell_size[2, 2] + + # The distance vecor may need to be updated after applying PBCs + distance_vector = data.pos[node_id, :] - neighbor_pos + + # pair_distance = data.edge_attr[edge_id].item() + interatomic_potential[node_id] += self.formula.potential_energy( + distance_vector + ) + + derivative_x = self.formula.derivative_x(distance_vector) + derivative_y = self.formula.derivative_y(distance_vector) + derivative_z = self.formula.derivative_z(distance_vector) + + interatomic_forces_contribution_x = -derivative_x + interatomic_forces_contribution_y = -derivative_y + interatomic_forces_contribution_z = -derivative_z + + interatomic_forces[node_id, 0] += interatomic_forces_contribution_x + interatomic_forces[node_id, 1] += interatomic_forces_contribution_y + interatomic_forces[node_id, 2] += interatomic_forces_contribution_z + + data.x = torch.cat( + (data.x, interatomic_potential, interatomic_forces), + 1, + ) + + return data + + +class LJpotential: + def __init__(self, epsilon, sigma): + self.epsilon = epsilon + self.sigma = sigma + + def potential_energy(self, distance_vector): + pair_distance = torch.norm(distance_vector) + return ( + 4 + * self.epsilon + * ((self.sigma / pair_distance) ** 12 - (self.sigma / pair_distance) ** 6) + ) + + def radial_derivative(self, distance_vector): + pair_distance = torch.norm(distance_vector) + return ( + 4 + * self.epsilon + * ( + -12 * (self.sigma / pair_distance) ** 12 * 1 / pair_distance + + 6 * (self.sigma / pair_distance) ** 6 * 1 / pair_distance + ) + ) + + def derivative_x(self, distance_vector): + pair_distance = torch.norm(distance_vector) + radial_derivative = self.radial_derivative(pair_distance) + return radial_derivative * (distance_vector[0].item()) / pair_distance + + def derivative_y(self, distance_vector): + pair_distance = torch.norm(distance_vector) + radial_derivative = self.radial_derivative(pair_distance) + return radial_derivative * (distance_vector[1].item()) / pair_distance + + def derivative_z(self, distance_vector): + pair_distance = torch.norm(distance_vector) + radial_derivative = self.radial_derivative(pair_distance) + return radial_derivative * (distance_vector[2].item()) / pair_distance + + +"""Etc""" + + +def info(*args, logtype="info", sep=" "): + getattr(logging, logtype)(sep.join(map(str, args))) diff --git a/examples/LJ_inference_plots.py b/examples/LJ_inference_plots.py new file mode 100644 index 000000000..324da425f --- /dev/null +++ b/examples/LJ_inference_plots.py @@ -0,0 +1,241 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +import json, os +import sys +import logging +import pickle +from tqdm import tqdm +from mpi4py import MPI +import argparse + +import torch +import torch_scatter +import numpy as np + +import hydragnn +from hydragnn.utils.time_utils import Timer +from hydragnn.utils.distributed import get_device +from hydragnn.utils.model import load_existing_model +from hydragnn.utils.pickledataset import SimplePickleDataset +from hydragnn.utils.config_utils import ( + update_config, +) +from hydragnn.models.create import create_model_config +from hydragnn.preprocess import create_dataloaders + +from scipy.interpolate import griddata + +try: + from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset +except ImportError: + pass + +from LJ_data import info + +import matplotlib.pyplot as plt + +plt.rcParams.update({"font.size": 16}) + + +def get_log_name_config(config): + return ( + config["NeuralNetwork"]["Architecture"]["model_type"] + + "-r-" + + str(config["NeuralNetwork"]["Architecture"]["radius"]) + + "-ncl-" + + str(config["NeuralNetwork"]["Architecture"]["num_conv_layers"]) + + "-hd-" + + str(config["NeuralNetwork"]["Architecture"]["hidden_dim"]) + + "-ne-" + + str(config["NeuralNetwork"]["Training"]["num_epoch"]) + + "-lr-" + + str(config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"]) + + "-bs-" + + str(config["NeuralNetwork"]["Training"]["batch_size"]) + + "-node_ft-" + + "".join( + str(x) + for x in config["NeuralNetwork"]["Variables_of_interest"][ + "input_node_features" + ] + ) + + "-task_weights-" + + "".join( + str(weigh) + "-" + for weigh in config["NeuralNetwork"]["Architecture"]["task_weights"] + ) + ) + + +def getcolordensity(xdata, ydata): + ############################### + nbin = 20 + hist2d, xbins_edge, ybins_edge = np.histogram2d(x=xdata, y=ydata, bins=[nbin, nbin]) + xbin_cen = 0.5 * (xbins_edge[0:-1] + xbins_edge[1:]) + ybin_cen = 0.5 * (ybins_edge[0:-1] + ybins_edge[1:]) + BCTY, BCTX = np.meshgrid(ybin_cen, xbin_cen) + hist2d = hist2d / np.amax(hist2d) + print(np.amax(hist2d)) + + bctx1d = np.reshape(BCTX, len(xbin_cen) * nbin) + bcty1d = np.reshape(BCTY, len(xbin_cen) * nbin) + loc_pts = np.zeros((len(xbin_cen) * nbin, 2)) + loc_pts[:, 0] = bctx1d + loc_pts[:, 1] = bcty1d + hist2d_norm = griddata( + loc_pts, + hist2d.reshape(len(xbin_cen) * nbin), + (xdata, ydata), + method="linear", + fill_value=0, + ) # np.nan) + return hist2d_norm + + +if __name__ == "__main__": + + modelname = "LJ" + + parser = argparse.ArgumentParser() + parser.add_argument( + "--inputfile", help="input file", type=str, default="./logs/LJ/config.json" + ) + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--adios", + help="Adios gan_dataset", + action="store_const", + dest="format", + const="adios", + ) + group.add_argument( + "--pickle", + help="Pickle gan_dataset", + action="store_const", + dest="format", + const="pickle", + ) + parser.set_defaults(format="pickle") + + args = parser.parse_args() + + dirpwd = os.path.dirname(os.path.abspath(__file__)) + input_filename = os.path.join(dirpwd, args.inputfile) + with open(input_filename, "r") as f: + config = json.load(f) + hydragnn.utils.setup_log(get_log_name_config(config)) + ################################################################################################################## + # Always initialize for multi-rank training. + comm_size, rank = hydragnn.utils.setup_ddp() + ################################################################################################################## + comm = MPI.COMM_WORLD + + datasetname = "LJ" + + comm.Barrier() + + timer = Timer("load_data") + timer.start() + if args.format == "pickle": + info("Pickle load") + basedir = os.path.join( + os.path.dirname(__file__), "dataset", "%s.pickle" % modelname + ) + trainset = SimplePickleDataset( + basedir=basedir, + label="trainset", + var_config=config["NeuralNetwork"]["Variables_of_interest"], + ) + valset = SimplePickleDataset( + basedir=basedir, + label="valset", + var_config=config["NeuralNetwork"]["Variables_of_interest"], + ) + testset = SimplePickleDataset( + basedir=basedir, + label="testset", + var_config=config["NeuralNetwork"]["Variables_of_interest"], + ) + pna_deg = trainset.pna_deg + else: + raise NotImplementedError("No supported format: %s" % (args.format)) + + model = create_model_config( + config=config["NeuralNetwork"], + verbosity=config["Verbosity"]["level"], + ) + + model = torch.nn.parallel.DistributedDataParallel(model) + + load_existing_model(model, modelname, path="./logs/") + model.eval() + + variable_index = 0 + # for output_name, output_type, output_dim in zip(config["NeuralNetwork"]["Variables_of_interest"]["output_names"], config["NeuralNetwork"]["Variables_of_interest"]["type"], config["NeuralNetwork"]["Variables_of_interest"]["output_dim"]): + + test_MAE = 0.0 + + num_samples = len(testset) + energy_true_list = [] + energy_pred_list = [] + forces_true_list = [] + forces_pred_list = [] + + for data_id, data in enumerate(tqdm(testset)): + data.pos.requires_grad = True + node_energy_pred = model(data.to(get_device()))[ + 0 + ] # Note that this is sensitive to energy and forces prediction being single-task (current requirement) + energy_pred = torch.sum(node_energy_pred, dim=0).float() + test_MAE += torch.norm(energy_pred - data.energy, p=1).item() / len(testset) + # predicted.backward(retain_graph=True) + # gradients = data.pos.grad + grads_energy = torch.autograd.grad( + outputs=energy_pred, + inputs=data.pos, + grad_outputs=torch.ones_like(energy_pred), + retain_graph=False, + create_graph=True, + )[0] + energy_pred_list.extend(energy_pred.tolist()) + energy_true_list.extend(data.energy.tolist()) + forces_pred_list.extend((-grads_energy).flatten().tolist()) + forces_true_list.extend(data.forces.flatten().tolist()) + + hist2d_norm = getcolordensity(energy_true_list, energy_pred_list) + + fig, ax = plt.subplots() + plt.scatter(energy_true_list, energy_pred_list, s=8, c=hist2d_norm, vmin=0, vmax=1) + plt.clim(0, 1) + ax.plot(ax.get_xlim(), ax.get_xlim(), ls="--", color="red") + plt.colorbar() + plt.xlabel("True values") + plt.ylabel("Predicted values") + plt.title(f"energy") + plt.draw() + plt.tight_layout() + plt.savefig(f"./energy_Scatterplot" + ".png", dpi=400) + + print(f"Test MAE energy: ", test_MAE) + + hist2d_norm = getcolordensity(forces_pred_list, forces_true_list) + fig, ax = plt.subplots() + plt.scatter(forces_pred_list, forces_true_list, s=8, c=hist2d_norm, vmin=0, vmax=1) + plt.clim(0, 1) + ax.plot(ax.get_xlim(), ax.get_xlim(), ls="--", color="red") + plt.colorbar() + plt.xlabel("Predicted Values") + plt.ylabel("True Values") + plt.title("Forces") + plt.draw() + plt.tight_layout() + plt.savefig(f"./Forces_Scatterplot" + ".png", dpi=400) diff --git a/examples/LennardJones.py b/examples/LennardJones.py deleted file mode 100644 index 2f4e57774..000000000 --- a/examples/LennardJones.py +++ /dev/null @@ -1,327 +0,0 @@ -############################################################################## -# Copyright (c) 2024, Oak Ridge National Laboratory # -# All rights reserved. # -# # -# This file is part of HydraGNN and is distributed under a BSD 3-clause # -# license. For the licensing terms see the LICENSE file in the top-level # -# directory. # -# # -# SPDX-License-Identifier: BSD-3-Clause # -############################################################################## - -# General -import os, json -import logging -import sys -import argparse - -# Torch -import torch - -# torch.set_default_tensor_type(torch.DoubleTensor) -# torch.set_default_dtype(torch.float64) - -# Distributed -import mpi4py -from mpi4py import MPI - -mpi4py.rc.thread_level = "serialized" -mpi4py.rc.threads = False - -# HydraGNN -import hydragnn -from hydragnn.utils.print_utils import log -from hydragnn.utils.time_utils import Timer -import hydragnn.utils.tracer as tr -from hydragnn.preprocess.load_data import split_dataset -from hydragnn.utils.distdataset import DistDataset -from hydragnn.utils.pickledataset import SimplePickleWriter, SimplePickleDataset -from hydragnn.preprocess.utils import gather_deg - -try: - from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset -except ImportError: - pass - -# Lennard Jones -from LJ_data import create_dataset, LJDataset, info - - -################################################################################################################## - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument("--sampling", type=float, help="sampling ratio", default=None) - parser.add_argument( - "--preonly", - action="store_true", - help="preprocess only (no training)", - ) - parser.add_argument("--inputfile", help="input file", type=str, default="LJ.json") - parser.add_argument("--model_type", help="model type", type=str, default=None) - parser.add_argument("--mae", action="store_true", help="do mae calculation") - parser.add_argument("--ddstore", action="store_true", help="ddstore dataset") - parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None) - parser.add_argument("--shmem", action="store_true", help="shmem") - parser.add_argument("--log", help="log name") - parser.add_argument("--batch_size", type=int, help="batch_size", default=None) - parser.add_argument("--everyone", action="store_true", help="gptimer") - - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--adios", - help="Adios dataset", - action="store_const", - dest="format", - const="adios", - ) - group.add_argument( - "--pickle", - help="Pickle dataset", - action="store_const", - dest="format", - const="pickle", - ) - parser.set_defaults(format="pickle") # Changed this for my PC - args = parser.parse_args() - - graph_feature_names = ["total_energy"] - graph_feature_dims = [1] - node_feature_names = ["atomic_number", "potential", "forces"] - node_feature_dims = [1, 1, 3] - dirpwd = os.path.dirname(os.path.abspath(__file__)) - ################################################################################################################## - input_filename = os.path.join(dirpwd, args.inputfile) - ################################################################################################################## - # Configurable run choices (JSON file that accompanies this example script). - with open(input_filename, "r") as f: - config = json.load(f) - config["NeuralNetwork"]["Architecture"]["model_type"] = ( - args.model_type - if args.model_type - else config["NeuralNetwork"]["Architecture"]["model_type"] - ) - verbosity = config["Verbosity"]["level"] - config["NeuralNetwork"]["Variables_of_interest"][ - "graph_feature_names" - ] = graph_feature_names - config["NeuralNetwork"]["Variables_of_interest"][ - "graph_feature_dims" - ] = graph_feature_dims - config["NeuralNetwork"]["Variables_of_interest"][ - "node_feature_names" - ] = node_feature_names - config["NeuralNetwork"]["Variables_of_interest"][ - "node_feature_dims" - ] = node_feature_dims - - if args.batch_size is not None: - config["NeuralNetwork"]["Training"]["batch_size"] = args.batch_size - - ################################################################################################################## - # Always initialize for multi-rank training. - comm_size, rank = hydragnn.utils.setup_ddp() - ################################################################################################################## - - comm = MPI.COMM_WORLD - - ## Set up logging - logging.basicConfig( - level=logging.INFO, - format="%%(levelname)s (rank %d): %%(message)s" % (rank), - datefmt="%H:%M:%S", - ) - - log_name = "LJ" if args.log is None else args.log - hydragnn.utils.setup_log(log_name) - writer = hydragnn.utils.get_summary_writer(log_name) - - log("Command: {0}\n".format(" ".join([x for x in sys.argv])), rank=0) - - modelname = "LJ" - # Check for dataset for each format - if args.format == "pickle": - basedir = os.path.join(dirpwd, "dataset", "%s.pickle" % modelname) - dataset_exists = os.path.exists(os.path.join(dirpwd, "dataset/LJ.pickle")) - if args.format == "adios": - fname = os.path.join(dirpwd, "./dataset/%s.bp" % modelname) - dataset_exists = os.path.exists( - os.path.join(dirpwd, "dataset", "%s.bp" % modelname) - ) - - # Create dataset if preonly specified or dataset does not exist - if not dataset_exists: - - ## local data - create_dataset(os.path.join(dirpwd, "dataset/data"), config) - total = LJDataset( - os.path.join(dirpwd, "dataset/data"), - config, - dist=True, - ) - ## This is a local split - trainset, valset, testset = split_dataset( - dataset=total, - perc_train=config["NeuralNetwork"]["Training"]["perc_train"], - stratify_splitting=False, - ) - print("Local splitting: ", len(total), len(trainset), len(valset), len(testset)) - - deg = gather_deg(trainset) - config["pna_deg"] = deg.tolist() - - setnames = ["trainset", "valset", "testset"] - - if args.format == "pickle": - - ## pickle - attrs = dict() - attrs["pna_deg"] = deg - SimplePickleWriter( - trainset, - basedir, - "trainset", - # minmax_node_feature=total.minmax_node_feature, - # minmax_graph_feature=total.minmax_graph_feature, - use_subdir=True, - attrs=attrs, - ) - SimplePickleWriter( - valset, - basedir, - "valset", - # minmax_node_feature=total.minmax_node_feature, - # minmax_graph_feature=total.minmax_graph_feature, - use_subdir=True, - ) - SimplePickleWriter( - testset, - basedir, - "testset", - # minmax_node_feature=total.minmax_node_feature, - # minmax_graph_feature=total.minmax_graph_feature, - use_subdir=True, - ) - - if args.format == "adios": - ## adios - adwriter = AdiosWriter(fname, comm) - adwriter.add("trainset", trainset) - adwriter.add("valset", valset) - adwriter.add("testset", testset) - # adwriter.add_global("minmax_node_feature", total.minmax_node_feature) - # adwriter.add_global("minmax_graph_feature", total.minmax_graph_feature) - adwriter.add_global("pna_deg", deg) - adwriter.save() - - tr.initialize() - tr.disable() - timer = Timer("load_data") - timer.start() - if args.format == "adios": - info("Adios load") - assert not (args.shmem and args.ddstore), "Cannot use both ddstore and shmem" - opt = { - "preload": False, - "shmem": args.shmem, - "ddstore": args.ddstore, - "ddstore_width": args.ddstore_width, - } - trainset = AdiosDataset(fname, "trainset", comm, **opt) - valset = AdiosDataset(fname, "valset", comm, **opt) - testset = AdiosDataset(fname, "testset", comm, **opt) - elif args.format == "pickle": - info("Pickle load") - var_config = config["NeuralNetwork"]["Variables_of_interest"] - trainset = SimplePickleDataset( - basedir=basedir, label="trainset", preload=True, var_config=var_config - ) - valset = SimplePickleDataset( - basedir=basedir, label="valset", var_config=var_config - ) - testset = SimplePickleDataset( - basedir=basedir, label="testset", var_config=var_config - ) - # minmax_node_feature = trainset.minmax_node_feature - # minmax_graph_feature = trainset.minmax_graph_feature - pna_deg = trainset.pna_deg - if args.ddstore: - opt = {"ddstore_width": args.ddstore_width} - trainset = DistDataset(trainset, "trainset", comm, **opt) - valset = DistDataset(valset, "valset", comm, **opt) - testset = DistDataset(testset, "testset", comm, **opt) - # trainset.minmax_node_feature = minmax_node_feature - # trainset.minmax_graph_feature = minmax_graph_feature - trainset.pna_deg = pna_deg - else: - raise NotImplementedError("No supported format: %s" % (args.format)) - - info( - "trainset,valset,testset size: %d %d %d" - % (len(trainset), len(valset), len(testset)) - ) - - if args.ddstore: - os.environ["HYDRAGNN_AGGR_BACKEND"] = "mpi" - os.environ["HYDRAGNN_USE_ddstore"] = "1" - - (train_loader, val_loader, test_loader,) = hydragnn.preprocess.create_dataloaders( - trainset, valset, testset, config["NeuralNetwork"]["Training"]["batch_size"] - ) - - config = hydragnn.utils.update_config(config, train_loader, val_loader, test_loader) - ## Good to sync with everyone right after DDStore setup - comm.Barrier() - - hydragnn.utils.save_config(config, log_name) - - timer.stop() - - model = hydragnn.models.create_model_config( - config=config["NeuralNetwork"], - verbosity=verbosity, - ) - model = hydragnn.utils.get_distributed_model(model, verbosity) - - learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"] - optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001 - ) - - hydragnn.utils.load_existing_model_config( - model, config["NeuralNetwork"]["Training"], optimizer=optimizer - ) - - ################################################################################################################## - - hydragnn.train.train_validate_test( - model, - optimizer, - train_loader, - val_loader, - test_loader, - writer, - scheduler, - config["NeuralNetwork"], - log_name, - verbosity, - create_plots=False, - compute_grad_energy=True, - ) - - hydragnn.utils.save_model(model, optimizer, log_name) - hydragnn.utils.print_timers(verbosity) - - if tr.has("GPTLTracer"): - import gptl4py as gp - - eligible = rank if args.everyone else 0 - if rank == eligible: - gp.pr_file(os.path.join("logs", log_name, "gp_timing.p%d" % rank)) - gp.pr_summary_file(os.path.join("logs", log_name, "gp_timing.summary")) - gp.finalize() - sys.exit(0) diff --git a/tests/test_forces_equivariant.py b/tests/test_forces_equivariant.py deleted file mode 100644 index 3f0b51218..000000000 --- a/tests/test_forces_equivariant.py +++ /dev/null @@ -1,27 +0,0 @@ -############################################################################## -# Copyright (c) 2024, Oak Ridge National Laboratory # -# All rights reserved. # -# # -# This file is part of HydraGNN and is distributed under a BSD 3-clause # -# license. For the licensing terms see the LICENSE file in the top-level # -# directory. # -# # -# SPDX-License-Identifier: BSD-3-Clause # -############################################################################## - -import os -import pytest - -import subprocess - - -@pytest.mark.parametrize("example", ["LennardJones"]) -@pytest.mark.parametrize("model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus", "MACE"]) -@pytest.mark.mpi_skip() -def pytest_examples(example, model_type): - path = os.path.join(os.path.dirname(__file__), "..", "examples", example) - file_path = os.path.join(path, example + ".py") # Assuming different model scripts - return_code = subprocess.call(["python", file_path, "--model_type", model_type]) - - # Check the file ran without error. - assert return_code == 0 From 06c539a94586f711166a90a014c00f68270fd2ce Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 25 Sep 2024 19:16:33 -0400 Subject: [PATCH 19/51] revert separate attempt to get test_forces in --- examples/LJ.json | 75 ----- examples/LJ_data.py | 504 --------------------------------- examples/LJ_inference_plots.py | 241 ---------------- 3 files changed, 820 deletions(-) delete mode 100644 examples/LJ.json delete mode 100644 examples/LJ_data.py delete mode 100644 examples/LJ_inference_plots.py diff --git a/examples/LJ.json b/examples/LJ.json deleted file mode 100644 index a6b18f12b..000000000 --- a/examples/LJ.json +++ /dev/null @@ -1,75 +0,0 @@ -{ - "Verbosity": { - "level": 2 - }, - "Dataset": { - "name": "LJdataset", - "format": "XYZ", - "node_features": { - "name": ["atom_type"], - "dim": [1], - "column_index": [0] - }, - "graph_features":{ - "name": ["total_energy"], - "dim": [1], - "column_index": [0] - } - }, - "NeuralNetwork": { - "Architecture": { - "periodic_boundary_conditions": true, - "model_type": "DimeNet", - "radius": 5.0, - "max_neighbours": 5, - "int_emb_size": 32, - "out_emb_size": 16, - "basis_emb_size": 8, - "num_gaussians": 10, - "num_filters": 8, - "num_before_skip": 1, - "num_after_skip": 1, - "envelope_exponent": 5, - "num_radial": 5, - "num_spherical": 2, - "hidden_dim": 20, - "num_conv_layers": 4, - "output_heads": { - "node": { - "num_headlayers": 2, - "dim_headlayers": [60,20], - "type": "mlp" - } - }, - "task_weights": [1] - }, - "Variables_of_interest": { - "input_node_features": [0], - "output_index": [ - 0 - ], - "type": [ - "node" - ], - "output_dim": [1], - "output_names": ["graph_energy"] - }, - "Training": { - "num_epoch": 15, - "batch_size": 64, - "perc_train": 0.7, - "patience": 20, - "early_stopping": true, - "Optimizer": { - "type": "Adam", - "learning_rate": 0.005 - }, - "conv_checkpointing": false - } - }, - "Visualization": { - "plot_init_solution": true, - "plot_hist_solution": true, - "create_plots": true - } -} diff --git a/examples/LJ_data.py b/examples/LJ_data.py deleted file mode 100644 index 594d6d154..000000000 --- a/examples/LJ_data.py +++ /dev/null @@ -1,504 +0,0 @@ -############################################################################## -# Copyright (c) 2024, Oak Ridge National Laboratory # -# All rights reserved. # -# # -# This file is part of HydraGNN and is distributed under a BSD 3-clause # -# license. For the licensing terms see the LICENSE file in the top-level # -# directory. # -# # -# SPDX-License-Identifier: BSD-3-Clause # -############################################################################## - -# General -import os -import logging -import numpy - -numpy.set_printoptions(threshold=numpy.inf) -numpy.set_printoptions(linewidth=numpy.inf) - -# Torch -import torch -from torch_geometric.data import Data - -# torch.set_default_tensor_type(torch.DoubleTensor) -# torch.set_default_dtype(torch.float64) - -# Distributed -import mpi4py -from mpi4py import MPI - -mpi4py.rc.thread_level = "serialized" -mpi4py.rc.threads = False - -# HydraGNN -from hydragnn.utils.abstractrawdataset import AbstractBaseDataset -from hydragnn.utils import nsplit -from hydragnn.preprocess.utils import get_radius_graph_pbc - -# Angstrom unit -primitive_bravais_lattice_constant_x = 3.8 -primitive_bravais_lattice_constant_y = 3.8 -primitive_bravais_lattice_constant_z = 3.8 - - -################################################################################################################## - - -"""High-Level Function""" - - -def create_dataset(path, config): - radius_cutoff = config["NeuralNetwork"]["Architecture"]["radius"] - number_configurations = ( - config["Dataset"]["number_configurations"] - if "number_configurations" in config["Dataset"] - else 300 - ) - atom_types = [1] - formula = LJpotential(1.0, 3.4) - atomic_structure_handler = AtomicStructureHandler( - atom_types, - [ - primitive_bravais_lattice_constant_x, - primitive_bravais_lattice_constant_y, - primitive_bravais_lattice_constant_z, - ], - radius_cutoff, - formula, - ) - deterministic_graph_data( - path, - atom_types, - atomic_structure_handler=atomic_structure_handler, - radius_cutoff=radius_cutoff, - relative_maximum_atomic_displacement=1e-1, - number_configurations=number_configurations, - ) - - -"""Reading/Transforming Data""" - - -class LJDataset(AbstractBaseDataset): - """LJDataset dataset class""" - - def __init__(self, dirpath, config, dist=False, sampling=None): - super().__init__() - - self.dist = dist - self.world_size = 1 - self.rank = 1 - if self.dist: - assert torch.distributed.is_initialized() - self.world_size = torch.distributed.get_world_size() - self.rank = torch.distributed.get_rank() - - self.radius = config["NeuralNetwork"]["Architecture"]["radius"] - self.max_neighbours = config["NeuralNetwork"]["Architecture"]["max_neighbours"] - - dirfiles = sorted(os.listdir(dirpath)) - - rx = list(nsplit((dirfiles), self.world_size))[self.rank] - - for file in rx: - filepath = os.path.join(dirpath, file) - self.dataset.append(self.transform_input_to_data_object_base(filepath)) - - def transform_input_to_data_object_base(self, filepath): - - # Using readline() - file = open(filepath, "r") - - torch_data = torch.empty((0, 8), dtype=torch.float32) - torch_supercell = torch.zeros((0, 3), dtype=torch.float32) - - count = 0 - - while True: - count += 1 - - # Get next line from file - line = file.readline() - - # if line is empty - # end of file is reached - if not line: - break - - if count == 1: - total_energy = float(line) - elif count == 2: - energy_per_atom = float(line) - elif 2 < count < 6: - array_line = numpy.fromstring(line, dtype=float, sep="\t") - torch_supercell = torch.cat( - [torch_supercell, torch.from_numpy(array_line).unsqueeze(0)], axis=0 - ) - elif count > 5: - array_line = numpy.fromstring(line, dtype=float, sep="\t") - torch_data = torch.cat( - [torch_data, torch.from_numpy(array_line).unsqueeze(0)], axis=0 - ) - # print("Line{}: {}".format(count, line.strip())) - - file.close() - - num_nodes = torch_data.shape[0] - - energy_pre_translation_factor = 0.0 - energy_pre_scaling_factor = 1.0 / num_nodes - energy_per_atom_pretransformed = ( - energy_per_atom - energy_pre_translation_factor - ) * energy_pre_scaling_factor - grad_energy_post_scaling_factor = ( - 1.0 / energy_pre_scaling_factor * torch.ones(num_nodes, 1) - ) - forces = torch_data[:, [5, 6, 7]] - forces_pre_scaling_factor = 1.0 - forces_pre_scaled = forces * forces_pre_scaling_factor - - data = Data( - supercell_size=torch_supercell.to(torch.float32), - num_nodes=num_nodes, - grad_energy_post_scaling_factor=grad_energy_post_scaling_factor, - forces_pre_scaling_factor=torch.tensor(forces_pre_scaling_factor).to( - torch.float32 - ), - forces=forces, - forces_pre_scaled=forces_pre_scaled, - pos=torch_data[:, [1, 2, 3]].to(torch.float32), - x=torch.cat([torch_data[:, [0, 4]]], axis=1).to(torch.float32), - y=torch.tensor(total_energy).unsqueeze(0).to(torch.float32), - energy_per_atom=torch.tensor(energy_per_atom_pretransformed) - .unsqueeze(0) - .to(torch.float32), - energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32), - ) - - # Create pbc edges and lengths - edge_creation = get_radius_graph_pbc(self.radius, self.max_neighbours) - data = edge_creation(data) - - return data - - def len(self): - return len(self.dataset) - - def get(self, idx): - return self.dataset[idx] - - -"""Create Data""" - - -def deterministic_graph_data( - path: str, - atom_types: list, - atomic_structure_handler, - radius_cutoff=float("inf"), - max_num_neighbors=float("inf"), - number_configurations: int = 500, - configuration_start: int = 0, - unit_cell_x_range: list = [3, 4], - unit_cell_y_range: list = [3, 4], - unit_cell_z_range: list = [3, 4], - relative_maximum_atomic_displacement: float = 1e-1, -): - - comm = MPI.COMM_WORLD - comm_size = comm.Get_size() - comm_rank = comm.Get_rank() - torch.manual_seed(comm_rank) - - if 0 == comm_rank: - os.makedirs(path, exist_ok=False) - comm.Barrier() - - # We assume that the unit cell is Simple Center Cubic (SCC) - unit_cell_x = torch.randint( - unit_cell_x_range[0], - unit_cell_x_range[1], - (number_configurations,), - ) - unit_cell_y = torch.randint( - unit_cell_y_range[0], - unit_cell_y_range[1], - (number_configurations,), - ) - unit_cell_z = torch.randint( - unit_cell_z_range[0], - unit_cell_z_range[1], - (number_configurations,), - ) - - configurations_list = range(number_configurations) - rx = list(nsplit(configurations_list, comm_size))[comm_rank] - - for configuration in configurations_list[rx.start : rx.stop]: - uc_x = unit_cell_x[configuration] - uc_y = unit_cell_y[configuration] - uc_z = unit_cell_z[configuration] - create_configuration( - path, - atomic_structure_handler, - configuration, - configuration_start, - uc_x, - uc_y, - uc_z, - atom_types, - radius_cutoff, - max_num_neighbors, - relative_maximum_atomic_displacement, - ) - - -def create_configuration( - path, - atomic_structure_handler, - configuration, - configuration_start, - uc_x, - uc_y, - uc_z, - types, - radius_cutoff, - max_num_neighbors, - relative_maximum_atomic_displacement, -): - ############################################################################################### - ################################### STRUCTURE OF THE DATA ################################## - ############################################################################################### - - # GLOBAL_OUTPUT1 - # GLOBAL_OUTPUT2 - # NODE1_FEATURE NODE1_INDEX NODE1_COORDINATE_X NODE1_COORDINATE_Y NODE1_COORDINATE_Z NODAL_OUTPUT1 NODAL_OUTPUT2 NODAL_OUTPUT3 - # NODE2_FEATURE NODE2_INDEX NODE2_COORDINATE_X NODE2_COORDINATE_Y NODE2_COORDINATE_Z NODAL_OUTPUT1 NODAL_OUTPUT2 NODAL_OUTPUT3 - # ... - # NODENn_FEATURE NODEn_INDEX NODEn_COORDINATE_X NODEn_COORDINATE_Y NODEn_COORDINATE_Z NODAL_OUTPUT1 NODAL_OUTPUT2 NODAL_OUTPUT3 - - ############################################################################################### - ################################# FORMULAS FOR NODAL FEATURE ############################### - ############################################################################################### - - # NODAL_FEATURE = ATOM SPECIES - - ############################################################################################### - ########################## FORMULAS FOR GLOBAL AND NODAL OUTPUTS ########################### - ############################################################################################### - - # GLOBAL_OUTPUT = TOTAL ENERGY - # GLOBAL_OUTPUT = TOTAL ENERGY / NUMBER OF NODES - # NODAL_OUTPUT1(X) = FORCE ACTING ON ATOM IN X DIRECTION - # NODAL_OUTPUT2(X) = FORCE ACTING ON ATOM IN Y DIRECTION - # NODAL_OUTPUT3(X) = FORCE ACTING ON ATOM IN Z DIRECTION - - ############################################################################################### - count_pos = 0 - number_nodes = uc_x * uc_y * uc_z - positions = torch.zeros(number_nodes, 3) - for x in range(uc_x): - for y in range(uc_y): - for z in range(uc_z): - positions[count_pos][0] = ( - x - + relative_maximum_atomic_displacement - * ((torch.rand(1, 1).item()) - 0.5) - ) * primitive_bravais_lattice_constant_x - positions[count_pos][1] = ( - y - + relative_maximum_atomic_displacement - * ((torch.rand(1, 1).item()) - 0.5) - ) * primitive_bravais_lattice_constant_y - positions[count_pos][2] = ( - z - + relative_maximum_atomic_displacement - * ((torch.rand(1, 1).item()) - 0.5) - ) * primitive_bravais_lattice_constant_z - - count_pos = count_pos + 1 - - atom_types = torch.randint(min(types), max(types) + 1, (number_nodes, 1)) - - data = Data() - - data.pos = positions - supercell_size_x = primitive_bravais_lattice_constant_x * uc_x - supercell_size_y = primitive_bravais_lattice_constant_y * uc_y - supercell_size_z = primitive_bravais_lattice_constant_z * uc_z - data.supercell_size = torch.diag( - torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z]) - ) - - create_graph_connectivity_pbc = get_radius_graph_pbc( - radius_cutoff, max_num_neighbors - ) - data = create_graph_connectivity_pbc(data) - - atomic_descriptors = torch.cat( - ( - atom_types, - positions, - ), - 1, - ) - - data.x = atomic_descriptors - - data = atomic_structure_handler.compute(data) - - total_energy = torch.sum(data.x[:, 4]) - energy_per_atom = total_energy / number_nodes - - total_energy_str = numpy.array2string(total_energy.detach().numpy()) - energy_per_atom_str = numpy.array2string(energy_per_atom.detach().numpy()) - filetxt = total_energy_str + "\n" + energy_per_atom_str - - for index in range(0, 3): - numpy_row = data.supercell_size[index, :].detach().numpy() - numpy_string_row = numpy.array2string(numpy_row, precision=64, separator="\t") - filetxt += "\n" + numpy_string_row.lstrip("[").rstrip("]") - - for index in range(0, number_nodes): - numpy_row = data.x[index, :].detach().numpy() - numpy_string_row = numpy.array2string(numpy_row, precision=64, separator="\t") - filetxt += "\n" + numpy_string_row.lstrip("[").rstrip("]") - - filename = os.path.join( - path, "output" + str(configuration + configuration_start) + ".txt" - ) - with open(filename, "w") as f: - f.write(filetxt) - - -"""Function Calculation""" - - -class AtomicStructureHandler: - def __init__( - self, list_atom_types, bravais_lattice_constants, radius_cutoff, formula - ): - - self.bravais_lattice_constants = bravais_lattice_constants - self.radius_cutoff = radius_cutoff - self.formula = formula - - def compute(self, data): - - assert data.pos.shape[0] == data.x.shape[0] - - interatomic_potential = torch.zeros([data.pos.shape[0], 1]) - interatomic_forces = torch.zeros([data.pos.shape[0], 3]) - - for node_id in range(data.pos.shape[0]): - - neighbor_list_indices = torch.where(data.edge_index[0, :] == node_id)[ - 0 - ].tolist() - neighbor_list = data.edge_index[1, neighbor_list_indices] - - for neighbor_id, edge_id in zip(neighbor_list, neighbor_list_indices): - - neighbor_pos = data.pos[neighbor_id, :] - distance_vector = data.pos[neighbor_id, :] - data.pos[node_id, :] - - # Adjust the neighbor position based on periodic boundary conditions (PBC) - ## If the distance between the atoms is larger than the cutoff radius, the edge is because of PBC conditions - if torch.norm(distance_vector) > self.radius_cutoff: - ## At this point, we know that the edge is due to PBC conditions, so we need to adjust the neighbor position. We also know that - ## that this connection MUST be the closest connection possible as a result of the asserted radius_cutoff < supercell_size earlier - ## in the code. Because of this, we can simply adjust the neighbor position coordinate-wise to be closer than - ## as done in the following lines of code. The logic goes that if the distance vector[index] is larger than half the supercell size, - ## then there is a closer distance at +- supercell_size[index], and we adjust to that for each coordinate - if abs(distance_vector[0]) > data.supercell_size[0, 0] / 2: - if distance_vector[0] > 0: - neighbor_pos[0] -= data.supercell_size[0, 0] - else: - neighbor_pos[0] += data.supercell_size[0, 0] - - if abs(distance_vector[1]) > data.supercell_size[1, 1] / 2: - if distance_vector[1] > 0: - neighbor_pos[1] -= data.supercell_size[1, 1] - else: - neighbor_pos[1] += data.supercell_size[1, 1] - - if abs(distance_vector[2]) > data.supercell_size[2, 2] / 2: - if distance_vector[2] > 0: - neighbor_pos[2] -= data.supercell_size[2, 2] - else: - neighbor_pos[2] += data.supercell_size[2, 2] - - # The distance vecor may need to be updated after applying PBCs - distance_vector = data.pos[node_id, :] - neighbor_pos - - # pair_distance = data.edge_attr[edge_id].item() - interatomic_potential[node_id] += self.formula.potential_energy( - distance_vector - ) - - derivative_x = self.formula.derivative_x(distance_vector) - derivative_y = self.formula.derivative_y(distance_vector) - derivative_z = self.formula.derivative_z(distance_vector) - - interatomic_forces_contribution_x = -derivative_x - interatomic_forces_contribution_y = -derivative_y - interatomic_forces_contribution_z = -derivative_z - - interatomic_forces[node_id, 0] += interatomic_forces_contribution_x - interatomic_forces[node_id, 1] += interatomic_forces_contribution_y - interatomic_forces[node_id, 2] += interatomic_forces_contribution_z - - data.x = torch.cat( - (data.x, interatomic_potential, interatomic_forces), - 1, - ) - - return data - - -class LJpotential: - def __init__(self, epsilon, sigma): - self.epsilon = epsilon - self.sigma = sigma - - def potential_energy(self, distance_vector): - pair_distance = torch.norm(distance_vector) - return ( - 4 - * self.epsilon - * ((self.sigma / pair_distance) ** 12 - (self.sigma / pair_distance) ** 6) - ) - - def radial_derivative(self, distance_vector): - pair_distance = torch.norm(distance_vector) - return ( - 4 - * self.epsilon - * ( - -12 * (self.sigma / pair_distance) ** 12 * 1 / pair_distance - + 6 * (self.sigma / pair_distance) ** 6 * 1 / pair_distance - ) - ) - - def derivative_x(self, distance_vector): - pair_distance = torch.norm(distance_vector) - radial_derivative = self.radial_derivative(pair_distance) - return radial_derivative * (distance_vector[0].item()) / pair_distance - - def derivative_y(self, distance_vector): - pair_distance = torch.norm(distance_vector) - radial_derivative = self.radial_derivative(pair_distance) - return radial_derivative * (distance_vector[1].item()) / pair_distance - - def derivative_z(self, distance_vector): - pair_distance = torch.norm(distance_vector) - radial_derivative = self.radial_derivative(pair_distance) - return radial_derivative * (distance_vector[2].item()) / pair_distance - - -"""Etc""" - - -def info(*args, logtype="info", sep=" "): - getattr(logging, logtype)(sep.join(map(str, args))) diff --git a/examples/LJ_inference_plots.py b/examples/LJ_inference_plots.py deleted file mode 100644 index 324da425f..000000000 --- a/examples/LJ_inference_plots.py +++ /dev/null @@ -1,241 +0,0 @@ -############################################################################## -# Copyright (c) 2024, Oak Ridge National Laboratory # -# All rights reserved. # -# # -# This file is part of HydraGNN and is distributed under a BSD 3-clause # -# license. For the licensing terms see the LICENSE file in the top-level # -# directory. # -# # -# SPDX-License-Identifier: BSD-3-Clause # -############################################################################## - -import json, os -import sys -import logging -import pickle -from tqdm import tqdm -from mpi4py import MPI -import argparse - -import torch -import torch_scatter -import numpy as np - -import hydragnn -from hydragnn.utils.time_utils import Timer -from hydragnn.utils.distributed import get_device -from hydragnn.utils.model import load_existing_model -from hydragnn.utils.pickledataset import SimplePickleDataset -from hydragnn.utils.config_utils import ( - update_config, -) -from hydragnn.models.create import create_model_config -from hydragnn.preprocess import create_dataloaders - -from scipy.interpolate import griddata - -try: - from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset -except ImportError: - pass - -from LJ_data import info - -import matplotlib.pyplot as plt - -plt.rcParams.update({"font.size": 16}) - - -def get_log_name_config(config): - return ( - config["NeuralNetwork"]["Architecture"]["model_type"] - + "-r-" - + str(config["NeuralNetwork"]["Architecture"]["radius"]) - + "-ncl-" - + str(config["NeuralNetwork"]["Architecture"]["num_conv_layers"]) - + "-hd-" - + str(config["NeuralNetwork"]["Architecture"]["hidden_dim"]) - + "-ne-" - + str(config["NeuralNetwork"]["Training"]["num_epoch"]) - + "-lr-" - + str(config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"]) - + "-bs-" - + str(config["NeuralNetwork"]["Training"]["batch_size"]) - + "-node_ft-" - + "".join( - str(x) - for x in config["NeuralNetwork"]["Variables_of_interest"][ - "input_node_features" - ] - ) - + "-task_weights-" - + "".join( - str(weigh) + "-" - for weigh in config["NeuralNetwork"]["Architecture"]["task_weights"] - ) - ) - - -def getcolordensity(xdata, ydata): - ############################### - nbin = 20 - hist2d, xbins_edge, ybins_edge = np.histogram2d(x=xdata, y=ydata, bins=[nbin, nbin]) - xbin_cen = 0.5 * (xbins_edge[0:-1] + xbins_edge[1:]) - ybin_cen = 0.5 * (ybins_edge[0:-1] + ybins_edge[1:]) - BCTY, BCTX = np.meshgrid(ybin_cen, xbin_cen) - hist2d = hist2d / np.amax(hist2d) - print(np.amax(hist2d)) - - bctx1d = np.reshape(BCTX, len(xbin_cen) * nbin) - bcty1d = np.reshape(BCTY, len(xbin_cen) * nbin) - loc_pts = np.zeros((len(xbin_cen) * nbin, 2)) - loc_pts[:, 0] = bctx1d - loc_pts[:, 1] = bcty1d - hist2d_norm = griddata( - loc_pts, - hist2d.reshape(len(xbin_cen) * nbin), - (xdata, ydata), - method="linear", - fill_value=0, - ) # np.nan) - return hist2d_norm - - -if __name__ == "__main__": - - modelname = "LJ" - - parser = argparse.ArgumentParser() - parser.add_argument( - "--inputfile", help="input file", type=str, default="./logs/LJ/config.json" - ) - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--adios", - help="Adios gan_dataset", - action="store_const", - dest="format", - const="adios", - ) - group.add_argument( - "--pickle", - help="Pickle gan_dataset", - action="store_const", - dest="format", - const="pickle", - ) - parser.set_defaults(format="pickle") - - args = parser.parse_args() - - dirpwd = os.path.dirname(os.path.abspath(__file__)) - input_filename = os.path.join(dirpwd, args.inputfile) - with open(input_filename, "r") as f: - config = json.load(f) - hydragnn.utils.setup_log(get_log_name_config(config)) - ################################################################################################################## - # Always initialize for multi-rank training. - comm_size, rank = hydragnn.utils.setup_ddp() - ################################################################################################################## - comm = MPI.COMM_WORLD - - datasetname = "LJ" - - comm.Barrier() - - timer = Timer("load_data") - timer.start() - if args.format == "pickle": - info("Pickle load") - basedir = os.path.join( - os.path.dirname(__file__), "dataset", "%s.pickle" % modelname - ) - trainset = SimplePickleDataset( - basedir=basedir, - label="trainset", - var_config=config["NeuralNetwork"]["Variables_of_interest"], - ) - valset = SimplePickleDataset( - basedir=basedir, - label="valset", - var_config=config["NeuralNetwork"]["Variables_of_interest"], - ) - testset = SimplePickleDataset( - basedir=basedir, - label="testset", - var_config=config["NeuralNetwork"]["Variables_of_interest"], - ) - pna_deg = trainset.pna_deg - else: - raise NotImplementedError("No supported format: %s" % (args.format)) - - model = create_model_config( - config=config["NeuralNetwork"], - verbosity=config["Verbosity"]["level"], - ) - - model = torch.nn.parallel.DistributedDataParallel(model) - - load_existing_model(model, modelname, path="./logs/") - model.eval() - - variable_index = 0 - # for output_name, output_type, output_dim in zip(config["NeuralNetwork"]["Variables_of_interest"]["output_names"], config["NeuralNetwork"]["Variables_of_interest"]["type"], config["NeuralNetwork"]["Variables_of_interest"]["output_dim"]): - - test_MAE = 0.0 - - num_samples = len(testset) - energy_true_list = [] - energy_pred_list = [] - forces_true_list = [] - forces_pred_list = [] - - for data_id, data in enumerate(tqdm(testset)): - data.pos.requires_grad = True - node_energy_pred = model(data.to(get_device()))[ - 0 - ] # Note that this is sensitive to energy and forces prediction being single-task (current requirement) - energy_pred = torch.sum(node_energy_pred, dim=0).float() - test_MAE += torch.norm(energy_pred - data.energy, p=1).item() / len(testset) - # predicted.backward(retain_graph=True) - # gradients = data.pos.grad - grads_energy = torch.autograd.grad( - outputs=energy_pred, - inputs=data.pos, - grad_outputs=torch.ones_like(energy_pred), - retain_graph=False, - create_graph=True, - )[0] - energy_pred_list.extend(energy_pred.tolist()) - energy_true_list.extend(data.energy.tolist()) - forces_pred_list.extend((-grads_energy).flatten().tolist()) - forces_true_list.extend(data.forces.flatten().tolist()) - - hist2d_norm = getcolordensity(energy_true_list, energy_pred_list) - - fig, ax = plt.subplots() - plt.scatter(energy_true_list, energy_pred_list, s=8, c=hist2d_norm, vmin=0, vmax=1) - plt.clim(0, 1) - ax.plot(ax.get_xlim(), ax.get_xlim(), ls="--", color="red") - plt.colorbar() - plt.xlabel("True values") - plt.ylabel("Predicted values") - plt.title(f"energy") - plt.draw() - plt.tight_layout() - plt.savefig(f"./energy_Scatterplot" + ".png", dpi=400) - - print(f"Test MAE energy: ", test_MAE) - - hist2d_norm = getcolordensity(forces_pred_list, forces_true_list) - fig, ax = plt.subplots() - plt.scatter(forces_pred_list, forces_true_list, s=8, c=hist2d_norm, vmin=0, vmax=1) - plt.clim(0, 1) - ax.plot(ax.get_xlim(), ax.get_xlim(), ls="--", color="red") - plt.colorbar() - plt.xlabel("Predicted Values") - plt.ylabel("True Values") - plt.title("Forces") - plt.draw() - plt.tight_layout() - plt.savefig(f"./Forces_Scatterplot" + ".png", dpi=400) From d037fd331218add864f609382968b49a7dad1118 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 13:12:50 -0400 Subject: [PATCH 20/51] commenting things that aren't needed for MACE in HYDRA (draft 1) --- hydragnn/models/MACEStack.py | 6 +- hydragnn/utils/mace_utils/data/__init__.py | 66 +- hydragnn/utils/mace_utils/data/atomic_data.py | 454 ++-- .../utils/mace_utils/data/hdf5_dataset.py | 154 +- .../utils/mace_utils/data/neighborhood.py | 108 +- hydragnn/utils/mace_utils/data/utils.py | 796 +++--- hydragnn/utils/mace_utils/modules/__init__.py | 62 +- hydragnn/utils/mace_utils/modules/loss.py | 734 +++--- hydragnn/utils/mace_utils/modules/models.py | 2130 ++++++++--------- hydragnn/utils/mace_utils/modules/utils.py | 694 +++--- hydragnn/utils/mace_utils/tools/__init__.py | 38 +- hydragnn/utils/mace_utils/tools/arg_parser.py | 1552 ++++++------ .../mace_utils/tools/arg_parser_tools.py | 210 +- hydragnn/utils/mace_utils/tools/checkpoint.py | 454 ++-- .../utils/mace_utils/tools/scripts_utils.py | 1306 +++++----- .../mace_utils/tools/slurm_distributed.py | 58 +- hydragnn/utils/mace_utils/tools/train.py | 1048 ++++---- tests/test_forces_equivariant.py | 28 + 18 files changed, 4964 insertions(+), 4934 deletions(-) create mode 100644 tests/test_forces_equivariant.py diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 11b78591a..aaeffee9d 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -190,9 +190,10 @@ def _init_conv(self): self.graph_convs.append( self.get_conv(self.input_dim, self.hidden_dim, first_layer=True) ) + irreps = hidden_irreps if not last_layer else final_hidden_irreps self.multihead_decoders.append( MultiheadDecoderBlock( - hidden_irreps, + irreps, self.node_max_ell, self.config_heads, self.head_dims, @@ -209,9 +210,10 @@ def _init_conv(self): self.hidden_dim, self.hidden_dim, last_layer=last_layer ) self.graph_convs.append(conv) + irreps = hidden_irreps if not last_layer else final_hidden_irreps self.multihead_decoders.append( MultiheadDecoderBlock( - final_hidden_irreps, + irreps, self.node_max_ell, self.config_heads, self.head_dims, diff --git a/hydragnn/utils/mace_utils/data/__init__.py b/hydragnn/utils/mace_utils/data/__init__.py index c10a36982..ace87d766 100644 --- a/hydragnn/utils/mace_utils/data/__init__.py +++ b/hydragnn/utils/mace_utils/data/__init__.py @@ -1,34 +1,34 @@ -from .atomic_data import AtomicData -from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 -from .neighborhood import get_neighborhood -from .utils import ( - Configuration, - Configurations, - compute_average_E0s, - config_from_atoms, - config_from_atoms_list, - load_from_xyz, - random_train_valid_split, - save_AtomicData_to_HDF5, - save_configurations_as_HDF5, - save_dataset_as_HDF5, - test_config_types, -) +# from .atomic_data import AtomicData +# from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 +# from .neighborhood import get_neighborhood +# from .utils import ( +# Configuration, +# Configurations, +# compute_average_E0s, +# config_from_atoms, +# config_from_atoms_list, +# load_from_xyz, +# random_train_valid_split, +# save_AtomicData_to_HDF5, +# save_configurations_as_HDF5, +# save_dataset_as_HDF5, +# test_config_types, +# ) -__all__ = [ - "get_neighborhood", - "Configuration", - "Configurations", - "random_train_valid_split", - "load_from_xyz", - "test_config_types", - "config_from_atoms", - "config_from_atoms_list", - "AtomicData", - "compute_average_E0s", - "save_dataset_as_HDF5", - "HDF5Dataset", - "dataset_from_sharded_hdf5", - "save_AtomicData_to_HDF5", - "save_configurations_as_HDF5", -] +# __all__ = [ +# "get_neighborhood", +# "Configuration", +# "Configurations", +# "random_train_valid_split", +# "load_from_xyz", +# "test_config_types", +# "config_from_atoms", +# "config_from_atoms_list", +# "AtomicData", +# "compute_average_E0s", +# "save_dataset_as_HDF5", +# "HDF5Dataset", +# "dataset_from_sharded_hdf5", +# "save_AtomicData_to_HDF5", +# "save_configurations_as_HDF5", +# ] diff --git a/hydragnn/utils/mace_utils/data/atomic_data.py b/hydragnn/utils/mace_utils/data/atomic_data.py index 01815fdc5..ded8d438d 100644 --- a/hydragnn/utils/mace_utils/data/atomic_data.py +++ b/hydragnn/utils/mace_utils/data/atomic_data.py @@ -1,227 +1,227 @@ -########################################################################################### -# Atomic Data Class for handling molecules as graphs -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Optional, Sequence - -import torch.utils.data - -from hydragnn.utils.mace_utils.tools import ( - AtomicNumberTable, - atomic_numbers_to_indices, - to_one_hot, - torch_geometric, - voigt_to_matrix, -) - -from .neighborhood import get_neighborhood -from .utils import Configuration - - -class AtomicData(torch_geometric.data.Data): - num_graphs: torch.Tensor - batch: torch.Tensor - edge_index: torch.Tensor - node_attrs: torch.Tensor - edge_vectors: torch.Tensor - edge_lengths: torch.Tensor - positions: torch.Tensor - shifts: torch.Tensor - unit_shifts: torch.Tensor - cell: torch.Tensor - forces: torch.Tensor - energy: torch.Tensor - stress: torch.Tensor - virials: torch.Tensor - dipole: torch.Tensor - charges: torch.Tensor - weight: torch.Tensor - energy_weight: torch.Tensor - forces_weight: torch.Tensor - stress_weight: torch.Tensor - virials_weight: torch.Tensor - - def __init__( - self, - edge_index: torch.Tensor, # [2, n_edges] - node_attrs: torch.Tensor, # [n_nodes, n_node_feats] - positions: torch.Tensor, # [n_nodes, 3] - shifts: torch.Tensor, # [n_edges, 3], - unit_shifts: torch.Tensor, # [n_edges, 3] - cell: Optional[torch.Tensor], # [3,3] - weight: Optional[torch.Tensor], # [,] - energy_weight: Optional[torch.Tensor], # [,] - forces_weight: Optional[torch.Tensor], # [,] - stress_weight: Optional[torch.Tensor], # [,] - virials_weight: Optional[torch.Tensor], # [,] - forces: Optional[torch.Tensor], # [n_nodes, 3] - energy: Optional[torch.Tensor], # [, ] - stress: Optional[torch.Tensor], # [1,3,3] - virials: Optional[torch.Tensor], # [1,3,3] - dipole: Optional[torch.Tensor], # [, 3] - charges: Optional[torch.Tensor], # [n_nodes, ] - ): - # Check shapes - num_nodes = node_attrs.shape[0] - - assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 - assert positions.shape == (num_nodes, 3) - assert shifts.shape[1] == 3 - assert unit_shifts.shape[1] == 3 - assert len(node_attrs.shape) == 2 - assert weight is None or len(weight.shape) == 0 - assert energy_weight is None or len(energy_weight.shape) == 0 - assert forces_weight is None or len(forces_weight.shape) == 0 - assert stress_weight is None or len(stress_weight.shape) == 0 - assert virials_weight is None or len(virials_weight.shape) == 0 - assert cell is None or cell.shape == (3, 3) - assert forces is None or forces.shape == (num_nodes, 3) - assert energy is None or len(energy.shape) == 0 - assert stress is None or stress.shape == (1, 3, 3) - assert virials is None or virials.shape == (1, 3, 3) - assert dipole is None or dipole.shape[-1] == 3 - assert charges is None or charges.shape == (num_nodes,) - # Aggregate data - data = { - "num_nodes": num_nodes, - "edge_index": edge_index, - "positions": positions, - "shifts": shifts, - "unit_shifts": unit_shifts, - "cell": cell, - "node_attrs": node_attrs, - "weight": weight, - "energy_weight": energy_weight, - "forces_weight": forces_weight, - "stress_weight": stress_weight, - "virials_weight": virials_weight, - "forces": forces, - "energy": energy, - "stress": stress, - "virials": virials, - "dipole": dipole, - "charges": charges, - } - super().__init__(**data) - - @classmethod - def from_config( - cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float - ) -> "AtomicData": - edge_index, shifts, unit_shifts = get_neighborhood( - positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell - ) - indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) - one_hot = to_one_hot( - torch.tensor(indices, dtype=torch.long).unsqueeze(-1), - num_classes=len(z_table), - ) - - cell = ( - torch.tensor(config.cell, dtype=torch.get_default_dtype()) - if config.cell is not None - else torch.tensor( - 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() - ).view(3, 3) - ) - - weight = ( - torch.tensor(config.weight, dtype=torch.get_default_dtype()) - if config.weight is not None - else 1 - ) - - energy_weight = ( - torch.tensor(config.energy_weight, dtype=torch.get_default_dtype()) - if config.energy_weight is not None - else 1 - ) - - forces_weight = ( - torch.tensor(config.forces_weight, dtype=torch.get_default_dtype()) - if config.forces_weight is not None - else 1 - ) - - stress_weight = ( - torch.tensor(config.stress_weight, dtype=torch.get_default_dtype()) - if config.stress_weight is not None - else 1 - ) - - virials_weight = ( - torch.tensor(config.virials_weight, dtype=torch.get_default_dtype()) - if config.virials_weight is not None - else 1 - ) - - forces = ( - torch.tensor(config.forces, dtype=torch.get_default_dtype()) - if config.forces is not None - else None - ) - energy = ( - torch.tensor(config.energy, dtype=torch.get_default_dtype()) - if config.energy is not None - else None - ) - stress = ( - voigt_to_matrix( - torch.tensor(config.stress, dtype=torch.get_default_dtype()) - ).unsqueeze(0) - if config.stress is not None - else None - ) - virials = ( - voigt_to_matrix( - torch.tensor(config.virials, dtype=torch.get_default_dtype()) - ).unsqueeze(0) - if config.virials is not None - else None - ) - dipole = ( - torch.tensor(config.dipole, dtype=torch.get_default_dtype()).unsqueeze(0) - if config.dipole is not None - else None - ) - charges = ( - torch.tensor(config.charges, dtype=torch.get_default_dtype()) - if config.charges is not None - else None - ) - - return cls( - edge_index=torch.tensor(edge_index, dtype=torch.long), - positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), - shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), - unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), - cell=cell, - node_attrs=one_hot, - weight=weight, - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, - virials_weight=virials_weight, - forces=forces, - energy=energy, - stress=stress, - virials=virials, - dipole=dipole, - charges=charges, - ) - - -def get_data_loader( - dataset: Sequence[AtomicData], - batch_size: int, - shuffle=True, - drop_last=False, -) -> torch.utils.data.DataLoader: - return torch_geometric.dataloader.DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - ) +# ########################################################################################### +# # Atomic Data Class for handling molecules as graphs +# # Authors: Ilyes Batatia, Gregor Simm +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### + +# from typing import Optional, Sequence + +# import torch.utils.data + +# from hydragnn.utils.mace_utils.tools import ( +# AtomicNumberTable, +# atomic_numbers_to_indices, +# to_one_hot, +# torch_geometric, +# voigt_to_matrix, +# ) + +# from .neighborhood import get_neighborhood +# from .utils import Configuration + + +# class AtomicData(torch_geometric.data.Data): +# num_graphs: torch.Tensor +# batch: torch.Tensor +# edge_index: torch.Tensor +# node_attrs: torch.Tensor +# edge_vectors: torch.Tensor +# edge_lengths: torch.Tensor +# positions: torch.Tensor +# shifts: torch.Tensor +# unit_shifts: torch.Tensor +# cell: torch.Tensor +# forces: torch.Tensor +# energy: torch.Tensor +# stress: torch.Tensor +# virials: torch.Tensor +# dipole: torch.Tensor +# charges: torch.Tensor +# weight: torch.Tensor +# energy_weight: torch.Tensor +# forces_weight: torch.Tensor +# stress_weight: torch.Tensor +# virials_weight: torch.Tensor + +# def __init__( +# self, +# edge_index: torch.Tensor, # [2, n_edges] +# node_attrs: torch.Tensor, # [n_nodes, n_node_feats] +# positions: torch.Tensor, # [n_nodes, 3] +# shifts: torch.Tensor, # [n_edges, 3], +# unit_shifts: torch.Tensor, # [n_edges, 3] +# cell: Optional[torch.Tensor], # [3,3] +# weight: Optional[torch.Tensor], # [,] +# energy_weight: Optional[torch.Tensor], # [,] +# forces_weight: Optional[torch.Tensor], # [,] +# stress_weight: Optional[torch.Tensor], # [,] +# virials_weight: Optional[torch.Tensor], # [,] +# forces: Optional[torch.Tensor], # [n_nodes, 3] +# energy: Optional[torch.Tensor], # [, ] +# stress: Optional[torch.Tensor], # [1,3,3] +# virials: Optional[torch.Tensor], # [1,3,3] +# dipole: Optional[torch.Tensor], # [, 3] +# charges: Optional[torch.Tensor], # [n_nodes, ] +# ): +# # Check shapes +# num_nodes = node_attrs.shape[0] + +# assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 +# assert positions.shape == (num_nodes, 3) +# assert shifts.shape[1] == 3 +# assert unit_shifts.shape[1] == 3 +# assert len(node_attrs.shape) == 2 +# assert weight is None or len(weight.shape) == 0 +# assert energy_weight is None or len(energy_weight.shape) == 0 +# assert forces_weight is None or len(forces_weight.shape) == 0 +# assert stress_weight is None or len(stress_weight.shape) == 0 +# assert virials_weight is None or len(virials_weight.shape) == 0 +# assert cell is None or cell.shape == (3, 3) +# assert forces is None or forces.shape == (num_nodes, 3) +# assert energy is None or len(energy.shape) == 0 +# assert stress is None or stress.shape == (1, 3, 3) +# assert virials is None or virials.shape == (1, 3, 3) +# assert dipole is None or dipole.shape[-1] == 3 +# assert charges is None or charges.shape == (num_nodes,) +# # Aggregate data +# data = { +# "num_nodes": num_nodes, +# "edge_index": edge_index, +# "positions": positions, +# "shifts": shifts, +# "unit_shifts": unit_shifts, +# "cell": cell, +# "node_attrs": node_attrs, +# "weight": weight, +# "energy_weight": energy_weight, +# "forces_weight": forces_weight, +# "stress_weight": stress_weight, +# "virials_weight": virials_weight, +# "forces": forces, +# "energy": energy, +# "stress": stress, +# "virials": virials, +# "dipole": dipole, +# "charges": charges, +# } +# super().__init__(**data) + +# @classmethod +# def from_config( +# cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float +# ) -> "AtomicData": +# edge_index, shifts, unit_shifts = get_neighborhood( +# positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell +# ) +# indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) +# one_hot = to_one_hot( +# torch.tensor(indices, dtype=torch.long).unsqueeze(-1), +# num_classes=len(z_table), +# ) + +# cell = ( +# torch.tensor(config.cell, dtype=torch.get_default_dtype()) +# if config.cell is not None +# else torch.tensor( +# 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() +# ).view(3, 3) +# ) + +# weight = ( +# torch.tensor(config.weight, dtype=torch.get_default_dtype()) +# if config.weight is not None +# else 1 +# ) + +# energy_weight = ( +# torch.tensor(config.energy_weight, dtype=torch.get_default_dtype()) +# if config.energy_weight is not None +# else 1 +# ) + +# forces_weight = ( +# torch.tensor(config.forces_weight, dtype=torch.get_default_dtype()) +# if config.forces_weight is not None +# else 1 +# ) + +# stress_weight = ( +# torch.tensor(config.stress_weight, dtype=torch.get_default_dtype()) +# if config.stress_weight is not None +# else 1 +# ) + +# virials_weight = ( +# torch.tensor(config.virials_weight, dtype=torch.get_default_dtype()) +# if config.virials_weight is not None +# else 1 +# ) + +# forces = ( +# torch.tensor(config.forces, dtype=torch.get_default_dtype()) +# if config.forces is not None +# else None +# ) +# energy = ( +# torch.tensor(config.energy, dtype=torch.get_default_dtype()) +# if config.energy is not None +# else None +# ) +# stress = ( +# voigt_to_matrix( +# torch.tensor(config.stress, dtype=torch.get_default_dtype()) +# ).unsqueeze(0) +# if config.stress is not None +# else None +# ) +# virials = ( +# voigt_to_matrix( +# torch.tensor(config.virials, dtype=torch.get_default_dtype()) +# ).unsqueeze(0) +# if config.virials is not None +# else None +# ) +# dipole = ( +# torch.tensor(config.dipole, dtype=torch.get_default_dtype()).unsqueeze(0) +# if config.dipole is not None +# else None +# ) +# charges = ( +# torch.tensor(config.charges, dtype=torch.get_default_dtype()) +# if config.charges is not None +# else None +# ) + +# return cls( +# edge_index=torch.tensor(edge_index, dtype=torch.long), +# positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), +# shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), +# unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), +# cell=cell, +# node_attrs=one_hot, +# weight=weight, +# energy_weight=energy_weight, +# forces_weight=forces_weight, +# stress_weight=stress_weight, +# virials_weight=virials_weight, +# forces=forces, +# energy=energy, +# stress=stress, +# virials=virials, +# dipole=dipole, +# charges=charges, +# ) + + +# def get_data_loader( +# dataset: Sequence[AtomicData], +# batch_size: int, +# shuffle=True, +# drop_last=False, +# ) -> torch.utils.data.DataLoader: +# return torch_geometric.dataloader.DataLoader( +# dataset=dataset, +# batch_size=batch_size, +# shuffle=shuffle, +# drop_last=drop_last, +# ) diff --git a/hydragnn/utils/mace_utils/data/hdf5_dataset.py b/hydragnn/utils/mace_utils/data/hdf5_dataset.py index affa6a8d5..f617e02f5 100644 --- a/hydragnn/utils/mace_utils/data/hdf5_dataset.py +++ b/hydragnn/utils/mace_utils/data/hdf5_dataset.py @@ -1,91 +1,91 @@ -from glob import glob -from typing import List +# from glob import glob +# from typing import List -from torch.utils.data import ConcatDataset, Dataset +# from torch.utils.data import ConcatDataset, Dataset -# Try import but pass otherwise -try: - import h5py -except ImportError: - pass +# # Try import but pass otherwise +# try: +# import h5py +# except ImportError: +# pass -from hydragnn.utils.mace_utils.data.atomic_data import AtomicData -from hydragnn.utils.mace_utils.data.utils import Configuration -from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable +# from hydragnn.utils.mace_utils.data.atomic_data import AtomicData +# from hydragnn.utils.mace_utils.data.utils import Configuration +# from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable -class HDF5Dataset(Dataset): - def __init__(self, file_path, r_max, z_table, **kwargs): - super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments - self.file_path = file_path - self._file = None - batch_key = list(self.file.keys())[0] - self.batch_size = len(self.file[batch_key].keys()) - self.length = len(self.file.keys()) * self.batch_size - self.r_max = r_max - self.z_table = z_table - try: - self.drop_last = bool(self.file.attrs["drop_last"]) - except KeyError: - self.drop_last = False - self.kwargs = kwargs +# class HDF5Dataset(Dataset): +# def __init__(self, file_path, r_max, z_table, **kwargs): +# super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments +# self.file_path = file_path +# self._file = None +# batch_key = list(self.file.keys())[0] +# self.batch_size = len(self.file[batch_key].keys()) +# self.length = len(self.file.keys()) * self.batch_size +# self.r_max = r_max +# self.z_table = z_table +# try: +# self.drop_last = bool(self.file.attrs["drop_last"]) +# except KeyError: +# self.drop_last = False +# self.kwargs = kwargs - @property - def file(self): - if self._file is None: - # If a file has not already been opened, open one here - self._file = h5py.File(self.file_path, "r") - return self._file +# @property +# def file(self): +# if self._file is None: +# # If a file has not already been opened, open one here +# self._file = h5py.File(self.file_path, "r") +# return self._file - def __getstate__(self): - _d = dict(self.__dict__) +# def __getstate__(self): +# _d = dict(self.__dict__) - # An opened h5py.File cannot be pickled, so we must exclude it from the state - _d["_file"] = None - return _d +# # An opened h5py.File cannot be pickled, so we must exclude it from the state +# _d["_file"] = None +# return _d - def __len__(self): - return self.length +# def __len__(self): +# return self.length - def __getitem__(self, index): - # compute the index of the batch - batch_index = index // self.batch_size - config_index = index % self.batch_size - grp = self.file["config_batch_" + str(batch_index)] - subgrp = grp["config_" + str(config_index)] - config = Configuration( - atomic_numbers=subgrp["atomic_numbers"][()], - positions=subgrp["positions"][()], - energy=unpack_value(subgrp["energy"][()]), - forces=unpack_value(subgrp["forces"][()]), - stress=unpack_value(subgrp["stress"][()]), - virials=unpack_value(subgrp["virials"][()]), - dipole=unpack_value(subgrp["dipole"][()]), - charges=unpack_value(subgrp["charges"][()]), - weight=unpack_value(subgrp["weight"][()]), - energy_weight=unpack_value(subgrp["energy_weight"][()]), - forces_weight=unpack_value(subgrp["forces_weight"][()]), - stress_weight=unpack_value(subgrp["stress_weight"][()]), - virials_weight=unpack_value(subgrp["virials_weight"][()]), - config_type=unpack_value(subgrp["config_type"][()]), - pbc=unpack_value(subgrp["pbc"][()]), - cell=unpack_value(subgrp["cell"][()]), - ) - atomic_data = AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max - ) - return atomic_data +# def __getitem__(self, index): +# # compute the index of the batch +# batch_index = index // self.batch_size +# config_index = index % self.batch_size +# grp = self.file["config_batch_" + str(batch_index)] +# subgrp = grp["config_" + str(config_index)] +# config = Configuration( +# atomic_numbers=subgrp["atomic_numbers"][()], +# positions=subgrp["positions"][()], +# energy=unpack_value(subgrp["energy"][()]), +# forces=unpack_value(subgrp["forces"][()]), +# stress=unpack_value(subgrp["stress"][()]), +# virials=unpack_value(subgrp["virials"][()]), +# dipole=unpack_value(subgrp["dipole"][()]), +# charges=unpack_value(subgrp["charges"][()]), +# weight=unpack_value(subgrp["weight"][()]), +# energy_weight=unpack_value(subgrp["energy_weight"][()]), +# forces_weight=unpack_value(subgrp["forces_weight"][()]), +# stress_weight=unpack_value(subgrp["stress_weight"][()]), +# virials_weight=unpack_value(subgrp["virials_weight"][()]), +# config_type=unpack_value(subgrp["config_type"][()]), +# pbc=unpack_value(subgrp["pbc"][()]), +# cell=unpack_value(subgrp["cell"][()]), +# ) +# atomic_data = AtomicData.from_config( +# config, z_table=self.z_table, cutoff=self.r_max +# ) +# return atomic_data -def dataset_from_sharded_hdf5(files: List, z_table: AtomicNumberTable, r_max: float): - files = glob(files + "/*") - datasets = [] - for file in files: - datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max)) - full_dataset = ConcatDataset(datasets) - return full_dataset +# def dataset_from_sharded_hdf5(files: List, z_table: AtomicNumberTable, r_max: float): +# files = glob(files + "/*") +# datasets = [] +# for file in files: +# datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max)) +# full_dataset = ConcatDataset(datasets) +# return full_dataset -def unpack_value(value): - value = value.decode("utf-8") if isinstance(value, bytes) else value - return None if str(value) == "None" else value +# def unpack_value(value): +# value = value.decode("utf-8") if isinstance(value, bytes) else value +# return None if str(value) == "None" else value diff --git a/hydragnn/utils/mace_utils/data/neighborhood.py b/hydragnn/utils/mace_utils/data/neighborhood.py index 293576af4..5bd70b6eb 100644 --- a/hydragnn/utils/mace_utils/data/neighborhood.py +++ b/hydragnn/utils/mace_utils/data/neighborhood.py @@ -1,66 +1,66 @@ -from typing import Optional, Tuple +# from typing import Optional, Tuple -import numpy as np -from matscipy.neighbours import neighbour_list +# import numpy as np +# from matscipy.neighbours import neighbour_list -def get_neighborhood( - positions: np.ndarray, # [num_positions, 3] - cutoff: float, - pbc: Optional[Tuple[bool, bool, bool]] = None, - cell: Optional[np.ndarray] = None, # [3, 3] - true_self_interaction=False, -) -> Tuple[np.ndarray, np.ndarray]: - if pbc is None: - pbc = (False, False, False) +# def get_neighborhood( +# positions: np.ndarray, # [num_positions, 3] +# cutoff: float, +# pbc: Optional[Tuple[bool, bool, bool]] = None, +# cell: Optional[np.ndarray] = None, # [3, 3] +# true_self_interaction=False, +# ) -> Tuple[np.ndarray, np.ndarray]: +# if pbc is None: +# pbc = (False, False, False) - if cell is None or cell.any() == np.zeros((3, 3)).any(): - cell = np.identity(3, dtype=float) +# if cell is None or cell.any() == np.zeros((3, 3)).any(): +# cell = np.identity(3, dtype=float) - assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) - assert cell.shape == (3, 3) +# assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) +# assert cell.shape == (3, 3) - pbc_x = pbc[0] - pbc_y = pbc[1] - pbc_z = pbc[2] - identity = np.identity(3, dtype=float) - max_positions = np.max(np.absolute(positions)) + 1 - # Extend cell in non-periodic directions - # For models with more than 5 layers, the multiplicative constant needs to be increased. - temp_cell = np.copy(cell) - if not pbc_x: - temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :] - if not pbc_y: - temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :] - if not pbc_z: - temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :] +# pbc_x = pbc[0] +# pbc_y = pbc[1] +# pbc_z = pbc[2] +# identity = np.identity(3, dtype=float) +# max_positions = np.max(np.absolute(positions)) + 1 +# # Extend cell in non-periodic directions +# # For models with more than 5 layers, the multiplicative constant needs to be increased. +# temp_cell = np.copy(cell) +# if not pbc_x: +# temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :] +# if not pbc_y: +# temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :] +# if not pbc_z: +# temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :] - sender, receiver, unit_shifts = neighbour_list( - quantities="ijS", - pbc=pbc, - cell=temp_cell, - positions=positions, - cutoff=cutoff, - # self_interaction=True, # we want edges from atom to itself in different periodic images - # use_scaled_positions=False, # positions are not scaled positions - ) +# sender, receiver, unit_shifts = neighbour_list( +# quantities="ijS", +# pbc=pbc, +# cell=temp_cell, +# positions=positions, +# cutoff=cutoff, +# # self_interaction=True, # we want edges from atom to itself in different periodic images +# # use_scaled_positions=False, # positions are not scaled positions +# ) - if not true_self_interaction: - # Eliminate self-edges that don't cross periodic boundaries - true_self_edge = sender == receiver - true_self_edge &= np.all(unit_shifts == 0, axis=1) - keep_edge = ~true_self_edge +# if not true_self_interaction: +# # Eliminate self-edges that don't cross periodic boundaries +# true_self_edge = sender == receiver +# true_self_edge &= np.all(unit_shifts == 0, axis=1) +# keep_edge = ~true_self_edge - # Note: after eliminating self-edges, it can be that no edges remain in this system - sender = sender[keep_edge] - receiver = receiver[keep_edge] - unit_shifts = unit_shifts[keep_edge] +# # Note: after eliminating self-edges, it can be that no edges remain in this system +# sender = sender[keep_edge] +# receiver = receiver[keep_edge] +# unit_shifts = unit_shifts[keep_edge] - # Build output - edge_index = np.stack((sender, receiver)) # [2, n_edges] +# # Build output +# edge_index = np.stack((sender, receiver)) # [2, n_edges] - # From the docs: With the shift vector S, the distances D between atoms can be computed from - # D = positions[j]-positions[i]+S.dot(cell) - shifts = np.dot(unit_shifts, cell) # [n_edges, 3] +# # From the docs: With the shift vector S, the distances D between atoms can be computed from +# # D = positions[j]-positions[i]+S.dot(cell) +# shifts = np.dot(unit_shifts, cell) # [n_edges, 3] - return edge_index, shifts, unit_shifts +# return edge_index, shifts, unit_shifts diff --git a/hydragnn/utils/mace_utils/data/utils.py b/hydragnn/utils/mace_utils/data/utils.py index 6458e7107..0eb3cd187 100644 --- a/hydragnn/utils/mace_utils/data/utils.py +++ b/hydragnn/utils/mace_utils/data/utils.py @@ -1,398 +1,398 @@ -########################################################################################### -# Data parsing utilities -# Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple - -import ase.data -import ase.io -import numpy as np - -# Try import but pass otherwise -try: - import h5py -except ImportError: - pass - -from hydragnn.utils.mace_utils.tools import AtomicNumberTable - -Vector = np.ndarray # [3,] -Positions = np.ndarray # [..., 3] -Forces = np.ndarray # [..., 3] -Stress = np.ndarray # [6, ], [3,3], [9, ] -Virials = np.ndarray # [6, ], [3,3], [9, ] -Charges = np.ndarray # [..., 1] -Cell = np.ndarray # [3,3] -Pbc = tuple # (3,) - -DEFAULT_CONFIG_TYPE = "Default" -DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} - - -@dataclass -class Configuration: - atomic_numbers: np.ndarray - positions: Positions # Angstrom - energy: Optional[float] = None # eV - forces: Optional[Forces] = None # eV/Angstrom - stress: Optional[Stress] = None # eV/Angstrom^3 - virials: Optional[Virials] = None # eV - dipole: Optional[Vector] = None # Debye - charges: Optional[Charges] = None # atomic unit - cell: Optional[Cell] = None - pbc: Optional[Pbc] = None - - weight: float = 1.0 # weight of config in loss - energy_weight: float = 1.0 # weight of config energy in loss - forces_weight: float = 1.0 # weight of config forces in loss - stress_weight: float = 1.0 # weight of config stress in loss - virials_weight: float = 1.0 # weight of config virial in loss - config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config - - -Configurations = List[Configuration] - - -def random_train_valid_split( - items: Sequence, valid_fraction: float, seed: int, work_dir: str -) -> Tuple[List, List]: - assert 0.0 < valid_fraction < 1.0 - - size = len(items) - train_size = size - int(valid_fraction * size) - - indices = list(range(size)) - rng = np.random.default_rng(seed) - rng.shuffle(indices) - if len(indices[train_size:]) < 10: - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" - ) - else: - # Save indices to file - with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: - for index in indices[train_size:]: - f.write(f"{index}\n") - - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" - ) - - return ( - [items[i] for i in indices[:train_size]], - [items[i] for i in indices[train_size:]], - ) - - -def config_from_atoms_list( - atoms_list: List[ase.Atoms], - energy_key="REF_energy", - forces_key="REF_forces", - stress_key="REF_stress", - virials_key="REF_virials", - dipole_key="REF_dipole", - charges_key="REF_charges", - config_type_weights: Dict[str, float] = None, -) -> Configurations: - """Convert list of ase.Atoms into Configurations""" - if config_type_weights is None: - config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - - all_configs = [] - for atoms in atoms_list: - all_configs.append( - config_from_atoms( - atoms, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - config_type_weights=config_type_weights, - ) - ) - return all_configs - - -def config_from_atoms( - atoms: ase.Atoms, - energy_key="REF_energy", - forces_key="REF_forces", - stress_key="REF_stress", - virials_key="REF_virials", - dipole_key="REF_dipole", - charges_key="REF_charges", - config_type_weights: Dict[str, float] = None, -) -> Configuration: - """Convert ase.Atoms to Configuration""" - if config_type_weights is None: - config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - - energy = atoms.info.get(energy_key, None) # eV - forces = atoms.arrays.get(forces_key, None) # eV / Ang - stress = atoms.info.get(stress_key, None) # eV / Ang ^ 3 - virials = atoms.info.get(virials_key, None) - dipole = atoms.info.get(dipole_key, None) # Debye - # Charges default to 0 instead of None if not found - charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit - atomic_numbers = np.array( - [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] - ) - pbc = tuple(atoms.get_pbc()) - cell = np.array(atoms.get_cell()) - config_type = atoms.info.get("config_type", "Default") - weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( - config_type, 1.0 - ) - energy_weight = atoms.info.get("config_energy_weight", 1.0) - forces_weight = atoms.info.get("config_forces_weight", 1.0) - stress_weight = atoms.info.get("config_stress_weight", 1.0) - virials_weight = atoms.info.get("config_virials_weight", 1.0) - - # fill in missing quantities but set their weight to 0.0 - if energy is None: - energy = 0.0 - energy_weight = 0.0 - if forces is None: - forces = np.zeros(np.shape(atoms.positions)) - forces_weight = 0.0 - if stress is None: - stress = np.zeros(6) - stress_weight = 0.0 - if virials is None: - virials = np.zeros((3, 3)) - virials_weight = 0.0 - if dipole is None: - dipole = np.zeros(3) - # dipoles_weight = 0.0 - - return Configuration( - atomic_numbers=atomic_numbers, - positions=atoms.get_positions(), - energy=energy, - forces=forces, - stress=stress, - virials=virials, - dipole=dipole, - charges=charges, - weight=weight, - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, - virials_weight=virials_weight, - config_type=config_type, - pbc=pbc, - cell=cell, - ) - - -def test_config_types( - test_configs: Configurations, -) -> List[Tuple[Optional[str], List[Configuration]]]: - """Split test set based on config_type-s""" - test_by_ct = [] - all_cts = [] - for conf in test_configs: - if conf.config_type not in all_cts: - all_cts.append(conf.config_type) - test_by_ct.append((conf.config_type, [conf])) - else: - ind = all_cts.index(conf.config_type) - test_by_ct[ind][1].append(conf) - return test_by_ct - - -def load_from_xyz( - file_path: str, - config_type_weights: Dict, - energy_key: str = "REF_energy", - forces_key: str = "REF_forces", - stress_key: str = "REF_stress", - virials_key: str = "REF_virials", - dipole_key: str = "REF_dipole", - charges_key: str = "REF_charges", - extract_atomic_energies: bool = False, - keep_isolated_atoms: bool = False, -) -> Tuple[Dict[int, float], Configurations]: - atoms_list = ase.io.read(file_path, index=":") - if energy_key == "energy": - logging.warning( - "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." - ) - energy_key = "REF_energy" - for atoms in atoms_list: - try: - atoms.info["REF_energy"] = atoms.get_potential_energy() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to extract energy: {e}") - atoms.info["REF_energy"] = None - if forces_key == "forces": - logging.warning( - "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." - ) - forces_key = "REF_forces" - for atoms in atoms_list: - try: - atoms.arrays["REF_forces"] = atoms.get_forces() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to extract forces: {e}") - atoms.arrays["REF_forces"] = None - if stress_key == "stress": - logging.warning( - "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." - ) - stress_key = "REF_stress" - for atoms in atoms_list: - try: - atoms.info["REF_stress"] = atoms.get_stress() - except Exception as e: # pylint: disable=W0703 - atoms.info["REF_stress"] = None - if not isinstance(atoms_list, list): - atoms_list = [atoms_list] - - atomic_energies_dict = {} - if extract_atomic_energies: - atoms_without_iso_atoms = [] - - for idx, atoms in enumerate(atoms_list): - isolated_atom_config = ( - len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" - ) - if isolated_atom_config: - if energy_key in atoms.info.keys(): - atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ - energy_key - ] - else: - logging.warning( - f"Configuration '{idx}' is marked as 'IsolatedAtom' " - "but does not contain an energy. Zero energy will be used." - ) - atomic_energies_dict[atoms.get_atomic_numbers()[0]] = np.zeros(1) - else: - atoms_without_iso_atoms.append(atoms) - - if len(atomic_energies_dict) > 0: - logging.info("Using isolated atom energies from training file") - if not keep_isolated_atoms: - atoms_list = atoms_without_iso_atoms - - configs = config_from_atoms_list( - atoms_list, - config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - ) - return atomic_energies_dict, configs - - -def compute_average_E0s( - collections_train: Configurations, z_table: AtomicNumberTable -) -> Dict[int, float]: - """ - Function to compute the average interaction energy of each chemical element - returns dictionary of E0s - """ - len_train = len(collections_train) - len_zs = len(z_table) - A = np.zeros((len_train, len_zs)) - B = np.zeros(len_train) - for i in range(len_train): - B[i] = collections_train[i].energy - for j, z in enumerate(z_table.zs): - A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) - try: - E0s = np.linalg.lstsq(A, B, rcond=None)[0] - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = E0s[i] - except np.linalg.LinAlgError: - logging.error( - "Failed to compute E0s using least squares regression, using the same for all atoms" - ) - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = 0.0 - return atomic_energies_dict - - -def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: - with h5py.File(out_name, "w") as f: - for i, data in enumerate(dataset): - grp = f.create_group(f"config_{i}") - grp["num_nodes"] = data.num_nodes - grp["edge_index"] = data.edge_index - grp["positions"] = data.positions - grp["shifts"] = data.shifts - grp["unit_shifts"] = data.unit_shifts - grp["cell"] = data.cell - grp["node_attrs"] = data.node_attrs - grp["weight"] = data.weight - grp["energy_weight"] = data.energy_weight - grp["forces_weight"] = data.forces_weight - grp["stress_weight"] = data.stress_weight - grp["virials_weight"] = data.virials_weight - grp["forces"] = data.forces - grp["energy"] = data.energy - grp["stress"] = data.stress - grp["virials"] = data.virials - grp["dipole"] = data.dipole - grp["charges"] = data.charges - - -def save_AtomicData_to_HDF5(data, i, h5_file) -> None: - grp = h5_file.create_group(f"config_{i}") - grp["num_nodes"] = data.num_nodes - grp["edge_index"] = data.edge_index - grp["positions"] = data.positions - grp["shifts"] = data.shifts - grp["unit_shifts"] = data.unit_shifts - grp["cell"] = data.cell - grp["node_attrs"] = data.node_attrs - grp["weight"] = data.weight - grp["energy_weight"] = data.energy_weight - grp["forces_weight"] = data.forces_weight - grp["stress_weight"] = data.stress_weight - grp["virials_weight"] = data.virials_weight - grp["forces"] = data.forces - grp["energy"] = data.energy - grp["stress"] = data.stress - grp["virials"] = data.virials - grp["dipole"] = data.dipole - grp["charges"] = data.charges - - -def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: - grp = h5_file.create_group("config_batch_0") - for j, config in enumerate(configurations): - subgroup_name = f"config_{j}" - subgroup = grp.create_group(subgroup_name) - subgroup["atomic_numbers"] = write_value(config.atomic_numbers) - subgroup["positions"] = write_value(config.positions) - subgroup["energy"] = write_value(config.energy) - subgroup["forces"] = write_value(config.forces) - subgroup["stress"] = write_value(config.stress) - subgroup["virials"] = write_value(config.virials) - subgroup["dipole"] = write_value(config.dipole) - subgroup["charges"] = write_value(config.charges) - subgroup["cell"] = write_value(config.cell) - subgroup["pbc"] = write_value(config.pbc) - subgroup["weight"] = write_value(config.weight) - subgroup["energy_weight"] = write_value(config.energy_weight) - subgroup["forces_weight"] = write_value(config.forces_weight) - subgroup["stress_weight"] = write_value(config.stress_weight) - subgroup["virials_weight"] = write_value(config.virials_weight) - subgroup["config_type"] = write_value(config.config_type) - - -def write_value(value): - return value if value is not None else "None" +# ########################################################################################### +# # Data parsing utilities +# # Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### + +# import logging +# from dataclasses import dataclass +# from typing import Dict, List, Optional, Sequence, Tuple + +# import ase.data +# import ase.io +# import numpy as np + +# # Try import but pass otherwise +# try: +# import h5py +# except ImportError: +# pass + +# from hydragnn.utils.mace_utils.tools import AtomicNumberTable + +# Vector = np.ndarray # [3,] +# Positions = np.ndarray # [..., 3] +# Forces = np.ndarray # [..., 3] +# Stress = np.ndarray # [6, ], [3,3], [9, ] +# Virials = np.ndarray # [6, ], [3,3], [9, ] +# Charges = np.ndarray # [..., 1] +# Cell = np.ndarray # [3,3] +# Pbc = tuple # (3,) + +# DEFAULT_CONFIG_TYPE = "Default" +# DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} + + +# @dataclass +# class Configuration: +# atomic_numbers: np.ndarray +# positions: Positions # Angstrom +# energy: Optional[float] = None # eV +# forces: Optional[Forces] = None # eV/Angstrom +# stress: Optional[Stress] = None # eV/Angstrom^3 +# virials: Optional[Virials] = None # eV +# dipole: Optional[Vector] = None # Debye +# charges: Optional[Charges] = None # atomic unit +# cell: Optional[Cell] = None +# pbc: Optional[Pbc] = None + +# weight: float = 1.0 # weight of config in loss +# energy_weight: float = 1.0 # weight of config energy in loss +# forces_weight: float = 1.0 # weight of config forces in loss +# stress_weight: float = 1.0 # weight of config stress in loss +# virials_weight: float = 1.0 # weight of config virial in loss +# config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config + + +# Configurations = List[Configuration] + + +# def random_train_valid_split( +# items: Sequence, valid_fraction: float, seed: int, work_dir: str +# ) -> Tuple[List, List]: +# assert 0.0 < valid_fraction < 1.0 + +# size = len(items) +# train_size = size - int(valid_fraction * size) + +# indices = list(range(size)) +# rng = np.random.default_rng(seed) +# rng.shuffle(indices) +# if len(indices[train_size:]) < 10: +# logging.info( +# f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" +# ) +# else: +# # Save indices to file +# with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: +# for index in indices[train_size:]: +# f.write(f"{index}\n") + +# logging.info( +# f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" +# ) + +# return ( +# [items[i] for i in indices[:train_size]], +# [items[i] for i in indices[train_size:]], +# ) + + +# def config_from_atoms_list( +# atoms_list: List[ase.Atoms], +# energy_key="REF_energy", +# forces_key="REF_forces", +# stress_key="REF_stress", +# virials_key="REF_virials", +# dipole_key="REF_dipole", +# charges_key="REF_charges", +# config_type_weights: Dict[str, float] = None, +# ) -> Configurations: +# """Convert list of ase.Atoms into Configurations""" +# if config_type_weights is None: +# config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + +# all_configs = [] +# for atoms in atoms_list: +# all_configs.append( +# config_from_atoms( +# atoms, +# energy_key=energy_key, +# forces_key=forces_key, +# stress_key=stress_key, +# virials_key=virials_key, +# dipole_key=dipole_key, +# charges_key=charges_key, +# config_type_weights=config_type_weights, +# ) +# ) +# return all_configs + + +# def config_from_atoms( +# atoms: ase.Atoms, +# energy_key="REF_energy", +# forces_key="REF_forces", +# stress_key="REF_stress", +# virials_key="REF_virials", +# dipole_key="REF_dipole", +# charges_key="REF_charges", +# config_type_weights: Dict[str, float] = None, +# ) -> Configuration: +# """Convert ase.Atoms to Configuration""" +# if config_type_weights is None: +# config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + +# energy = atoms.info.get(energy_key, None) # eV +# forces = atoms.arrays.get(forces_key, None) # eV / Ang +# stress = atoms.info.get(stress_key, None) # eV / Ang ^ 3 +# virials = atoms.info.get(virials_key, None) +# dipole = atoms.info.get(dipole_key, None) # Debye +# # Charges default to 0 instead of None if not found +# charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit +# atomic_numbers = np.array( +# [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] +# ) +# pbc = tuple(atoms.get_pbc()) +# cell = np.array(atoms.get_cell()) +# config_type = atoms.info.get("config_type", "Default") +# weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( +# config_type, 1.0 +# ) +# energy_weight = atoms.info.get("config_energy_weight", 1.0) +# forces_weight = atoms.info.get("config_forces_weight", 1.0) +# stress_weight = atoms.info.get("config_stress_weight", 1.0) +# virials_weight = atoms.info.get("config_virials_weight", 1.0) + +# # fill in missing quantities but set their weight to 0.0 +# if energy is None: +# energy = 0.0 +# energy_weight = 0.0 +# if forces is None: +# forces = np.zeros(np.shape(atoms.positions)) +# forces_weight = 0.0 +# if stress is None: +# stress = np.zeros(6) +# stress_weight = 0.0 +# if virials is None: +# virials = np.zeros((3, 3)) +# virials_weight = 0.0 +# if dipole is None: +# dipole = np.zeros(3) +# # dipoles_weight = 0.0 + +# return Configuration( +# atomic_numbers=atomic_numbers, +# positions=atoms.get_positions(), +# energy=energy, +# forces=forces, +# stress=stress, +# virials=virials, +# dipole=dipole, +# charges=charges, +# weight=weight, +# energy_weight=energy_weight, +# forces_weight=forces_weight, +# stress_weight=stress_weight, +# virials_weight=virials_weight, +# config_type=config_type, +# pbc=pbc, +# cell=cell, +# ) + + +# def test_config_types( +# test_configs: Configurations, +# ) -> List[Tuple[Optional[str], List[Configuration]]]: +# """Split test set based on config_type-s""" +# test_by_ct = [] +# all_cts = [] +# for conf in test_configs: +# if conf.config_type not in all_cts: +# all_cts.append(conf.config_type) +# test_by_ct.append((conf.config_type, [conf])) +# else: +# ind = all_cts.index(conf.config_type) +# test_by_ct[ind][1].append(conf) +# return test_by_ct + + +# def load_from_xyz( +# file_path: str, +# config_type_weights: Dict, +# energy_key: str = "REF_energy", +# forces_key: str = "REF_forces", +# stress_key: str = "REF_stress", +# virials_key: str = "REF_virials", +# dipole_key: str = "REF_dipole", +# charges_key: str = "REF_charges", +# extract_atomic_energies: bool = False, +# keep_isolated_atoms: bool = False, +# ) -> Tuple[Dict[int, float], Configurations]: +# atoms_list = ase.io.read(file_path, index=":") +# if energy_key == "energy": +# logging.warning( +# "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." +# ) +# energy_key = "REF_energy" +# for atoms in atoms_list: +# try: +# atoms.info["REF_energy"] = atoms.get_potential_energy() +# except Exception as e: # pylint: disable=W0703 +# logging.error(f"Failed to extract energy: {e}") +# atoms.info["REF_energy"] = None +# if forces_key == "forces": +# logging.warning( +# "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." +# ) +# forces_key = "REF_forces" +# for atoms in atoms_list: +# try: +# atoms.arrays["REF_forces"] = atoms.get_forces() +# except Exception as e: # pylint: disable=W0703 +# logging.error(f"Failed to extract forces: {e}") +# atoms.arrays["REF_forces"] = None +# if stress_key == "stress": +# logging.warning( +# "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." +# ) +# stress_key = "REF_stress" +# for atoms in atoms_list: +# try: +# atoms.info["REF_stress"] = atoms.get_stress() +# except Exception as e: # pylint: disable=W0703 +# atoms.info["REF_stress"] = None +# if not isinstance(atoms_list, list): +# atoms_list = [atoms_list] + +# atomic_energies_dict = {} +# if extract_atomic_energies: +# atoms_without_iso_atoms = [] + +# for idx, atoms in enumerate(atoms_list): +# isolated_atom_config = ( +# len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" +# ) +# if isolated_atom_config: +# if energy_key in atoms.info.keys(): +# atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ +# energy_key +# ] +# else: +# logging.warning( +# f"Configuration '{idx}' is marked as 'IsolatedAtom' " +# "but does not contain an energy. Zero energy will be used." +# ) +# atomic_energies_dict[atoms.get_atomic_numbers()[0]] = np.zeros(1) +# else: +# atoms_without_iso_atoms.append(atoms) + +# if len(atomic_energies_dict) > 0: +# logging.info("Using isolated atom energies from training file") +# if not keep_isolated_atoms: +# atoms_list = atoms_without_iso_atoms + +# configs = config_from_atoms_list( +# atoms_list, +# config_type_weights=config_type_weights, +# energy_key=energy_key, +# forces_key=forces_key, +# stress_key=stress_key, +# virials_key=virials_key, +# dipole_key=dipole_key, +# charges_key=charges_key, +# ) +# return atomic_energies_dict, configs + + +# def compute_average_E0s( +# collections_train: Configurations, z_table: AtomicNumberTable +# ) -> Dict[int, float]: +# """ +# Function to compute the average interaction energy of each chemical element +# returns dictionary of E0s +# """ +# len_train = len(collections_train) +# len_zs = len(z_table) +# A = np.zeros((len_train, len_zs)) +# B = np.zeros(len_train) +# for i in range(len_train): +# B[i] = collections_train[i].energy +# for j, z in enumerate(z_table.zs): +# A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) +# try: +# E0s = np.linalg.lstsq(A, B, rcond=None)[0] +# atomic_energies_dict = {} +# for i, z in enumerate(z_table.zs): +# atomic_energies_dict[z] = E0s[i] +# except np.linalg.LinAlgError: +# logging.error( +# "Failed to compute E0s using least squares regression, using the same for all atoms" +# ) +# atomic_energies_dict = {} +# for i, z in enumerate(z_table.zs): +# atomic_energies_dict[z] = 0.0 +# return atomic_energies_dict + + +# def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: +# with h5py.File(out_name, "w") as f: +# for i, data in enumerate(dataset): +# grp = f.create_group(f"config_{i}") +# grp["num_nodes"] = data.num_nodes +# grp["edge_index"] = data.edge_index +# grp["positions"] = data.positions +# grp["shifts"] = data.shifts +# grp["unit_shifts"] = data.unit_shifts +# grp["cell"] = data.cell +# grp["node_attrs"] = data.node_attrs +# grp["weight"] = data.weight +# grp["energy_weight"] = data.energy_weight +# grp["forces_weight"] = data.forces_weight +# grp["stress_weight"] = data.stress_weight +# grp["virials_weight"] = data.virials_weight +# grp["forces"] = data.forces +# grp["energy"] = data.energy +# grp["stress"] = data.stress +# grp["virials"] = data.virials +# grp["dipole"] = data.dipole +# grp["charges"] = data.charges + + +# def save_AtomicData_to_HDF5(data, i, h5_file) -> None: +# grp = h5_file.create_group(f"config_{i}") +# grp["num_nodes"] = data.num_nodes +# grp["edge_index"] = data.edge_index +# grp["positions"] = data.positions +# grp["shifts"] = data.shifts +# grp["unit_shifts"] = data.unit_shifts +# grp["cell"] = data.cell +# grp["node_attrs"] = data.node_attrs +# grp["weight"] = data.weight +# grp["energy_weight"] = data.energy_weight +# grp["forces_weight"] = data.forces_weight +# grp["stress_weight"] = data.stress_weight +# grp["virials_weight"] = data.virials_weight +# grp["forces"] = data.forces +# grp["energy"] = data.energy +# grp["stress"] = data.stress +# grp["virials"] = data.virials +# grp["dipole"] = data.dipole +# grp["charges"] = data.charges + + +# def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: +# grp = h5_file.create_group("config_batch_0") +# for j, config in enumerate(configurations): +# subgroup_name = f"config_{j}" +# subgroup = grp.create_group(subgroup_name) +# subgroup["atomic_numbers"] = write_value(config.atomic_numbers) +# subgroup["positions"] = write_value(config.positions) +# subgroup["energy"] = write_value(config.energy) +# subgroup["forces"] = write_value(config.forces) +# subgroup["stress"] = write_value(config.stress) +# subgroup["virials"] = write_value(config.virials) +# subgroup["dipole"] = write_value(config.dipole) +# subgroup["charges"] = write_value(config.charges) +# subgroup["cell"] = write_value(config.cell) +# subgroup["pbc"] = write_value(config.pbc) +# subgroup["weight"] = write_value(config.weight) +# subgroup["energy_weight"] = write_value(config.energy_weight) +# subgroup["forces_weight"] = write_value(config.forces_weight) +# subgroup["stress_weight"] = write_value(config.stress_weight) +# subgroup["virials_weight"] = write_value(config.virials_weight) +# subgroup["config_type"] = write_value(config.config_type) + + +# def write_value(value): +# return value if value is not None else "None" diff --git a/hydragnn/utils/mace_utils/modules/__init__.py b/hydragnn/utils/mace_utils/modules/__init__.py index b767383f3..2c8bb160a 100644 --- a/hydragnn/utils/mace_utils/modules/__init__.py +++ b/hydragnn/utils/mace_utils/modules/__init__.py @@ -20,34 +20,34 @@ # ResidualElementDependentInteractionBlock, ScaleShiftBlock, ) -from .loss import ( - DipoleSingleLoss, - UniversalLoss, - WeightedEnergyForcesDipoleLoss, - WeightedEnergyForcesLoss, - WeightedEnergyForcesStressLoss, - WeightedEnergyForcesVirialsLoss, - WeightedForcesLoss, - WeightedHuberEnergyForcesStressLoss, -) -from .models import ( - MACE, - AtomicDipolesMACE, - BOTNet, - EnergyDipolesMACE, - ScaleShiftBOTNet, - ScaleShiftMACE, -) +# from .loss import ( +# DipoleSingleLoss, +# UniversalLoss, +# WeightedEnergyForcesDipoleLoss, +# WeightedEnergyForcesLoss, +# WeightedEnergyForcesStressLoss, +# WeightedEnergyForcesVirialsLoss, +# WeightedForcesLoss, +# WeightedHuberEnergyForcesStressLoss, +# ) +# from .models import ( +# MACE, +# AtomicDipolesMACE, +# BOTNet, +# EnergyDipolesMACE, +# ScaleShiftBOTNet, +# ScaleShiftMACE, +# ) from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis from .symmetric_contraction import SymmetricContraction -from .utils import ( - compute_avg_num_neighbors, - compute_fixed_charge_dipole, - compute_mean_rms_energy_forces, - compute_mean_std_atomic_inter_energy, - compute_rms_dipoles, - compute_statistics, -) +# from .utils import ( +# compute_avg_num_neighbors, +# compute_fixed_charge_dipole, +# compute_mean_rms_energy_forces, +# compute_mean_std_atomic_inter_energy, +# compute_rms_dipoles, +# compute_statistics, +# ) interaction_classes: Dict[str, Type[InteractionBlock]] = { # "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, @@ -58,11 +58,11 @@ # "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, } -scaling_classes: Dict[str, Callable] = { - "std_scaling": compute_mean_std_atomic_inter_energy, - "rms_forces_scaling": compute_mean_rms_energy_forces, - "rms_dipoles_scaling": compute_rms_dipoles, -} +# scaling_classes: Dict[str, Callable] = { +# "std_scaling": compute_mean_std_atomic_inter_energy, +# "rms_forces_scaling": compute_mean_rms_energy_forces, +# "rms_dipoles_scaling": compute_rms_dipoles, +# } gate_dict: Dict[str, Optional[Callable]] = { "abs": torch.abs, diff --git a/hydragnn/utils/mace_utils/modules/loss.py b/hydragnn/utils/mace_utils/modules/loss.py index 5c754defc..9ece6e0c2 100644 --- a/hydragnn/utils/mace_utils/modules/loss.py +++ b/hydragnn/utils/mace_utils/modules/loss.py @@ -1,367 +1,367 @@ -########################################################################################### -# Implementation of different loss functions -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import torch - -from hydragnn.utils.mace_utils.tools import TensorDict -from hydragnn.utils.mace_utils.tools.torch_geometric import Batch - - -def mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: - # energy: [n_graphs, ] - return torch.mean(torch.square(ref["energy"] - pred["energy"])) # [] - - -def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: - # energy: [n_graphs, ] - configs_weight = ref.weight # [n_graphs, ] - configs_energy_weight = ref.energy_weight # [n_graphs, ] - num_atoms = ref.ptr[1:] - ref.ptr[:-1] # [n_graphs,] - return torch.mean( - configs_weight - * configs_energy_weight - * torch.square((ref["energy"] - pred["energy"]) / num_atoms) - ) # [] - - -def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor: - # energy: [n_graphs, ] - configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] - configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] - return torch.mean( - configs_weight - * configs_stress_weight - * torch.square(ref["stress"] - pred["stress"]) - ) # [] - - -def weighted_mean_squared_virials(ref: Batch, pred: TensorDict) -> torch.Tensor: - # energy: [n_graphs, ] - configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] - configs_virials_weight = ref.virials_weight.view(-1, 1, 1) # [n_graphs, ] - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,] - return torch.mean( - configs_weight - * configs_virials_weight - * torch.square((ref["virials"] - pred["virials"]) / num_atoms) - ) # [] - - -def mean_squared_error_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: - # forces: [n_atoms, 3] - configs_weight = torch.repeat_interleave( - ref.weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze( - -1 - ) # [n_atoms, 1] - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze( - -1 - ) # [n_atoms, 1] - return torch.mean( - configs_weight - * configs_forces_weight - * torch.square(ref["forces"] - pred["forces"]) - ) # [] - - -def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Tensor: - # dipole: [n_graphs, ] - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) # [n_graphs,1] - return torch.mean(torch.square((ref["dipole"] - pred["dipole"]) / num_atoms)) # [] - # return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # [] - - -def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: - # forces: [n_atoms, 3] - configs_weight = torch.repeat_interleave( - ref.weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze( - -1 - ) # [n_atoms, 1] - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze( - -1 - ) # [n_atoms, 1] - - # Define the multiplication factors for each condition - factors = torch.tensor([1.0, 0.7, 0.4, 0.1]) - - # Apply multiplication factors based on conditions - c1 = torch.norm(ref["forces"], dim=-1) < 100 - c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( - torch.norm(ref["forces"], dim=-1) < 200 - ) - c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( - torch.norm(ref["forces"], dim=-1) < 300 - ) - - err = ref["forces"] - pred["forces"] - - se = torch.zeros_like(err) - - se[c1] = torch.square(err[c1]) * factors[0] - se[c2] = torch.square(err[c2]) * factors[1] - se[c3] = torch.square(err[c3]) * factors[2] - se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] - - return torch.mean(configs_weight * configs_forces_weight * se) - - -def conditional_huber_forces( - ref: Batch, pred: TensorDict, huber_delta: float -) -> torch.Tensor: - # Define the multiplication factors for each condition - factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) - - # Apply multiplication factors based on conditions - c1 = torch.norm(ref["forces"], dim=-1) < 100 - c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( - torch.norm(ref["forces"], dim=-1) < 200 - ) - c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( - torch.norm(ref["forces"], dim=-1) < 300 - ) - c4 = ~(c1 | c2 | c3) - - se = torch.zeros_like(pred["forces"]) - - se[c1] = torch.nn.functional.huber_loss( - ref["forces"][c1], pred["forces"][c1], reduction="none", delta=factors[0] - ) - se[c2] = torch.nn.functional.huber_loss( - ref["forces"][c2], pred["forces"][c2], reduction="none", delta=factors[1] - ) - se[c3] = torch.nn.functional.huber_loss( - ref["forces"][c3], pred["forces"][c3], reduction="none", delta=factors[2] - ) - se[c4] = torch.nn.functional.huber_loss( - ref["forces"][c4], pred["forces"][c4], reduction="none", delta=factors[3] - ) - - return torch.mean(se) - - -class WeightedEnergyForcesLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - return self.energy_weight * weighted_mean_squared_error_energy( - ref, pred - ) + self.forces_weight * mean_squared_error_forces(ref, pred) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f})" - ) - - -class WeightedForcesLoss(torch.nn.Module): - def __init__(self, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - return self.forces_weight * mean_squared_error_forces(ref, pred) - - def __repr__(self): - return f"{self.__class__.__name__}(" f"forces_weight={self.forces_weight:.3f})" - - -class WeightedEnergyForcesStressLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - return ( - self.energy_weight * weighted_mean_squared_error_energy(ref, pred) - + self.forces_weight * mean_squared_error_forces(ref, pred) - + self.stress_weight * weighted_mean_squared_stress(ref, pred) - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 - ) -> None: - super().__init__() - self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - return ( - self.energy_weight - * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) - + self.forces_weight * self.huber_loss(ref["forces"], pred["forces"]) - + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class UniversalLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 - ) -> None: - super().__init__() - self.huber_delta = huber_delta - self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - return ( - self.energy_weight - * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) - + self.forces_weight - * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) - + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class WeightedEnergyForcesVirialsLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 - ) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "virials_weight", - torch.tensor(virials_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - return ( - self.energy_weight * weighted_mean_squared_error_energy(ref, pred) - + self.forces_weight * mean_squared_error_forces(ref, pred) - + self.virials_weight * weighted_mean_squared_virials(ref, pred) - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" - ) - - -class DipoleSingleLoss(torch.nn.Module): - def __init__(self, dipole_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "dipole_weight", - torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - return ( - self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100.0 - ) # multiply by 100 to have the right scale for the loss - - def __repr__(self): - return f"{self.__class__.__name__}(" f"dipole_weight={self.dipole_weight:.3f})" - - -class WeightedEnergyForcesDipoleLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "dipole_weight", - torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), - ) - - def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: - return ( - self.energy_weight * weighted_mean_squared_error_energy(ref, pred) - + self.forces_weight * mean_squared_error_forces(ref, pred) - + self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100 - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" - ) +# ########################################################################################### +# # Implementation of different loss functions +# # Authors: Ilyes Batatia, Gregor Simm +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### + +# import torch + +# from hydragnn.utils.mace_utils.tools import TensorDict +# from hydragnn.utils.mace_utils.tools.torch_geometric import Batch + + +# def mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: +# # energy: [n_graphs, ] +# return torch.mean(torch.square(ref["energy"] - pred["energy"])) # [] + + +# def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: +# # energy: [n_graphs, ] +# configs_weight = ref.weight # [n_graphs, ] +# configs_energy_weight = ref.energy_weight # [n_graphs, ] +# num_atoms = ref.ptr[1:] - ref.ptr[:-1] # [n_graphs,] +# return torch.mean( +# configs_weight +# * configs_energy_weight +# * torch.square((ref["energy"] - pred["energy"]) / num_atoms) +# ) # [] + + +# def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor: +# # energy: [n_graphs, ] +# configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] +# configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] +# return torch.mean( +# configs_weight +# * configs_stress_weight +# * torch.square(ref["stress"] - pred["stress"]) +# ) # [] + + +# def weighted_mean_squared_virials(ref: Batch, pred: TensorDict) -> torch.Tensor: +# # energy: [n_graphs, ] +# configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] +# configs_virials_weight = ref.virials_weight.view(-1, 1, 1) # [n_graphs, ] +# num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,] +# return torch.mean( +# configs_weight +# * configs_virials_weight +# * torch.square((ref["virials"] - pred["virials"]) / num_atoms) +# ) # [] + + +# def mean_squared_error_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: +# # forces: [n_atoms, 3] +# configs_weight = torch.repeat_interleave( +# ref.weight, ref.ptr[1:] - ref.ptr[:-1] +# ).unsqueeze( +# -1 +# ) # [n_atoms, 1] +# configs_forces_weight = torch.repeat_interleave( +# ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] +# ).unsqueeze( +# -1 +# ) # [n_atoms, 1] +# return torch.mean( +# configs_weight +# * configs_forces_weight +# * torch.square(ref["forces"] - pred["forces"]) +# ) # [] + + +# def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Tensor: +# # dipole: [n_graphs, ] +# num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) # [n_graphs,1] +# return torch.mean(torch.square((ref["dipole"] - pred["dipole"]) / num_atoms)) # [] +# # return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # [] + + +# def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: +# # forces: [n_atoms, 3] +# configs_weight = torch.repeat_interleave( +# ref.weight, ref.ptr[1:] - ref.ptr[:-1] +# ).unsqueeze( +# -1 +# ) # [n_atoms, 1] +# configs_forces_weight = torch.repeat_interleave( +# ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] +# ).unsqueeze( +# -1 +# ) # [n_atoms, 1] + +# # Define the multiplication factors for each condition +# factors = torch.tensor([1.0, 0.7, 0.4, 0.1]) + +# # Apply multiplication factors based on conditions +# c1 = torch.norm(ref["forces"], dim=-1) < 100 +# c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( +# torch.norm(ref["forces"], dim=-1) < 200 +# ) +# c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( +# torch.norm(ref["forces"], dim=-1) < 300 +# ) + +# err = ref["forces"] - pred["forces"] + +# se = torch.zeros_like(err) + +# se[c1] = torch.square(err[c1]) * factors[0] +# se[c2] = torch.square(err[c2]) * factors[1] +# se[c3] = torch.square(err[c3]) * factors[2] +# se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] + +# return torch.mean(configs_weight * configs_forces_weight * se) + + +# def conditional_huber_forces( +# ref: Batch, pred: TensorDict, huber_delta: float +# ) -> torch.Tensor: +# # Define the multiplication factors for each condition +# factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) + +# # Apply multiplication factors based on conditions +# c1 = torch.norm(ref["forces"], dim=-1) < 100 +# c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( +# torch.norm(ref["forces"], dim=-1) < 200 +# ) +# c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( +# torch.norm(ref["forces"], dim=-1) < 300 +# ) +# c4 = ~(c1 | c2 | c3) + +# se = torch.zeros_like(pred["forces"]) + +# se[c1] = torch.nn.functional.huber_loss( +# ref["forces"][c1], pred["forces"][c1], reduction="none", delta=factors[0] +# ) +# se[c2] = torch.nn.functional.huber_loss( +# ref["forces"][c2], pred["forces"][c2], reduction="none", delta=factors[1] +# ) +# se[c3] = torch.nn.functional.huber_loss( +# ref["forces"][c3], pred["forces"][c3], reduction="none", delta=factors[2] +# ) +# se[c4] = torch.nn.functional.huber_loss( +# ref["forces"][c4], pred["forces"][c4], reduction="none", delta=factors[3] +# ) + +# return torch.mean(se) + + +# class WeightedEnergyForcesLoss(torch.nn.Module): +# def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: +# super().__init__() +# self.register_buffer( +# "energy_weight", +# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "forces_weight", +# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# return self.energy_weight * weighted_mean_squared_error_energy( +# ref, pred +# ) + self.forces_weight * mean_squared_error_forces(ref, pred) + +# def __repr__(self): +# return ( +# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " +# f"forces_weight={self.forces_weight:.3f})" +# ) + + +# class WeightedForcesLoss(torch.nn.Module): +# def __init__(self, forces_weight=1.0) -> None: +# super().__init__() +# self.register_buffer( +# "forces_weight", +# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# return self.forces_weight * mean_squared_error_forces(ref, pred) + +# def __repr__(self): +# return f"{self.__class__.__name__}(" f"forces_weight={self.forces_weight:.3f})" + + +# class WeightedEnergyForcesStressLoss(torch.nn.Module): +# def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: +# super().__init__() +# self.register_buffer( +# "energy_weight", +# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "forces_weight", +# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "stress_weight", +# torch.tensor(stress_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# return ( +# self.energy_weight * weighted_mean_squared_error_energy(ref, pred) +# + self.forces_weight * mean_squared_error_forces(ref, pred) +# + self.stress_weight * weighted_mean_squared_stress(ref, pred) +# ) + +# def __repr__(self): +# return ( +# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " +# f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" +# ) + + +# class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): +# def __init__( +# self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 +# ) -> None: +# super().__init__() +# self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) +# self.register_buffer( +# "energy_weight", +# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "forces_weight", +# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "stress_weight", +# torch.tensor(stress_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# num_atoms = ref.ptr[1:] - ref.ptr[:-1] +# return ( +# self.energy_weight +# * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) +# + self.forces_weight * self.huber_loss(ref["forces"], pred["forces"]) +# + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) +# ) + +# def __repr__(self): +# return ( +# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " +# f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" +# ) + + +# class UniversalLoss(torch.nn.Module): +# def __init__( +# self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 +# ) -> None: +# super().__init__() +# self.huber_delta = huber_delta +# self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) +# self.register_buffer( +# "energy_weight", +# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "forces_weight", +# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "stress_weight", +# torch.tensor(stress_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# num_atoms = ref.ptr[1:] - ref.ptr[:-1] +# return ( +# self.energy_weight +# * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) +# + self.forces_weight +# * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) +# + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) +# ) + +# def __repr__(self): +# return ( +# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " +# f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" +# ) + + +# class WeightedEnergyForcesVirialsLoss(torch.nn.Module): +# def __init__( +# self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 +# ) -> None: +# super().__init__() +# self.register_buffer( +# "energy_weight", +# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "forces_weight", +# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "virials_weight", +# torch.tensor(virials_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# return ( +# self.energy_weight * weighted_mean_squared_error_energy(ref, pred) +# + self.forces_weight * mean_squared_error_forces(ref, pred) +# + self.virials_weight * weighted_mean_squared_virials(ref, pred) +# ) + +# def __repr__(self): +# return ( +# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " +# f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" +# ) + + +# class DipoleSingleLoss(torch.nn.Module): +# def __init__(self, dipole_weight=1.0) -> None: +# super().__init__() +# self.register_buffer( +# "dipole_weight", +# torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# return ( +# self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100.0 +# ) # multiply by 100 to have the right scale for the loss + +# def __repr__(self): +# return f"{self.__class__.__name__}(" f"dipole_weight={self.dipole_weight:.3f})" + + +# class WeightedEnergyForcesDipoleLoss(torch.nn.Module): +# def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: +# super().__init__() +# self.register_buffer( +# "energy_weight", +# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "forces_weight", +# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), +# ) +# self.register_buffer( +# "dipole_weight", +# torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), +# ) + +# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: +# return ( +# self.energy_weight * weighted_mean_squared_error_energy(ref, pred) +# + self.forces_weight * mean_squared_error_forces(ref, pred) +# + self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100 +# ) + +# def __repr__(self): +# return ( +# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " +# f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" +# ) diff --git a/hydragnn/utils/mace_utils/modules/models.py b/hydragnn/utils/mace_utils/modules/models.py index e0fa51ee2..cc87fed91 100644 --- a/hydragnn/utils/mace_utils/modules/models.py +++ b/hydragnn/utils/mace_utils/modules/models.py @@ -1,1065 +1,1065 @@ -########################################################################################### -# Implementation of MACE models and other models based E(3)-Equivariant MPNNs -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Any, Callable, Dict, List, Optional, Type, Union - -import numpy as np -import torch -from e3nn import o3 -from e3nn.util.jit import compile_mode - -from hydragnn.utils.mace_utils.data import AtomicData -from hydragnn.utils.mace_utils.modules.radial import ZBLBasis -from hydragnn.utils.mace_utils.tools.scatter import scatter_sum - -from .blocks import ( - AtomicEnergiesBlock, - EquivariantProductBasisBlock, - InteractionBlock, - LinearDipoleReadoutBlock, - LinearNodeEmbeddingBlock, - LinearReadoutBlock, - NonLinearDipoleReadoutBlock, - NonLinearReadoutBlock, - RadialEmbeddingBlock, - ScaleShiftBlock, -) -from .utils import ( - compute_fixed_charge_dipole, - compute_forces, - get_edge_vectors_and_lengths, - get_outputs, - get_symmetric_displacement, -) - -# pylint: disable=C0302 - - -@compile_mode("script") -class MACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: Union[int, List[int]], - gate: Optional[Callable], - pair_repulsion: bool = False, - distance_transform: str = "None", - radial_MLP: Optional[List[int]] = None, - radial_type: Optional[str] = "bessel", - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - if isinstance(correlation, int): - correlation = [correlation] * num_interactions - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - radial_type=radial_type, - distance_transform=distance_transform, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) - self.pair_repulsion = True - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readout - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation[0], - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearReadoutBlock(hidden_irreps)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - hidden_irreps_out = str( - hidden_irreps[0] - ) # Select only scalars for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation[i + 1], - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) - ) - else: - self.readouts.append(LinearReadoutBlock(hidden_irreps)) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - pair_energy = scatter_sum( - src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - else: - pair_node_energy = torch.zeros_like(node_e0) - pair_energy = torch.zeros_like(e0) - - # Interactions - energies = [e0, pair_energy] - node_energies_list = [node_e0, pair_node_energy] - node_feats_list = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_feats_list.append(node_feats) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_energies_list.append(node_energies) - - # Concatenate node features - node_feats_out = torch.cat(node_feats_list, dim=-1) - - # Sum over energy contributions - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - - # Outputs - forces, virials, stress, hessian = get_outputs( - energy=total_energy, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, - ) - - return { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, - "hessian": hessian, - "node_feats": node_feats_out, - } - - -@compile_mode("script") -class ScaleShiftMACE(MACE): - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["positions"].requires_grad_(True) - data["node_attrs"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - else: - pair_node_energy = torch.zeros_like(node_e0) - # Interactions - node_es_list = [pair_node_energy] - node_feats_list = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] - ) - node_feats_list.append(node_feats) - node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } - # Concatenate node features - node_feats_out = torch.cat(node_feats_list, dim=-1) - # print("node_es_list", node_es_list) - # Sum over interactions - node_inter_es = torch.sum( - torch.stack(node_es_list, dim=0), dim=0 - ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) - - # Sum over nodes in graph - inter_e = scatter_sum( - src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Add E_0 and (scaled) interaction energy - total_energy = e0 + inter_e - node_energy = node_e0 + node_inter_es - forces, virials, stress, hessian = get_outputs( - energy=inter_e, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, - ) - output = { - "energy": total_energy, - "node_energy": node_energy, - "interaction_energy": inter_e, - "forces": forces, - "virials": virials, - "stress": stress, - "hessian": hessian, - "displacement": displacement, - "node_feats": node_feats_out, - } - - return output - - -class BOTNet(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - gate: Optional[Callable], - avg_num_neighbors: float, - atomic_numbers: List[int], - ): - super().__init__() - self.r_max = r_max - self.atomic_numbers = atomic_numbers - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - - # Interactions and readouts - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - self.interactions = torch.nn.ModuleList() - self.readouts = torch.nn.ModuleList() - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - self.readouts.append(LinearReadoutBlock(inter.irreps_out)) - - for i in range(num_interactions - 1): - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=inter.irreps_out, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - ) - self.interactions.append(inter) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate) - ) - else: - self.readouts.append(LinearReadoutBlock(inter.irreps_out)) - - def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: - # Setup - data.positions.requires_grad = True - - # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) - e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data.node_attrs) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - energies = [e0] - for interaction, readout in zip(self.interactions, self.readouts): - node_feats = interaction( - node_attrs=data.node_attrs, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data.edge_index, - ) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - energy = scatter_sum( - src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] - energies.append(energy) - - # Sum over energy contributions - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - - output = { - "energy": total_energy, - "contributions": contributions, - "forces": compute_forces( - energy=total_energy, positions=data.positions, training=training - ), - } - - return output - - -class ScaleShiftBOTNet(BOTNet): - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) - - def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: - # Setup - data.positions.requires_grad = True - - # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) - e0 = scatter_sum( - src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data.node_attrs) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data.positions, edge_index=data.edge_index, shifts=data.shifts - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - node_es_list = [] - for interaction, readout in zip(self.interactions, self.readouts): - node_feats = interaction( - node_attrs=data.node_attrs, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data.edge_index, - ) - - node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } - - # Sum over interactions - node_inter_es = torch.sum( - torch.stack(node_es_list, dim=0), dim=0 - ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) - - # Sum over nodes in graph - inter_e = scatter_sum( - src=node_inter_es, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] - - # Add E_0 and (scaled) interaction energy - total_e = e0 + inter_e - - output = { - "energy": total_e, - "forces": compute_forces( - energy=inter_e, positions=data.positions, training=training - ), - } - - return output - - -@compile_mode("script") -class AtomicDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[ - None - ], # Just here to make it compatible with energy models, MUST be None - radial_type: Optional[str] = "bessel", - radial_MLP: Optional[List[int]] = None, - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - assert atomic_energies is None - - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - radial_type=radial_type, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - - # Interactions and readouts - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[1] - ) # Select only l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=True - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, # pylint: disable=W0613 - compute_force: bool = False, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - assert compute_force is False - assert compute_virials is False - assert compute_stress is False - assert compute_displacement is False - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] - dipoles.append(node_dipoles) - - # Compute the dipoles - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"], - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - output = { - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - return output - - -@compile_mode("script") -class EnergyDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[np.ndarray], - radial_MLP: Optional[List[int]] = None, - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readouts - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[:2] - ) # Select scalars and l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=False - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - energies = [e0] - node_energies_list = [node_e0] - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] - # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - node_energies = node_out[:, 0] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_dipoles = node_out[:, 1:] - dipoles.append(node_dipoles) - - # Compute the energies and dipoles - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"].unsqueeze(-1), - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - forces, virials, stress, _ = get_outputs( - energy=total_energy, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - ) - - output = { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - return output +# ########################################################################################### +# # Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# # Authors: Ilyes Batatia, Gregor Simm +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### + +# from typing import Any, Callable, Dict, List, Optional, Type, Union + +# import numpy as np +# import torch +# from e3nn import o3 +# from e3nn.util.jit import compile_mode + +# from hydragnn.utils.mace_utils.data import AtomicData +# from hydragnn.utils.mace_utils.modules.radial import ZBLBasis +# from hydragnn.utils.mace_utils.tools.scatter import scatter_sum + +# from .blocks import ( +# AtomicEnergiesBlock, +# EquivariantProductBasisBlock, +# InteractionBlock, +# LinearDipoleReadoutBlock, +# LinearNodeEmbeddingBlock, +# LinearReadoutBlock, +# NonLinearDipoleReadoutBlock, +# NonLinearReadoutBlock, +# RadialEmbeddingBlock, +# ScaleShiftBlock, +# ) +# from .utils import ( +# compute_fixed_charge_dipole, +# compute_forces, +# get_edge_vectors_and_lengths, +# get_outputs, +# get_symmetric_displacement, +# ) + +# # pylint: disable=C0302 + + +# @compile_mode("script") +# class MACE(torch.nn.Module): +# def __init__( +# self, +# r_max: float, +# num_bessel: int, +# num_polynomial_cutoff: int, +# max_ell: int, +# interaction_cls: Type[InteractionBlock], +# interaction_cls_first: Type[InteractionBlock], +# num_interactions: int, +# num_elements: int, +# hidden_irreps: o3.Irreps, +# MLP_irreps: o3.Irreps, +# atomic_energies: np.ndarray, +# avg_num_neighbors: float, +# atomic_numbers: List[int], +# correlation: Union[int, List[int]], +# gate: Optional[Callable], +# pair_repulsion: bool = False, +# distance_transform: str = "None", +# radial_MLP: Optional[List[int]] = None, +# radial_type: Optional[str] = "bessel", +# ): +# super().__init__() +# self.register_buffer( +# "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) +# ) +# self.register_buffer( +# "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) +# ) +# self.register_buffer( +# "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) +# ) +# if isinstance(correlation, int): +# correlation = [correlation] * num_interactions +# # Embedding +# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) +# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) +# self.node_embedding = LinearNodeEmbeddingBlock( +# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps +# ) +# self.radial_embedding = RadialEmbeddingBlock( +# r_max=r_max, +# num_bessel=num_bessel, +# num_polynomial_cutoff=num_polynomial_cutoff, +# radial_type=radial_type, +# distance_transform=distance_transform, +# ) +# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") +# if pair_repulsion: +# self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) +# self.pair_repulsion = True + +# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) +# num_features = hidden_irreps.count(o3.Irrep(0, 1)) +# interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() +# self.spherical_harmonics = o3.SphericalHarmonics( +# sh_irreps, normalize=True, normalization="component" +# ) +# if radial_MLP is None: +# radial_MLP = [64, 64, 64] +# # Interactions and readout +# self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + +# inter = interaction_cls_first( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=node_feats_irreps, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=interaction_irreps, +# hidden_irreps=hidden_irreps, +# avg_num_neighbors=avg_num_neighbors, +# radial_MLP=radial_MLP, +# ) +# self.interactions = torch.nn.ModuleList([inter]) + +# # Use the appropriate self connection at the first layer for proper E0 +# use_sc_first = False +# if "Residual" in str(interaction_cls_first): +# use_sc_first = True + +# node_feats_irreps_out = inter.target_irreps +# prod = EquivariantProductBasisBlock( +# node_feats_irreps=node_feats_irreps_out, +# target_irreps=hidden_irreps, +# correlation=correlation[0], +# num_elements=num_elements, +# use_sc=use_sc_first, +# ) +# self.products = torch.nn.ModuleList([prod]) + +# self.readouts = torch.nn.ModuleList() +# self.readouts.append(LinearReadoutBlock(hidden_irreps)) + +# for i in range(num_interactions - 1): +# if i == num_interactions - 2: +# hidden_irreps_out = str( +# hidden_irreps[0] +# ) # Select only scalars for last layer +# else: +# hidden_irreps_out = hidden_irreps +# inter = interaction_cls( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=hidden_irreps, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=interaction_irreps, +# hidden_irreps=hidden_irreps_out, +# avg_num_neighbors=avg_num_neighbors, +# radial_MLP=radial_MLP, +# ) +# self.interactions.append(inter) +# prod = EquivariantProductBasisBlock( +# node_feats_irreps=interaction_irreps, +# target_irreps=hidden_irreps_out, +# correlation=correlation[i + 1], +# num_elements=num_elements, +# use_sc=True, +# ) +# self.products.append(prod) +# if i == num_interactions - 2: +# self.readouts.append( +# NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) +# ) +# else: +# self.readouts.append(LinearReadoutBlock(hidden_irreps)) + +# def forward( +# self, +# data: Dict[str, torch.Tensor], +# training: bool = False, +# compute_force: bool = True, +# compute_virials: bool = False, +# compute_stress: bool = False, +# compute_displacement: bool = False, +# compute_hessian: bool = False, +# ) -> Dict[str, Optional[torch.Tensor]]: +# # Setup +# data["node_attrs"].requires_grad_(True) +# data["positions"].requires_grad_(True) +# num_graphs = data["ptr"].numel() - 1 +# displacement = torch.zeros( +# (num_graphs, 3, 3), +# dtype=data["positions"].dtype, +# device=data["positions"].device, +# ) +# if compute_virials or compute_stress or compute_displacement: +# ( +# data["positions"], +# data["shifts"], +# displacement, +# ) = get_symmetric_displacement( +# positions=data["positions"], +# unit_shifts=data["unit_shifts"], +# cell=data["cell"], +# edge_index=data["edge_index"], +# num_graphs=num_graphs, +# batch=data["batch"], +# ) + +# # Atomic energies +# node_e0 = self.atomic_energies_fn(data["node_attrs"]) +# e0 = scatter_sum( +# src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs +# ) # [n_graphs,] +# # Embeddings +# node_feats = self.node_embedding(data["node_attrs"]) +# vectors, lengths = get_edge_vectors_and_lengths( +# positions=data["positions"], +# edge_index=data["edge_index"], +# shifts=data["shifts"], +# ) +# edge_attrs = self.spherical_harmonics(vectors) +# edge_feats = self.radial_embedding( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) +# if hasattr(self, "pair_repulsion"): +# pair_node_energy = self.pair_repulsion_fn( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) +# pair_energy = scatter_sum( +# src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs +# ) # [n_graphs,] +# else: +# pair_node_energy = torch.zeros_like(node_e0) +# pair_energy = torch.zeros_like(e0) + +# # Interactions +# energies = [e0, pair_energy] +# node_energies_list = [node_e0, pair_node_energy] +# node_feats_list = [] +# for interaction, product, readout in zip( +# self.interactions, self.products, self.readouts +# ): +# node_feats, sc = interaction( +# node_attrs=data["node_attrs"], +# node_feats=node_feats, +# edge_attrs=edge_attrs, +# edge_feats=edge_feats, +# edge_index=data["edge_index"], +# ) +# node_feats = product( +# node_feats=node_feats, +# sc=sc, +# node_attrs=data["node_attrs"], +# ) +# node_feats_list.append(node_feats) +# node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] +# energy = scatter_sum( +# src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs +# ) # [n_graphs,] +# energies.append(energy) +# node_energies_list.append(node_energies) + +# # Concatenate node features +# node_feats_out = torch.cat(node_feats_list, dim=-1) + +# # Sum over energy contributions +# contributions = torch.stack(energies, dim=-1) +# total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] +# node_energy_contributions = torch.stack(node_energies_list, dim=-1) +# node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + +# # Outputs +# forces, virials, stress, hessian = get_outputs( +# energy=total_energy, +# positions=data["positions"], +# displacement=displacement, +# cell=data["cell"], +# training=training, +# compute_force=compute_force, +# compute_virials=compute_virials, +# compute_stress=compute_stress, +# compute_hessian=compute_hessian, +# ) + +# return { +# "energy": total_energy, +# "node_energy": node_energy, +# "contributions": contributions, +# "forces": forces, +# "virials": virials, +# "stress": stress, +# "displacement": displacement, +# "hessian": hessian, +# "node_feats": node_feats_out, +# } + + +# @compile_mode("script") +# class ScaleShiftMACE(MACE): +# def __init__( +# self, +# atomic_inter_scale: float, +# atomic_inter_shift: float, +# **kwargs, +# ): +# super().__init__(**kwargs) +# self.scale_shift = ScaleShiftBlock( +# scale=atomic_inter_scale, shift=atomic_inter_shift +# ) + +# def forward( +# self, +# data: Dict[str, torch.Tensor], +# training: bool = False, +# compute_force: bool = True, +# compute_virials: bool = False, +# compute_stress: bool = False, +# compute_displacement: bool = False, +# compute_hessian: bool = False, +# ) -> Dict[str, Optional[torch.Tensor]]: +# # Setup +# data["positions"].requires_grad_(True) +# data["node_attrs"].requires_grad_(True) +# num_graphs = data["ptr"].numel() - 1 +# displacement = torch.zeros( +# (num_graphs, 3, 3), +# dtype=data["positions"].dtype, +# device=data["positions"].device, +# ) +# if compute_virials or compute_stress or compute_displacement: +# ( +# data["positions"], +# data["shifts"], +# displacement, +# ) = get_symmetric_displacement( +# positions=data["positions"], +# unit_shifts=data["unit_shifts"], +# cell=data["cell"], +# edge_index=data["edge_index"], +# num_graphs=num_graphs, +# batch=data["batch"], +# ) + +# # Atomic energies +# node_e0 = self.atomic_energies_fn(data["node_attrs"]) +# e0 = scatter_sum( +# src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs +# ) # [n_graphs,] + +# # Embeddings +# node_feats = self.node_embedding(data["node_attrs"]) +# vectors, lengths = get_edge_vectors_and_lengths( +# positions=data["positions"], +# edge_index=data["edge_index"], +# shifts=data["shifts"], +# ) +# edge_attrs = self.spherical_harmonics(vectors) +# edge_feats = self.radial_embedding( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) +# if hasattr(self, "pair_repulsion"): +# pair_node_energy = self.pair_repulsion_fn( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) +# else: +# pair_node_energy = torch.zeros_like(node_e0) +# # Interactions +# node_es_list = [pair_node_energy] +# node_feats_list = [] +# for interaction, product, readout in zip( +# self.interactions, self.products, self.readouts +# ): +# node_feats, sc = interaction( +# node_attrs=data["node_attrs"], +# node_feats=node_feats, +# edge_attrs=edge_attrs, +# edge_feats=edge_feats, +# edge_index=data["edge_index"], +# ) +# node_feats = product( +# node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] +# ) +# node_feats_list.append(node_feats) +# node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } +# # Concatenate node features +# node_feats_out = torch.cat(node_feats_list, dim=-1) +# # print("node_es_list", node_es_list) +# # Sum over interactions +# node_inter_es = torch.sum( +# torch.stack(node_es_list, dim=0), dim=0 +# ) # [n_nodes, ] +# node_inter_es = self.scale_shift(node_inter_es) + +# # Sum over nodes in graph +# inter_e = scatter_sum( +# src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs +# ) # [n_graphs,] + +# # Add E_0 and (scaled) interaction energy +# total_energy = e0 + inter_e +# node_energy = node_e0 + node_inter_es +# forces, virials, stress, hessian = get_outputs( +# energy=inter_e, +# positions=data["positions"], +# displacement=displacement, +# cell=data["cell"], +# training=training, +# compute_force=compute_force, +# compute_virials=compute_virials, +# compute_stress=compute_stress, +# compute_hessian=compute_hessian, +# ) +# output = { +# "energy": total_energy, +# "node_energy": node_energy, +# "interaction_energy": inter_e, +# "forces": forces, +# "virials": virials, +# "stress": stress, +# "hessian": hessian, +# "displacement": displacement, +# "node_feats": node_feats_out, +# } + +# return output + + +# class BOTNet(torch.nn.Module): +# def __init__( +# self, +# r_max: float, +# num_bessel: int, +# num_polynomial_cutoff: int, +# max_ell: int, +# interaction_cls: Type[InteractionBlock], +# interaction_cls_first: Type[InteractionBlock], +# num_interactions: int, +# num_elements: int, +# hidden_irreps: o3.Irreps, +# MLP_irreps: o3.Irreps, +# atomic_energies: np.ndarray, +# gate: Optional[Callable], +# avg_num_neighbors: float, +# atomic_numbers: List[int], +# ): +# super().__init__() +# self.r_max = r_max +# self.atomic_numbers = atomic_numbers +# # Embedding +# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) +# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) +# self.node_embedding = LinearNodeEmbeddingBlock( +# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps +# ) +# self.radial_embedding = RadialEmbeddingBlock( +# r_max=r_max, +# num_bessel=num_bessel, +# num_polynomial_cutoff=num_polynomial_cutoff, +# ) +# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + +# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) +# self.spherical_harmonics = o3.SphericalHarmonics( +# sh_irreps, normalize=True, normalization="component" +# ) + +# # Interactions and readouts +# self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + +# self.interactions = torch.nn.ModuleList() +# self.readouts = torch.nn.ModuleList() + +# inter = interaction_cls_first( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=node_feats_irreps, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=hidden_irreps, +# avg_num_neighbors=avg_num_neighbors, +# ) +# self.interactions.append(inter) +# self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + +# for i in range(num_interactions - 1): +# inter = interaction_cls( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=inter.irreps_out, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=hidden_irreps, +# avg_num_neighbors=avg_num_neighbors, +# ) +# self.interactions.append(inter) +# if i == num_interactions - 2: +# self.readouts.append( +# NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate) +# ) +# else: +# self.readouts.append(LinearReadoutBlock(inter.irreps_out)) + +# def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: +# # Setup +# data.positions.requires_grad = True + +# # Atomic energies +# node_e0 = self.atomic_energies_fn(data.node_attrs) +# e0 = scatter_sum( +# src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs +# ) # [n_graphs,] + +# # Embeddings +# node_feats = self.node_embedding(data.node_attrs) +# vectors, lengths = get_edge_vectors_and_lengths( +# positions=data.positions, edge_index=data.edge_index, shifts=data.shifts +# ) +# edge_attrs = self.spherical_harmonics(vectors) +# edge_feats = self.radial_embedding( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) + +# # Interactions +# energies = [e0] +# for interaction, readout in zip(self.interactions, self.readouts): +# node_feats = interaction( +# node_attrs=data.node_attrs, +# node_feats=node_feats, +# edge_attrs=edge_attrs, +# edge_feats=edge_feats, +# edge_index=data.edge_index, +# ) +# node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] +# energy = scatter_sum( +# src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs +# ) # [n_graphs,] +# energies.append(energy) + +# # Sum over energy contributions +# contributions = torch.stack(energies, dim=-1) +# total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + +# output = { +# "energy": total_energy, +# "contributions": contributions, +# "forces": compute_forces( +# energy=total_energy, positions=data.positions, training=training +# ), +# } + +# return output + + +# class ScaleShiftBOTNet(BOTNet): +# def __init__( +# self, +# atomic_inter_scale: float, +# atomic_inter_shift: float, +# **kwargs, +# ): +# super().__init__(**kwargs) +# self.scale_shift = ScaleShiftBlock( +# scale=atomic_inter_scale, shift=atomic_inter_shift +# ) + +# def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: +# # Setup +# data.positions.requires_grad = True + +# # Atomic energies +# node_e0 = self.atomic_energies_fn(data.node_attrs) +# e0 = scatter_sum( +# src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs +# ) # [n_graphs,] + +# # Embeddings +# node_feats = self.node_embedding(data.node_attrs) +# vectors, lengths = get_edge_vectors_and_lengths( +# positions=data.positions, edge_index=data.edge_index, shifts=data.shifts +# ) +# edge_attrs = self.spherical_harmonics(vectors) +# edge_feats = self.radial_embedding( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) + +# # Interactions +# node_es_list = [] +# for interaction, readout in zip(self.interactions, self.readouts): +# node_feats = interaction( +# node_attrs=data.node_attrs, +# node_feats=node_feats, +# edge_attrs=edge_attrs, +# edge_feats=edge_feats, +# edge_index=data.edge_index, +# ) + +# node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + +# # Sum over interactions +# node_inter_es = torch.sum( +# torch.stack(node_es_list, dim=0), dim=0 +# ) # [n_nodes, ] +# node_inter_es = self.scale_shift(node_inter_es) + +# # Sum over nodes in graph +# inter_e = scatter_sum( +# src=node_inter_es, index=data.batch, dim=-1, dim_size=data.num_graphs +# ) # [n_graphs,] + +# # Add E_0 and (scaled) interaction energy +# total_e = e0 + inter_e + +# output = { +# "energy": total_e, +# "forces": compute_forces( +# energy=inter_e, positions=data.positions, training=training +# ), +# } + +# return output + + +# @compile_mode("script") +# class AtomicDipolesMACE(torch.nn.Module): +# def __init__( +# self, +# r_max: float, +# num_bessel: int, +# num_polynomial_cutoff: int, +# max_ell: int, +# interaction_cls: Type[InteractionBlock], +# interaction_cls_first: Type[InteractionBlock], +# num_interactions: int, +# num_elements: int, +# hidden_irreps: o3.Irreps, +# MLP_irreps: o3.Irreps, +# avg_num_neighbors: float, +# atomic_numbers: List[int], +# correlation: int, +# gate: Optional[Callable], +# atomic_energies: Optional[ +# None +# ], # Just here to make it compatible with energy models, MUST be None +# radial_type: Optional[str] = "bessel", +# radial_MLP: Optional[List[int]] = None, +# ): +# super().__init__() +# self.register_buffer( +# "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) +# ) +# self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) +# self.register_buffer( +# "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) +# ) +# assert atomic_energies is None + +# # Embedding +# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) +# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) +# self.node_embedding = LinearNodeEmbeddingBlock( +# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps +# ) +# self.radial_embedding = RadialEmbeddingBlock( +# r_max=r_max, +# num_bessel=num_bessel, +# num_polynomial_cutoff=num_polynomial_cutoff, +# radial_type=radial_type, +# ) +# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + +# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) +# num_features = hidden_irreps.count(o3.Irrep(0, 1)) +# interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() +# self.spherical_harmonics = o3.SphericalHarmonics( +# sh_irreps, normalize=True, normalization="component" +# ) +# if radial_MLP is None: +# radial_MLP = [64, 64, 64] + +# # Interactions and readouts +# inter = interaction_cls_first( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=node_feats_irreps, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=interaction_irreps, +# hidden_irreps=hidden_irreps, +# avg_num_neighbors=avg_num_neighbors, +# radial_MLP=radial_MLP, +# ) +# self.interactions = torch.nn.ModuleList([inter]) + +# # Use the appropriate self connection at the first layer +# use_sc_first = False +# if "Residual" in str(interaction_cls_first): +# use_sc_first = True + +# node_feats_irreps_out = inter.target_irreps +# prod = EquivariantProductBasisBlock( +# node_feats_irreps=node_feats_irreps_out, +# target_irreps=hidden_irreps, +# correlation=correlation, +# num_elements=num_elements, +# use_sc=use_sc_first, +# ) +# self.products = torch.nn.ModuleList([prod]) + +# self.readouts = torch.nn.ModuleList() +# self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) + +# for i in range(num_interactions - 1): +# if i == num_interactions - 2: +# assert ( +# len(hidden_irreps) > 1 +# ), "To predict dipoles use at least l=1 hidden_irreps" +# hidden_irreps_out = str( +# hidden_irreps[1] +# ) # Select only l=1 vectors for last layer +# else: +# hidden_irreps_out = hidden_irreps +# inter = interaction_cls( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=hidden_irreps, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=interaction_irreps, +# hidden_irreps=hidden_irreps_out, +# avg_num_neighbors=avg_num_neighbors, +# radial_MLP=radial_MLP, +# ) +# self.interactions.append(inter) +# prod = EquivariantProductBasisBlock( +# node_feats_irreps=interaction_irreps, +# target_irreps=hidden_irreps_out, +# correlation=correlation, +# num_elements=num_elements, +# use_sc=True, +# ) +# self.products.append(prod) +# if i == num_interactions - 2: +# self.readouts.append( +# NonLinearDipoleReadoutBlock( +# hidden_irreps_out, MLP_irreps, gate, dipole_only=True +# ) +# ) +# else: +# self.readouts.append( +# LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) +# ) + +# def forward( +# self, +# data: Dict[str, torch.Tensor], +# training: bool = False, # pylint: disable=W0613 +# compute_force: bool = False, +# compute_virials: bool = False, +# compute_stress: bool = False, +# compute_displacement: bool = False, +# ) -> Dict[str, Optional[torch.Tensor]]: +# assert compute_force is False +# assert compute_virials is False +# assert compute_stress is False +# assert compute_displacement is False +# # Setup +# data["node_attrs"].requires_grad_(True) +# data["positions"].requires_grad_(True) +# num_graphs = data["ptr"].numel() - 1 + +# # Embeddings +# node_feats = self.node_embedding(data["node_attrs"]) +# vectors, lengths = get_edge_vectors_and_lengths( +# positions=data["positions"], +# edge_index=data["edge_index"], +# shifts=data["shifts"], +# ) +# edge_attrs = self.spherical_harmonics(vectors) +# edge_feats = self.radial_embedding( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) + +# # Interactions +# dipoles = [] +# for interaction, product, readout in zip( +# self.interactions, self.products, self.readouts +# ): +# node_feats, sc = interaction( +# node_attrs=data["node_attrs"], +# node_feats=node_feats, +# edge_attrs=edge_attrs, +# edge_feats=edge_feats, +# edge_index=data["edge_index"], +# ) +# node_feats = product( +# node_feats=node_feats, +# sc=sc, +# node_attrs=data["node_attrs"], +# ) +# node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] +# dipoles.append(node_dipoles) + +# # Compute the dipoles +# contributions_dipoles = torch.stack( +# dipoles, dim=-1 +# ) # [n_nodes,3,n_contributions] +# atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] +# total_dipole = scatter_sum( +# src=atomic_dipoles, +# index=data["batch"], +# dim=0, +# dim_size=num_graphs, +# ) # [n_graphs,3] +# baseline = compute_fixed_charge_dipole( +# charges=data["charges"], +# positions=data["positions"], +# batch=data["batch"], +# num_graphs=num_graphs, +# ) # [n_graphs,3] +# total_dipole = total_dipole + baseline + +# output = { +# "dipole": total_dipole, +# "atomic_dipoles": atomic_dipoles, +# } +# return output + + +# @compile_mode("script") +# class EnergyDipolesMACE(torch.nn.Module): +# def __init__( +# self, +# r_max: float, +# num_bessel: int, +# num_polynomial_cutoff: int, +# max_ell: int, +# interaction_cls: Type[InteractionBlock], +# interaction_cls_first: Type[InteractionBlock], +# num_interactions: int, +# num_elements: int, +# hidden_irreps: o3.Irreps, +# MLP_irreps: o3.Irreps, +# avg_num_neighbors: float, +# atomic_numbers: List[int], +# correlation: int, +# gate: Optional[Callable], +# atomic_energies: Optional[np.ndarray], +# radial_MLP: Optional[List[int]] = None, +# ): +# super().__init__() +# self.register_buffer( +# "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) +# ) +# self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) +# self.register_buffer( +# "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) +# ) +# # Embedding +# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) +# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) +# self.node_embedding = LinearNodeEmbeddingBlock( +# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps +# ) +# self.radial_embedding = RadialEmbeddingBlock( +# r_max=r_max, +# num_bessel=num_bessel, +# num_polynomial_cutoff=num_polynomial_cutoff, +# ) +# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + +# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) +# num_features = hidden_irreps.count(o3.Irrep(0, 1)) +# interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() +# self.spherical_harmonics = o3.SphericalHarmonics( +# sh_irreps, normalize=True, normalization="component" +# ) +# if radial_MLP is None: +# radial_MLP = [64, 64, 64] +# # Interactions and readouts +# self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + +# inter = interaction_cls_first( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=node_feats_irreps, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=interaction_irreps, +# hidden_irreps=hidden_irreps, +# avg_num_neighbors=avg_num_neighbors, +# radial_MLP=radial_MLP, +# ) +# self.interactions = torch.nn.ModuleList([inter]) + +# # Use the appropriate self connection at the first layer +# use_sc_first = False +# if "Residual" in str(interaction_cls_first): +# use_sc_first = True + +# node_feats_irreps_out = inter.target_irreps +# prod = EquivariantProductBasisBlock( +# node_feats_irreps=node_feats_irreps_out, +# target_irreps=hidden_irreps, +# correlation=correlation, +# num_elements=num_elements, +# use_sc=use_sc_first, +# ) +# self.products = torch.nn.ModuleList([prod]) + +# self.readouts = torch.nn.ModuleList() +# self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) + +# for i in range(num_interactions - 1): +# if i == num_interactions - 2: +# assert ( +# len(hidden_irreps) > 1 +# ), "To predict dipoles use at least l=1 hidden_irreps" +# hidden_irreps_out = str( +# hidden_irreps[:2] +# ) # Select scalars and l=1 vectors for last layer +# else: +# hidden_irreps_out = hidden_irreps +# inter = interaction_cls( +# node_attrs_irreps=node_attr_irreps, +# node_feats_irreps=hidden_irreps, +# edge_attrs_irreps=sh_irreps, +# edge_feats_irreps=edge_feats_irreps, +# target_irreps=interaction_irreps, +# hidden_irreps=hidden_irreps_out, +# avg_num_neighbors=avg_num_neighbors, +# radial_MLP=radial_MLP, +# ) +# self.interactions.append(inter) +# prod = EquivariantProductBasisBlock( +# node_feats_irreps=interaction_irreps, +# target_irreps=hidden_irreps_out, +# correlation=correlation, +# num_elements=num_elements, +# use_sc=True, +# ) +# self.products.append(prod) +# if i == num_interactions - 2: +# self.readouts.append( +# NonLinearDipoleReadoutBlock( +# hidden_irreps_out, MLP_irreps, gate, dipole_only=False +# ) +# ) +# else: +# self.readouts.append( +# LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) +# ) + +# def forward( +# self, +# data: Dict[str, torch.Tensor], +# training: bool = False, +# compute_force: bool = True, +# compute_virials: bool = False, +# compute_stress: bool = False, +# compute_displacement: bool = False, +# ) -> Dict[str, Optional[torch.Tensor]]: +# # Setup +# data["node_attrs"].requires_grad_(True) +# data["positions"].requires_grad_(True) +# num_graphs = data["ptr"].numel() - 1 +# displacement = torch.zeros( +# (num_graphs, 3, 3), +# dtype=data["positions"].dtype, +# device=data["positions"].device, +# ) +# if compute_virials or compute_stress or compute_displacement: +# ( +# data["positions"], +# data["shifts"], +# displacement, +# ) = get_symmetric_displacement( +# positions=data["positions"], +# unit_shifts=data["unit_shifts"], +# cell=data["cell"], +# edge_index=data["edge_index"], +# num_graphs=num_graphs, +# batch=data["batch"], +# ) + +# # Atomic energies +# node_e0 = self.atomic_energies_fn(data["node_attrs"]) +# e0 = scatter_sum( +# src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs +# ) # [n_graphs,] + +# # Embeddings +# node_feats = self.node_embedding(data["node_attrs"]) +# vectors, lengths = get_edge_vectors_and_lengths( +# positions=data["positions"], +# edge_index=data["edge_index"], +# shifts=data["shifts"], +# ) +# edge_attrs = self.spherical_harmonics(vectors) +# edge_feats = self.radial_embedding( +# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers +# ) + +# # Interactions +# energies = [e0] +# node_energies_list = [node_e0] +# dipoles = [] +# for interaction, product, readout in zip( +# self.interactions, self.products, self.readouts +# ): +# node_feats, sc = interaction( +# node_attrs=data["node_attrs"], +# node_feats=node_feats, +# edge_attrs=edge_attrs, +# edge_feats=edge_feats, +# edge_index=data["edge_index"], +# ) +# node_feats = product( +# node_feats=node_feats, +# sc=sc, +# node_attrs=data["node_attrs"], +# ) +# node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] +# # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] +# node_energies = node_out[:, 0] +# energy = scatter_sum( +# src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs +# ) # [n_graphs,] +# energies.append(energy) +# node_dipoles = node_out[:, 1:] +# dipoles.append(node_dipoles) + +# # Compute the energies and dipoles +# contributions = torch.stack(energies, dim=-1) +# total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] +# node_energy_contributions = torch.stack(node_energies_list, dim=-1) +# node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] +# contributions_dipoles = torch.stack( +# dipoles, dim=-1 +# ) # [n_nodes,3,n_contributions] +# atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] +# total_dipole = scatter_sum( +# src=atomic_dipoles, +# index=data["batch"].unsqueeze(-1), +# dim=0, +# dim_size=num_graphs, +# ) # [n_graphs,3] +# baseline = compute_fixed_charge_dipole( +# charges=data["charges"], +# positions=data["positions"], +# batch=data["batch"], +# num_graphs=num_graphs, +# ) # [n_graphs,3] +# total_dipole = total_dipole + baseline + +# forces, virials, stress, _ = get_outputs( +# energy=total_energy, +# positions=data["positions"], +# displacement=displacement, +# cell=data["cell"], +# training=training, +# compute_force=compute_force, +# compute_virials=compute_virials, +# compute_stress=compute_stress, +# ) + +# output = { +# "energy": total_energy, +# "node_energy": node_energy, +# "contributions": contributions, +# "forces": forces, +# "virials": virials, +# "stress": stress, +# "displacement": displacement, +# "dipole": total_dipole, +# "atomic_dipoles": atomic_dipoles, +# } +# return output diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index c6a44fff6..a2e569475 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -20,194 +20,194 @@ from .blocks import AtomicEnergiesBlock -def compute_forces( - energy: torch.Tensor, positions: torch.Tensor, training: bool = True -) -> torch.Tensor: - grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] - gradient = torch.autograd.grad( - outputs=[energy], # [n_graphs, ] - inputs=[positions], # [n_nodes, 3] - grad_outputs=grad_outputs, - retain_graph=training, # Make sure the graph is not destroyed during training - create_graph=training, # Create graph for second derivative - allow_unused=True, # For complete dissociation turn to true - )[ - 0 - ] # [n_nodes, 3] - if gradient is None: - return torch.zeros_like(positions) - return -1 * gradient - - -def compute_forces_virials( - energy: torch.Tensor, - positions: torch.Tensor, - displacement: torch.Tensor, - cell: torch.Tensor, - training: bool = True, - compute_stress: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] - forces, virials = torch.autograd.grad( - outputs=[energy], # [n_graphs, ] - inputs=[positions, displacement], # [n_nodes, 3] - grad_outputs=grad_outputs, - retain_graph=training, # Make sure the graph is not destroyed during training - create_graph=training, # Create graph for second derivative - allow_unused=True, - ) - stress = torch.zeros_like(displacement) - if compute_stress and virials is not None: - cell = cell.view(-1, 3, 3) - volume = torch.linalg.det(cell).abs().unsqueeze(-1) - stress = virials / volume.view(-1, 1, 1) - stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) - if forces is None: - forces = torch.zeros_like(positions) - if virials is None: - virials = torch.zeros((1, 3, 3)) - - return -1 * forces, -1 * virials, stress - - -def get_symmetric_displacement( - positions: torch.Tensor, - unit_shifts: torch.Tensor, - cell: Optional[torch.Tensor], - edge_index: torch.Tensor, - num_graphs: int, - batch: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if cell is None: - cell = torch.zeros( - num_graphs * 3, - 3, - dtype=positions.dtype, - device=positions.device, - ) - sender = edge_index[0] - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=positions.dtype, - device=positions.device, - ) - displacement.requires_grad_(True) - symmetric_displacement = 0.5 * ( - displacement + displacement.transpose(-1, -2) - ) # From https://github.com/mir-group/nequip - positions = positions + torch.einsum( - "be,bec->bc", positions, symmetric_displacement[batch] - ) - cell = cell.view(-1, 3, 3) - cell = cell + torch.matmul(cell, symmetric_displacement) - shifts = torch.einsum( - "be,bec->bc", - unit_shifts, - cell[batch[sender]], - ) - return positions, shifts, displacement - - -@torch.jit.unused -def compute_hessians_vmap( - forces: torch.Tensor, - positions: torch.Tensor, -) -> torch.Tensor: - forces_flatten = forces.view(-1) - num_elements = forces_flatten.shape[0] - - def get_vjp(v): - return torch.autograd.grad( - -1 * forces_flatten, - positions, - v, - retain_graph=True, - create_graph=False, - allow_unused=False, - ) - - I_N = torch.eye(num_elements).to(forces.device) - try: - chunk_size = 1 if num_elements < 64 else 16 - gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( - I_N - )[0] - except RuntimeError: - gradient = compute_hessians_loop(forces, positions) - if gradient is None: - return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) - return gradient - - -@torch.jit.unused -def compute_hessians_loop( - forces: torch.Tensor, - positions: torch.Tensor, -) -> torch.Tensor: - - hessian = [] - for grad_elem in forces.view(-1): - hess_row = torch.autograd.grad( - outputs=[-1 * grad_elem], - inputs=[positions], - grad_outputs=torch.ones_like(grad_elem), - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] - hess_row = hess_row.detach() # this makes it very slow? but needs less memory - if hess_row is None: - hessian.append(torch.zeros_like(positions)) - else: - hessian.append(hess_row) - hessian = torch.stack(hessian) - return hessian - - -def get_outputs( - energy: torch.Tensor, - positions: torch.Tensor, - displacement: Optional[torch.Tensor], - cell: torch.Tensor, - training: bool = False, - compute_force: bool = True, - compute_virials: bool = True, - compute_stress: bool = True, - compute_hessian: bool = False, -) -> Tuple[ - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], -]: - if (compute_virials or compute_stress) and displacement is not None: - # forces come for free - forces, virials, stress = compute_forces_virials( - energy=energy, - positions=positions, - displacement=displacement, - cell=cell, - compute_stress=compute_stress, - training=(training or compute_hessian), - ) - elif compute_force: - forces, virials, stress = ( - compute_forces( - energy=energy, - positions=positions, - training=(training or compute_hessian), - ), - None, - None, - ) - else: - forces, virials, stress = (None, None, None) - if compute_hessian: - assert forces is not None, "Forces must be computed to get the hessian" - hessian = compute_hessians_vmap(forces, positions) - else: - hessian = None - return forces, virials, stress, hessian +# def compute_forces( +# energy: torch.Tensor, positions: torch.Tensor, training: bool = True +# ) -> torch.Tensor: +# grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] +# gradient = torch.autograd.grad( +# outputs=[energy], # [n_graphs, ] +# inputs=[positions], # [n_nodes, 3] +# grad_outputs=grad_outputs, +# retain_graph=training, # Make sure the graph is not destroyed during training +# create_graph=training, # Create graph for second derivative +# allow_unused=True, # For complete dissociation turn to true +# )[ +# 0 +# ] # [n_nodes, 3] +# if gradient is None: +# return torch.zeros_like(positions) +# return -1 * gradient + + +# def compute_forces_virials( +# energy: torch.Tensor, +# positions: torch.Tensor, +# displacement: torch.Tensor, +# cell: torch.Tensor, +# training: bool = True, +# compute_stress: bool = False, +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: +# grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] +# forces, virials = torch.autograd.grad( +# outputs=[energy], # [n_graphs, ] +# inputs=[positions, displacement], # [n_nodes, 3] +# grad_outputs=grad_outputs, +# retain_graph=training, # Make sure the graph is not destroyed during training +# create_graph=training, # Create graph for second derivative +# allow_unused=True, +# ) +# stress = torch.zeros_like(displacement) +# if compute_stress and virials is not None: +# cell = cell.view(-1, 3, 3) +# volume = torch.linalg.det(cell).abs().unsqueeze(-1) +# stress = virials / volume.view(-1, 1, 1) +# stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) +# if forces is None: +# forces = torch.zeros_like(positions) +# if virials is None: +# virials = torch.zeros((1, 3, 3)) + +# return -1 * forces, -1 * virials, stress + + +# def get_symmetric_displacement( +# positions: torch.Tensor, +# unit_shifts: torch.Tensor, +# cell: Optional[torch.Tensor], +# edge_index: torch.Tensor, +# num_graphs: int, +# batch: torch.Tensor, +# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +# if cell is None: +# cell = torch.zeros( +# num_graphs * 3, +# 3, +# dtype=positions.dtype, +# device=positions.device, +# ) +# sender = edge_index[0] +# displacement = torch.zeros( +# (num_graphs, 3, 3), +# dtype=positions.dtype, +# device=positions.device, +# ) +# displacement.requires_grad_(True) +# symmetric_displacement = 0.5 * ( +# displacement + displacement.transpose(-1, -2) +# ) # From https://github.com/mir-group/nequip +# positions = positions + torch.einsum( +# "be,bec->bc", positions, symmetric_displacement[batch] +# ) +# cell = cell.view(-1, 3, 3) +# cell = cell + torch.matmul(cell, symmetric_displacement) +# shifts = torch.einsum( +# "be,bec->bc", +# unit_shifts, +# cell[batch[sender]], +# ) +# return positions, shifts, displacement + + +# @torch.jit.unused +# def compute_hessians_vmap( +# forces: torch.Tensor, +# positions: torch.Tensor, +# ) -> torch.Tensor: +# forces_flatten = forces.view(-1) +# num_elements = forces_flatten.shape[0] + +# def get_vjp(v): +# return torch.autograd.grad( +# -1 * forces_flatten, +# positions, +# v, +# retain_graph=True, +# create_graph=False, +# allow_unused=False, +# ) + +# I_N = torch.eye(num_elements).to(forces.device) +# try: +# chunk_size = 1 if num_elements < 64 else 16 +# gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( +# I_N +# )[0] +# except RuntimeError: +# gradient = compute_hessians_loop(forces, positions) +# if gradient is None: +# return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) +# return gradient + + +# @torch.jit.unused +# def compute_hessians_loop( +# forces: torch.Tensor, +# positions: torch.Tensor, +# ) -> torch.Tensor: + +# hessian = [] +# for grad_elem in forces.view(-1): +# hess_row = torch.autograd.grad( +# outputs=[-1 * grad_elem], +# inputs=[positions], +# grad_outputs=torch.ones_like(grad_elem), +# retain_graph=True, +# create_graph=False, +# allow_unused=False, +# )[0] +# hess_row = hess_row.detach() # this makes it very slow? but needs less memory +# if hess_row is None: +# hessian.append(torch.zeros_like(positions)) +# else: +# hessian.append(hess_row) +# hessian = torch.stack(hessian) +# return hessian + + +# def get_outputs( +# energy: torch.Tensor, +# positions: torch.Tensor, +# displacement: Optional[torch.Tensor], +# cell: torch.Tensor, +# training: bool = False, +# compute_force: bool = True, +# compute_virials: bool = True, +# compute_stress: bool = True, +# compute_hessian: bool = False, +# ) -> Tuple[ +# Optional[torch.Tensor], +# Optional[torch.Tensor], +# Optional[torch.Tensor], +# Optional[torch.Tensor], +# ]: +# if (compute_virials or compute_stress) and displacement is not None: +# # forces come for free +# forces, virials, stress = compute_forces_virials( +# energy=energy, +# positions=positions, +# displacement=displacement, +# cell=cell, +# compute_stress=compute_stress, +# training=(training or compute_hessian), +# ) +# elif compute_force: +# forces, virials, stress = ( +# compute_forces( +# energy=energy, +# positions=positions, +# training=(training or compute_hessian), +# ), +# None, +# None, +# ) +# else: +# forces, virials, stress = (None, None, None) +# if compute_hessian: +# assert forces is not None, "Forces must be computed to get the hessian" +# hessian = compute_hessians_vmap(forces, positions) +# else: +# hessian = None +# return forces, virials, stress, hessian def get_edge_vectors_and_lengths( @@ -253,162 +253,162 @@ def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max return torch.cat(out, dim=-1) -def compute_mean_std_atomic_inter_energy( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - avg_atom_inter_es_list = [] - - for batch in data_loader: - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - avg_atom_inter_es_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - - avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] - mean = to_numpy(torch.mean(avg_atom_inter_es)).item() - std = to_numpy(torch.std(avg_atom_inter_es)).item() - std = _check_non_zero(std) - - return mean, std - - -def _compute_mean_std_atomic_inter_energy( - batch: Batch, - atomic_energies_fn: AtomicEnergiesBlock, -) -> Tuple[torch.Tensor, torch.Tensor]: - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energies = (batch.energy - graph_e0s) / graph_sizes - return atom_energies - - -def compute_mean_rms_energy_forces( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - atom_energy_list = [] - forces_list = [] - - for batch in data_loader: - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energy_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - - atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] - forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - - mean = to_numpy(torch.mean(atom_energies)).item() - rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() - rms = _check_non_zero(rms) - - return mean, rms - - -def _compute_mean_rms_energy_forces( - batch: Batch, - atomic_energies_fn: AtomicEnergiesBlock, -) -> Tuple[torch.Tensor, torch.Tensor]: - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } - forces = batch.forces # {[n_graphs*n_atoms,3], } - - return atom_energies, forces - - -def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: - num_neighbors = [] - - for batch in data_loader: - _, receivers = batch.edge_index - _, counts = torch.unique(receivers, return_counts=True) - num_neighbors.append(counts) - - avg_num_neighbors = torch.mean( - torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) - ) - return to_numpy(avg_num_neighbors).item() - - -def compute_statistics( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float, float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - atom_energy_list = [] - forces_list = [] - num_neighbors = [] - - for batch in data_loader: - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energy_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - - _, receivers = batch.edge_index - _, counts = torch.unique(receivers, return_counts=True) - num_neighbors.append(counts) - - atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] - forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - - mean = to_numpy(torch.mean(atom_energies)).item() - rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() - - avg_num_neighbors = torch.mean( - torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) - ) - - return to_numpy(avg_num_neighbors).item(), mean, rms - - -def compute_rms_dipoles( - data_loader: torch.utils.data.DataLoader, -) -> Tuple[float, float]: - dipoles_list = [] - for batch in data_loader: - dipoles_list.append(batch.dipole) # {[n_graphs,3], } - - dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } - rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() - rms = _check_non_zero(rms) - return rms - - -def compute_fixed_charge_dipole( - charges: torch.Tensor, - positions: torch.Tensor, - batch: torch.Tensor, - num_graphs: int, -) -> torch.Tensor: - mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] - return scatter_sum( - src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs - ) # [N_graphs,3] +# def compute_mean_std_atomic_inter_energy( +# data_loader: torch.utils.data.DataLoader, +# atomic_energies: np.ndarray, +# ) -> Tuple[float, float]: +# atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + +# avg_atom_inter_es_list = [] + +# for batch in data_loader: +# node_e0 = atomic_energies_fn(batch.node_attrs) +# graph_e0s = scatter_sum( +# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs +# ) +# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] +# avg_atom_inter_es_list.append( +# (batch.energy - graph_e0s) / graph_sizes +# ) # {[n_graphs], } + +# avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] +# mean = to_numpy(torch.mean(avg_atom_inter_es)).item() +# std = to_numpy(torch.std(avg_atom_inter_es)).item() +# std = _check_non_zero(std) + +# return mean, std + + +# def _compute_mean_std_atomic_inter_energy( +# batch: Batch, +# atomic_energies_fn: AtomicEnergiesBlock, +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# node_e0 = atomic_energies_fn(batch.node_attrs) +# graph_e0s = scatter_sum( +# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs +# ) +# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] +# atom_energies = (batch.energy - graph_e0s) / graph_sizes +# return atom_energies + + +# def compute_mean_rms_energy_forces( +# data_loader: torch.utils.data.DataLoader, +# atomic_energies: np.ndarray, +# ) -> Tuple[float, float]: +# atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + +# atom_energy_list = [] +# forces_list = [] + +# for batch in data_loader: +# node_e0 = atomic_energies_fn(batch.node_attrs) +# graph_e0s = scatter_sum( +# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs +# ) +# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] +# atom_energy_list.append( +# (batch.energy - graph_e0s) / graph_sizes +# ) # {[n_graphs], } +# forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + +# atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] +# forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + +# mean = to_numpy(torch.mean(atom_energies)).item() +# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() +# rms = _check_non_zero(rms) + +# return mean, rms + + +# def _compute_mean_rms_energy_forces( +# batch: Batch, +# atomic_energies_fn: AtomicEnergiesBlock, +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# node_e0 = atomic_energies_fn(batch.node_attrs) +# graph_e0s = scatter_sum( +# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs +# ) +# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] +# atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } +# forces = batch.forces # {[n_graphs*n_atoms,3], } + +# return atom_energies, forces + + +# def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: +# num_neighbors = [] + +# for batch in data_loader: +# _, receivers = batch.edge_index +# _, counts = torch.unique(receivers, return_counts=True) +# num_neighbors.append(counts) + +# avg_num_neighbors = torch.mean( +# torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) +# ) +# return to_numpy(avg_num_neighbors).item() + + +# def compute_statistics( +# data_loader: torch.utils.data.DataLoader, +# atomic_energies: np.ndarray, +# ) -> Tuple[float, float, float, float]: +# atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + +# atom_energy_list = [] +# forces_list = [] +# num_neighbors = [] + +# for batch in data_loader: +# node_e0 = atomic_energies_fn(batch.node_attrs) +# graph_e0s = scatter_sum( +# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs +# ) +# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] +# atom_energy_list.append( +# (batch.energy - graph_e0s) / graph_sizes +# ) # {[n_graphs], } +# forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + +# _, receivers = batch.edge_index +# _, counts = torch.unique(receivers, return_counts=True) +# num_neighbors.append(counts) + +# atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] +# forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + +# mean = to_numpy(torch.mean(atom_energies)).item() +# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + +# avg_num_neighbors = torch.mean( +# torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) +# ) + +# return to_numpy(avg_num_neighbors).item(), mean, rms + + +# def compute_rms_dipoles( +# data_loader: torch.utils.data.DataLoader, +# ) -> Tuple[float, float]: +# dipoles_list = [] +# for batch in data_loader: +# dipoles_list.append(batch.dipole) # {[n_graphs,3], } + +# dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } +# rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() +# rms = _check_non_zero(rms) +# return rms + + +# def compute_fixed_charge_dipole( +# charges: torch.Tensor, +# positions: torch.Tensor, +# batch: torch.Tensor, +# num_graphs: int, +# ) -> torch.Tensor: +# mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] +# return scatter_sum( +# src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs +# ) # [N_graphs,3] diff --git a/hydragnn/utils/mace_utils/tools/__init__.py b/hydragnn/utils/mace_utils/tools/__init__.py index 54c594550..ce89f5967 100644 --- a/hydragnn/utils/mace_utils/tools/__init__.py +++ b/hydragnn/utils/mace_utils/tools/__init__.py @@ -1,7 +1,7 @@ -from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser -from .arg_parser_tools import check_args +# from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +# from .arg_parser_tools import check_args from .cg import U_matrix_real -from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState +# from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .finetuning_utils import load_foundations from .torch_tools import ( TensorDict, @@ -16,22 +16,22 @@ to_one_hot, voigt_to_matrix, ) -from .train import SWAContainer, evaluate, train -from .utils import ( - AtomicNumberTable, - MetricsLogger, - atomic_numbers_to_indices, - compute_c, - compute_mae, - compute_q95, - compute_rel_mae, - compute_rel_rmse, - compute_rmse, - get_atomic_number_table_from_zs, - get_optimizer, - get_tag, - setup_logger, -) +# from .train import SWAContainer, evaluate, train +# from .utils import ( +# AtomicNumberTable, +# MetricsLogger, +# atomic_numbers_to_indices, +# compute_c, +# compute_mae, +# compute_q95, +# compute_rel_mae, +# compute_rel_rmse, +# compute_rmse, +# get_atomic_number_table_from_zs, +# get_optimizer, +# get_tag, +# setup_logger, +# ) __all__ = [ "TensorDict", diff --git a/hydragnn/utils/mace_utils/tools/arg_parser.py b/hydragnn/utils/mace_utils/tools/arg_parser.py index 2b0e2b56e..73e1e9d24 100644 --- a/hydragnn/utils/mace_utils/tools/arg_parser.py +++ b/hydragnn/utils/mace_utils/tools/arg_parser.py @@ -1,792 +1,792 @@ -########################################################################################### -# Parsing functionalities -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### +# ########################################################################################### +# # Parsing functionalities +# # Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### -import argparse -import os -from typing import Optional +# import argparse +# import os +# from typing import Optional -def build_default_arg_parser() -> argparse.ArgumentParser: - try: - import configargparse +# def build_default_arg_parser() -> argparse.ArgumentParser: +# try: +# import configargparse - parser = configargparse.ArgumentParser( - config_file_parser_class=configargparse.YAMLConfigFileParser, - ) - parser.add( - "--config", - type=str, - is_config_file=True, - help="config file to agregate options", - ) - except ImportError: - parser = argparse.ArgumentParser() +# parser = configargparse.ArgumentParser( +# config_file_parser_class=configargparse.YAMLConfigFileParser, +# ) +# parser.add( +# "--config", +# type=str, +# is_config_file=True, +# help="config file to agregate options", +# ) +# except ImportError: +# parser = argparse.ArgumentParser() - # Name and seed - parser.add_argument("--name", help="experiment name", required=True) - parser.add_argument("--seed", help="random seed", type=int, default=123) +# # Name and seed +# parser.add_argument("--name", help="experiment name", required=True) +# parser.add_argument("--seed", help="random seed", type=int, default=123) - # Directories - parser.add_argument( - "--work_dir", - help="set directory for all files and folders", - type=str, - default=".", - ) - parser.add_argument( - "--log_dir", help="directory for log files", type=str, default=None - ) - parser.add_argument( - "--model_dir", help="directory for final model", type=str, default=None - ) - parser.add_argument( - "--checkpoints_dir", - help="directory for checkpoint files", - type=str, - default=None, - ) - parser.add_argument( - "--results_dir", help="directory for results", type=str, default=None - ) - parser.add_argument( - "--downloads_dir", help="directory for downloads", type=str, default=None - ) +# # Directories +# parser.add_argument( +# "--work_dir", +# help="set directory for all files and folders", +# type=str, +# default=".", +# ) +# parser.add_argument( +# "--log_dir", help="directory for log files", type=str, default=None +# ) +# parser.add_argument( +# "--model_dir", help="directory for final model", type=str, default=None +# ) +# parser.add_argument( +# "--checkpoints_dir", +# help="directory for checkpoint files", +# type=str, +# default=None, +# ) +# parser.add_argument( +# "--results_dir", help="directory for results", type=str, default=None +# ) +# parser.add_argument( +# "--downloads_dir", help="directory for downloads", type=str, default=None +# ) - # Device and logging - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda", "mps"], - default="cpu", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument( - "--distributed", - help="train in multi-GPU data parallel mode", - action="store_true", - default=False, - ) - parser.add_argument("--log_level", help="log level", type=str, default="INFO") +# # Device and logging +# parser.add_argument( +# "--device", +# help="select device", +# type=str, +# choices=["cpu", "cuda", "mps"], +# default="cpu", +# ) +# parser.add_argument( +# "--default_dtype", +# help="set default dtype", +# type=str, +# choices=["float32", "float64"], +# default="float64", +# ) +# parser.add_argument( +# "--distributed", +# help="train in multi-GPU data parallel mode", +# action="store_true", +# default=False, +# ) +# parser.add_argument("--log_level", help="log level", type=str, default="INFO") - parser.add_argument( - "--error_table", - help="Type of error table produced at the end of the training", - type=str, - choices=[ - "PerAtomRMSE", - "TotalRMSE", - "PerAtomRMSEstressvirials", - "PerAtomMAEstressvirials", - "PerAtomMAE", - "TotalMAE", - "DipoleRMSE", - "DipoleMAE", - "EnergyDipoleRMSE", - ], - default="PerAtomRMSE", - ) +# parser.add_argument( +# "--error_table", +# help="Type of error table produced at the end of the training", +# type=str, +# choices=[ +# "PerAtomRMSE", +# "TotalRMSE", +# "PerAtomRMSEstressvirials", +# "PerAtomMAEstressvirials", +# "PerAtomMAE", +# "TotalMAE", +# "DipoleRMSE", +# "DipoleMAE", +# "EnergyDipoleRMSE", +# ], +# default="PerAtomRMSE", +# ) - # Model - parser.add_argument( - "--model", - help="model type", - default="MACE", - choices=[ - "BOTNet", - "MACE", - "ScaleShiftMACE", - "ScaleShiftBOTNet", - "AtomicDipolesMACE", - "EnergyDipolesMACE", - ], - ) - parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 - ) - parser.add_argument( - "--radial_type", - help="type of radial basis functions", - type=str, - default="bessel", - choices=["bessel", "gaussian", "chebyshev"], - ) - parser.add_argument( - "--num_radial_basis", - help="number of radial basis functions", - type=int, - default=8, - ) - parser.add_argument( - "--num_cutoff_basis", - help="number of basis functions for smooth cutoff", - type=int, - default=5, - ) - parser.add_argument( - "--pair_repulsion", - help="use pair repulsion term with ZBL potential", - action="store_true", - default=False, - ) - parser.add_argument( - "--distance_transform", - help="use distance transform for radial basis functions", - default="None", - choices=["None", "Agnesi", "Soft"], - ) - parser.add_argument( - "--interaction", - help="name of interaction block", - type=str, - default="RealAgnosticResidualInteractionBlock", - choices=[ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticAttResidualInteractionBlock", - "RealAgnosticInteractionBlock", - ], - ) - parser.add_argument( - "--interaction_first", - help="name of interaction block", - type=str, - default="RealAgnosticResidualInteractionBlock", - choices=[ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticInteractionBlock", - ], - ) - parser.add_argument( - "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 - ) - parser.add_argument( - "--correlation", help="correlation order at each layer", type=int, default=3 - ) - parser.add_argument( - "--num_interactions", help="number of interactions", type=int, default=2 - ) - parser.add_argument( - "--MLP_irreps", - help="hidden irreps of the MLP in last readout", - type=str, - default="16x0e", - ) - parser.add_argument( - "--radial_MLP", - help="width of the radial MLP", - type=str, - default="[64, 64, 64]", - ) - parser.add_argument( - "--hidden_irreps", - help="irreps for hidden node states", - type=str, - default=None, - ) - # add option to specify irreps by channel number and max L - parser.add_argument( - "--num_channels", - help="number of embedding channels", - type=int, - default=None, - ) - parser.add_argument( - "--max_L", - help="max L equivariance of the message", - type=int, - default=None, - ) - parser.add_argument( - "--gate", - help="non linearity for last readout", - type=str, - default="silu", - choices=["silu", "tanh", "abs", "None"], - ) - parser.add_argument( - "--scaling", - help="type of scaling to the output", - type=str, - default="rms_forces_scaling", - choices=["std_scaling", "rms_forces_scaling", "no_scaling"], - ) - parser.add_argument( - "--avg_num_neighbors", - help="normalization factor for the message", - type=float, - default=1, - ) - parser.add_argument( - "--compute_avg_num_neighbors", - help="normalization factor for the message", - type=bool, - default=True, - ) - parser.add_argument( - "--compute_stress", - help="Select True to compute stress", - type=bool, - default=False, - ) - parser.add_argument( - "--compute_forces", - help="Select True to compute forces", - type=bool, - default=True, - ) +# # Model +# parser.add_argument( +# "--model", +# help="model type", +# default="MACE", +# choices=[ +# "BOTNet", +# "MACE", +# "ScaleShiftMACE", +# "ScaleShiftBOTNet", +# "AtomicDipolesMACE", +# "EnergyDipolesMACE", +# ], +# ) +# parser.add_argument( +# "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 +# ) +# parser.add_argument( +# "--radial_type", +# help="type of radial basis functions", +# type=str, +# default="bessel", +# choices=["bessel", "gaussian", "chebyshev"], +# ) +# parser.add_argument( +# "--num_radial_basis", +# help="number of radial basis functions", +# type=int, +# default=8, +# ) +# parser.add_argument( +# "--num_cutoff_basis", +# help="number of basis functions for smooth cutoff", +# type=int, +# default=5, +# ) +# parser.add_argument( +# "--pair_repulsion", +# help="use pair repulsion term with ZBL potential", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--distance_transform", +# help="use distance transform for radial basis functions", +# default="None", +# choices=["None", "Agnesi", "Soft"], +# ) +# parser.add_argument( +# "--interaction", +# help="name of interaction block", +# type=str, +# default="RealAgnosticResidualInteractionBlock", +# choices=[ +# "RealAgnosticResidualInteractionBlock", +# "RealAgnosticAttResidualInteractionBlock", +# "RealAgnosticInteractionBlock", +# ], +# ) +# parser.add_argument( +# "--interaction_first", +# help="name of interaction block", +# type=str, +# default="RealAgnosticResidualInteractionBlock", +# choices=[ +# "RealAgnosticResidualInteractionBlock", +# "RealAgnosticInteractionBlock", +# ], +# ) +# parser.add_argument( +# "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 +# ) +# parser.add_argument( +# "--correlation", help="correlation order at each layer", type=int, default=3 +# ) +# parser.add_argument( +# "--num_interactions", help="number of interactions", type=int, default=2 +# ) +# parser.add_argument( +# "--MLP_irreps", +# help="hidden irreps of the MLP in last readout", +# type=str, +# default="16x0e", +# ) +# parser.add_argument( +# "--radial_MLP", +# help="width of the radial MLP", +# type=str, +# default="[64, 64, 64]", +# ) +# parser.add_argument( +# "--hidden_irreps", +# help="irreps for hidden node states", +# type=str, +# default=None, +# ) +# # add option to specify irreps by channel number and max L +# parser.add_argument( +# "--num_channels", +# help="number of embedding channels", +# type=int, +# default=None, +# ) +# parser.add_argument( +# "--max_L", +# help="max L equivariance of the message", +# type=int, +# default=None, +# ) +# parser.add_argument( +# "--gate", +# help="non linearity for last readout", +# type=str, +# default="silu", +# choices=["silu", "tanh", "abs", "None"], +# ) +# parser.add_argument( +# "--scaling", +# help="type of scaling to the output", +# type=str, +# default="rms_forces_scaling", +# choices=["std_scaling", "rms_forces_scaling", "no_scaling"], +# ) +# parser.add_argument( +# "--avg_num_neighbors", +# help="normalization factor for the message", +# type=float, +# default=1, +# ) +# parser.add_argument( +# "--compute_avg_num_neighbors", +# help="normalization factor for the message", +# type=bool, +# default=True, +# ) +# parser.add_argument( +# "--compute_stress", +# help="Select True to compute stress", +# type=bool, +# default=False, +# ) +# parser.add_argument( +# "--compute_forces", +# help="Select True to compute forces", +# type=bool, +# default=True, +# ) - # Dataset - parser.add_argument( - "--train_file", - help="Training set file, format is .xyz or .h5", - type=str, - required=True, - ) - parser.add_argument( - "--valid_file", - help="Validation set .xyz or .h5 file", - default=None, - type=str, - required=False, - ) - parser.add_argument( - "--valid_fraction", - help="Fraction of training set used for validation", - type=float, - default=0.1, - required=False, - ) - parser.add_argument( - "--test_file", - help="Test set .xyz pt .h5 file", - type=str, - ) - parser.add_argument( - "--test_dir", - help="Path to directory with test files named as test_*.h5", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--multi_processed_test", - help="Boolean value for whether the test data was multiprocessed", - type=bool, - default=False, - required=False, - ) - parser.add_argument( - "--num_workers", - help="Number of workers for data loading", - type=int, - default=0, - ) - parser.add_argument( - "--pin_memory", - help="Pin memory for data loading", - default=True, - type=bool, - ) - parser.add_argument( - "--atomic_numbers", - help="List of atomic numbers", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--mean", - help="Mean energy per atom of training set", - type=float, - default=None, - required=False, - ) - parser.add_argument( - "--std", - help="Standard deviation of force components in the training set", - type=float, - default=None, - required=False, - ) - parser.add_argument( - "--statistics_file", - help="json file containing statistics of training set", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--E0s", - help="Dictionary of isolated atom energies", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--keep_isolated_atoms", - help="Keep isolated atoms in the dataset, useful for transfer learning", - type=bool, - default=False, - ) - parser.add_argument( - "--energy_key", - help="Key of reference energies in training xyz", - type=str, - default="REF_energy", - ) - parser.add_argument( - "--forces_key", - help="Key of reference forces in training xyz", - type=str, - default="REF_forces", - ) - parser.add_argument( - "--virials_key", - help="Key of reference virials in training xyz", - type=str, - default="REF_virials", - ) - parser.add_argument( - "--stress_key", - help="Key of reference stress in training xyz", - type=str, - default="REF_stress", - ) - parser.add_argument( - "--dipole_key", - help="Key of reference dipoles in training xyz", - type=str, - default="REF_dipole", - ) - parser.add_argument( - "--charges_key", - help="Key of atomic charges in training xyz", - type=str, - default="REF_charges", - ) +# # Dataset +# parser.add_argument( +# "--train_file", +# help="Training set file, format is .xyz or .h5", +# type=str, +# required=True, +# ) +# parser.add_argument( +# "--valid_file", +# help="Validation set .xyz or .h5 file", +# default=None, +# type=str, +# required=False, +# ) +# parser.add_argument( +# "--valid_fraction", +# help="Fraction of training set used for validation", +# type=float, +# default=0.1, +# required=False, +# ) +# parser.add_argument( +# "--test_file", +# help="Test set .xyz pt .h5 file", +# type=str, +# ) +# parser.add_argument( +# "--test_dir", +# help="Path to directory with test files named as test_*.h5", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--multi_processed_test", +# help="Boolean value for whether the test data was multiprocessed", +# type=bool, +# default=False, +# required=False, +# ) +# parser.add_argument( +# "--num_workers", +# help="Number of workers for data loading", +# type=int, +# default=0, +# ) +# parser.add_argument( +# "--pin_memory", +# help="Pin memory for data loading", +# default=True, +# type=bool, +# ) +# parser.add_argument( +# "--atomic_numbers", +# help="List of atomic numbers", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--mean", +# help="Mean energy per atom of training set", +# type=float, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--std", +# help="Standard deviation of force components in the training set", +# type=float, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--statistics_file", +# help="json file containing statistics of training set", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--E0s", +# help="Dictionary of isolated atom energies", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--keep_isolated_atoms", +# help="Keep isolated atoms in the dataset, useful for transfer learning", +# type=bool, +# default=False, +# ) +# parser.add_argument( +# "--energy_key", +# help="Key of reference energies in training xyz", +# type=str, +# default="REF_energy", +# ) +# parser.add_argument( +# "--forces_key", +# help="Key of reference forces in training xyz", +# type=str, +# default="REF_forces", +# ) +# parser.add_argument( +# "--virials_key", +# help="Key of reference virials in training xyz", +# type=str, +# default="REF_virials", +# ) +# parser.add_argument( +# "--stress_key", +# help="Key of reference stress in training xyz", +# type=str, +# default="REF_stress", +# ) +# parser.add_argument( +# "--dipole_key", +# help="Key of reference dipoles in training xyz", +# type=str, +# default="REF_dipole", +# ) +# parser.add_argument( +# "--charges_key", +# help="Key of atomic charges in training xyz", +# type=str, +# default="REF_charges", +# ) - # Loss and optimization - parser.add_argument( - "--loss", - help="type of loss", - default="weighted", - choices=[ - "ef", - "weighted", - "forces_only", - "virials", - "stress", - "dipole", - "huber", - "universal", - "energy_forces_dipole", - ], - ) - parser.add_argument( - "--forces_weight", help="weight of forces loss", type=float, default=100.0 - ) - parser.add_argument( - "--swa_forces_weight", - "--stage_two_forces_weight", - help="weight of forces loss after starting Stage Two (previously called swa)", - type=float, - default=100.0, - dest="swa_forces_weight", - ) - parser.add_argument( - "--energy_weight", help="weight of energy loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_energy_weight", - "--stage_two_energy_weight", - help="weight of energy loss after starting Stage Two (previously called swa)", - type=float, - default=1000.0, - dest="swa_energy_weight", - ) - parser.add_argument( - "--virials_weight", help="weight of virials loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_virials_weight", - "--stage_two_virials_weight", - help="weight of virials loss after starting Stage Two (previously called swa)", - type=float, - default=10.0, - dest="swa_virials_weight", - ) - parser.add_argument( - "--stress_weight", help="weight of virials loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_stress_weight", - "--stage_two_stress_weight", - help="weight of stress loss after starting Stage Two (previously called swa)", - type=float, - default=10.0, - dest="swa_stress_weight", - ) - parser.add_argument( - "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_dipole_weight", - "--stage_two_dipole_weight", - help="weight of dipoles after starting Stage Two (previously called swa)", - type=float, - default=1.0, - dest="swa_dipole_weight", - ) - parser.add_argument( - "--config_type_weights", - help="String of dictionary containing the weights for each config type", - type=str, - default='{"Default":1.0}', - ) - parser.add_argument( - "--huber_delta", - help="delta parameter for huber loss", - type=float, - default=0.01, - ) - parser.add_argument( - "--optimizer", - help="Optimizer for parameter optimization", - type=str, - default="adam", - choices=["adam", "adamw", "schedulefree"], - ) - parser.add_argument( - "--beta", - help="Beta parameter for the optimizer", - type=float, - default=0.9, - ) - parser.add_argument("--batch_size", help="batch size", type=int, default=10) - parser.add_argument( - "--valid_batch_size", help="Validation batch size", type=int, default=10 - ) - parser.add_argument( - "--lr", help="Learning rate of optimizer", type=float, default=0.01 - ) - parser.add_argument( - "--swa_lr", - "--stage_two_lr", - help="Learning rate of optimizer in Stage Two (previously called swa)", - type=float, - default=1e-3, - dest="swa_lr", - ) - parser.add_argument( - "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 - ) - parser.add_argument( - "--amsgrad", - help="use amsgrad variant of optimizer", - action="store_true", - default=True, - ) - parser.add_argument( - "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" - ) - parser.add_argument( - "--lr_factor", help="Learning rate factor", type=float, default=0.8 - ) - parser.add_argument( - "--scheduler_patience", help="Learning rate factor", type=int, default=50 - ) - parser.add_argument( - "--lr_scheduler_gamma", - help="Gamma of learning rate scheduler", - type=float, - default=0.9993, - ) - parser.add_argument( - "--swa", - "--stage_two", - help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", - action="store_true", - default=False, - dest="swa", - ) - parser.add_argument( - "--start_swa", - "--start_stage_two", - help="Number of epochs before changing to Stage Two loss weights", - type=int, - default=None, - dest="start_swa", - ) - parser.add_argument( - "--ema", - help="use Exponential Moving Average", - action="store_true", - default=False, - ) - parser.add_argument( - "--ema_decay", - help="Exponential Moving Average decay", - type=float, - default=0.99, - ) - parser.add_argument( - "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 - ) - parser.add_argument( - "--patience", - help="Maximum number of consecutive epochs of increasing loss", - type=int, - default=2048, - ) - parser.add_argument( - "--foundation_model", - help="Path to the foundation model for transfer learning", - type=str, - default=None, - ) - parser.add_argument( - "--foundation_model_readout", - help="Use readout of foundation model for transfer learning", - action="store_false", - default=True, - ) - parser.add_argument( - "--eval_interval", help="evaluate model every epochs", type=int, default=1 - ) - parser.add_argument( - "--keep_checkpoints", - help="keep all checkpoints", - action="store_true", - default=False, - ) - parser.add_argument( - "--save_all_checkpoints", - help="save all checkpoints", - action="store_true", - default=False, - ) - parser.add_argument( - "--restart_latest", - help="restart optimizer from latest checkpoint", - action="store_true", - default=False, - ) - parser.add_argument( - "--save_cpu", - help="Save a model to be loaded on cpu", - action="store_true", - default=False, - ) - parser.add_argument( - "--clip_grad", - help="Gradient Clipping Value", - type=check_float_or_none, - default=10.0, - ) - # options for using Weights and Biases for experiment tracking - # to install see https://wandb.ai - parser.add_argument( - "--wandb", - help="Use Weights and Biases for experiment tracking", - action="store_true", - default=False, - ) - parser.add_argument( - "--wandb_dir", - help="An absolute path to a directory where Weights and Biases metadata will be stored", - type=str, - default=None, - ) - parser.add_argument( - "--wandb_project", - help="Weights and Biases project name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_entity", - help="Weights and Biases entity name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_name", - help="Weights and Biases experiment name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_log_hypers", - help="The hyperparameters to log in Weights and Biases", - type=list, - default=[ - "num_channels", - "max_L", - "correlation", - "lr", - "swa_lr", - "weight_decay", - "batch_size", - "max_num_epochs", - "start_swa", - "energy_weight", - "forces_weight", - ], - ) - return parser +# # Loss and optimization +# parser.add_argument( +# "--loss", +# help="type of loss", +# default="weighted", +# choices=[ +# "ef", +# "weighted", +# "forces_only", +# "virials", +# "stress", +# "dipole", +# "huber", +# "universal", +# "energy_forces_dipole", +# ], +# ) +# parser.add_argument( +# "--forces_weight", help="weight of forces loss", type=float, default=100.0 +# ) +# parser.add_argument( +# "--swa_forces_weight", +# "--stage_two_forces_weight", +# help="weight of forces loss after starting Stage Two (previously called swa)", +# type=float, +# default=100.0, +# dest="swa_forces_weight", +# ) +# parser.add_argument( +# "--energy_weight", help="weight of energy loss", type=float, default=1.0 +# ) +# parser.add_argument( +# "--swa_energy_weight", +# "--stage_two_energy_weight", +# help="weight of energy loss after starting Stage Two (previously called swa)", +# type=float, +# default=1000.0, +# dest="swa_energy_weight", +# ) +# parser.add_argument( +# "--virials_weight", help="weight of virials loss", type=float, default=1.0 +# ) +# parser.add_argument( +# "--swa_virials_weight", +# "--stage_two_virials_weight", +# help="weight of virials loss after starting Stage Two (previously called swa)", +# type=float, +# default=10.0, +# dest="swa_virials_weight", +# ) +# parser.add_argument( +# "--stress_weight", help="weight of virials loss", type=float, default=1.0 +# ) +# parser.add_argument( +# "--swa_stress_weight", +# "--stage_two_stress_weight", +# help="weight of stress loss after starting Stage Two (previously called swa)", +# type=float, +# default=10.0, +# dest="swa_stress_weight", +# ) +# parser.add_argument( +# "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 +# ) +# parser.add_argument( +# "--swa_dipole_weight", +# "--stage_two_dipole_weight", +# help="weight of dipoles after starting Stage Two (previously called swa)", +# type=float, +# default=1.0, +# dest="swa_dipole_weight", +# ) +# parser.add_argument( +# "--config_type_weights", +# help="String of dictionary containing the weights for each config type", +# type=str, +# default='{"Default":1.0}', +# ) +# parser.add_argument( +# "--huber_delta", +# help="delta parameter for huber loss", +# type=float, +# default=0.01, +# ) +# parser.add_argument( +# "--optimizer", +# help="Optimizer for parameter optimization", +# type=str, +# default="adam", +# choices=["adam", "adamw", "schedulefree"], +# ) +# parser.add_argument( +# "--beta", +# help="Beta parameter for the optimizer", +# type=float, +# default=0.9, +# ) +# parser.add_argument("--batch_size", help="batch size", type=int, default=10) +# parser.add_argument( +# "--valid_batch_size", help="Validation batch size", type=int, default=10 +# ) +# parser.add_argument( +# "--lr", help="Learning rate of optimizer", type=float, default=0.01 +# ) +# parser.add_argument( +# "--swa_lr", +# "--stage_two_lr", +# help="Learning rate of optimizer in Stage Two (previously called swa)", +# type=float, +# default=1e-3, +# dest="swa_lr", +# ) +# parser.add_argument( +# "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 +# ) +# parser.add_argument( +# "--amsgrad", +# help="use amsgrad variant of optimizer", +# action="store_true", +# default=True, +# ) +# parser.add_argument( +# "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" +# ) +# parser.add_argument( +# "--lr_factor", help="Learning rate factor", type=float, default=0.8 +# ) +# parser.add_argument( +# "--scheduler_patience", help="Learning rate factor", type=int, default=50 +# ) +# parser.add_argument( +# "--lr_scheduler_gamma", +# help="Gamma of learning rate scheduler", +# type=float, +# default=0.9993, +# ) +# parser.add_argument( +# "--swa", +# "--stage_two", +# help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", +# action="store_true", +# default=False, +# dest="swa", +# ) +# parser.add_argument( +# "--start_swa", +# "--start_stage_two", +# help="Number of epochs before changing to Stage Two loss weights", +# type=int, +# default=None, +# dest="start_swa", +# ) +# parser.add_argument( +# "--ema", +# help="use Exponential Moving Average", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--ema_decay", +# help="Exponential Moving Average decay", +# type=float, +# default=0.99, +# ) +# parser.add_argument( +# "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 +# ) +# parser.add_argument( +# "--patience", +# help="Maximum number of consecutive epochs of increasing loss", +# type=int, +# default=2048, +# ) +# parser.add_argument( +# "--foundation_model", +# help="Path to the foundation model for transfer learning", +# type=str, +# default=None, +# ) +# parser.add_argument( +# "--foundation_model_readout", +# help="Use readout of foundation model for transfer learning", +# action="store_false", +# default=True, +# ) +# parser.add_argument( +# "--eval_interval", help="evaluate model every epochs", type=int, default=1 +# ) +# parser.add_argument( +# "--keep_checkpoints", +# help="keep all checkpoints", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--save_all_checkpoints", +# help="save all checkpoints", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--restart_latest", +# help="restart optimizer from latest checkpoint", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--save_cpu", +# help="Save a model to be loaded on cpu", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--clip_grad", +# help="Gradient Clipping Value", +# type=check_float_or_none, +# default=10.0, +# ) +# # options for using Weights and Biases for experiment tracking +# # to install see https://wandb.ai +# parser.add_argument( +# "--wandb", +# help="Use Weights and Biases for experiment tracking", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--wandb_dir", +# help="An absolute path to a directory where Weights and Biases metadata will be stored", +# type=str, +# default=None, +# ) +# parser.add_argument( +# "--wandb_project", +# help="Weights and Biases project name", +# type=str, +# default="", +# ) +# parser.add_argument( +# "--wandb_entity", +# help="Weights and Biases entity name", +# type=str, +# default="", +# ) +# parser.add_argument( +# "--wandb_name", +# help="Weights and Biases experiment name", +# type=str, +# default="", +# ) +# parser.add_argument( +# "--wandb_log_hypers", +# help="The hyperparameters to log in Weights and Biases", +# type=list, +# default=[ +# "num_channels", +# "max_L", +# "correlation", +# "lr", +# "swa_lr", +# "weight_decay", +# "batch_size", +# "max_num_epochs", +# "start_swa", +# "energy_weight", +# "forces_weight", +# ], +# ) +# return parser -def build_preprocess_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument( - "--train_file", - help="Training set h5 file", - type=str, - default=None, - required=True, - ) - parser.add_argument( - "--valid_file", - help="Training set xyz file", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--num_process", - help="The user defined number of processes to use, as well as the number of files created.", - type=int, - default=int(os.cpu_count() / 4), - ) - parser.add_argument( - "--valid_fraction", - help="Fraction of training set used for validation", - type=float, - default=0.1, - required=False, - ) - parser.add_argument( - "--test_file", - help="Test set xyz file", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--h5_prefix", - help="Prefix for h5 files when saving", - type=str, - default="", - ) - parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 - ) - parser.add_argument( - "--config_type_weights", - help="String of dictionary containing the weights for each config type", - type=str, - default='{"Default":1.0}', - ) - parser.add_argument( - "--energy_key", - help="Key of reference energies in training xyz", - type=str, - default="REF_energy", - ) - parser.add_argument( - "--forces_key", - help="Key of reference forces in training xyz", - type=str, - default="REF_forces", - ) - parser.add_argument( - "--virials_key", - help="Key of reference virials in training xyz", - type=str, - default="REF_virials", - ) - parser.add_argument( - "--stress_key", - help="Key of reference stress in training xyz", - type=str, - default="REF_stress", - ) - parser.add_argument( - "--dipole_key", - help="Key of reference dipoles in training xyz", - type=str, - default="REF_dipole", - ) - parser.add_argument( - "--charges_key", - help="Key of atomic charges in training xyz", - type=str, - default="REF_charges", - ) - parser.add_argument( - "--atomic_numbers", - help="List of atomic numbers", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--compute_statistics", - help="Compute statistics for the dataset", - action="store_true", - default=False, - ) - parser.add_argument( - "--batch_size", - help="batch size to compute average number of neighbours", - type=int, - default=16, - ) +# def build_preprocess_arg_parser() -> argparse.ArgumentParser: +# parser = argparse.ArgumentParser() +# parser.add_argument( +# "--train_file", +# help="Training set h5 file", +# type=str, +# default=None, +# required=True, +# ) +# parser.add_argument( +# "--valid_file", +# help="Training set xyz file", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--num_process", +# help="The user defined number of processes to use, as well as the number of files created.", +# type=int, +# default=int(os.cpu_count() / 4), +# ) +# parser.add_argument( +# "--valid_fraction", +# help="Fraction of training set used for validation", +# type=float, +# default=0.1, +# required=False, +# ) +# parser.add_argument( +# "--test_file", +# help="Test set xyz file", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--h5_prefix", +# help="Prefix for h5 files when saving", +# type=str, +# default="", +# ) +# parser.add_argument( +# "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 +# ) +# parser.add_argument( +# "--config_type_weights", +# help="String of dictionary containing the weights for each config type", +# type=str, +# default='{"Default":1.0}', +# ) +# parser.add_argument( +# "--energy_key", +# help="Key of reference energies in training xyz", +# type=str, +# default="REF_energy", +# ) +# parser.add_argument( +# "--forces_key", +# help="Key of reference forces in training xyz", +# type=str, +# default="REF_forces", +# ) +# parser.add_argument( +# "--virials_key", +# help="Key of reference virials in training xyz", +# type=str, +# default="REF_virials", +# ) +# parser.add_argument( +# "--stress_key", +# help="Key of reference stress in training xyz", +# type=str, +# default="REF_stress", +# ) +# parser.add_argument( +# "--dipole_key", +# help="Key of reference dipoles in training xyz", +# type=str, +# default="REF_dipole", +# ) +# parser.add_argument( +# "--charges_key", +# help="Key of atomic charges in training xyz", +# type=str, +# default="REF_charges", +# ) +# parser.add_argument( +# "--atomic_numbers", +# help="List of atomic numbers", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--compute_statistics", +# help="Compute statistics for the dataset", +# action="store_true", +# default=False, +# ) +# parser.add_argument( +# "--batch_size", +# help="batch size to compute average number of neighbours", +# type=int, +# default=16, +# ) - parser.add_argument( - "--scaling", - help="type of scaling to the output", - type=str, - default="rms_forces_scaling", - choices=["std_scaling", "rms_forces_scaling", "no_scaling"], - ) - parser.add_argument( - "--E0s", - help="Dictionary of isolated atom energies", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--shuffle", - help="Shuffle the training dataset", - type=bool, - default=True, - ) - parser.add_argument( - "--seed", - help="Random seed for splitting training and validation sets", - type=int, - default=123, - ) - return parser +# parser.add_argument( +# "--scaling", +# help="type of scaling to the output", +# type=str, +# default="rms_forces_scaling", +# choices=["std_scaling", "rms_forces_scaling", "no_scaling"], +# ) +# parser.add_argument( +# "--E0s", +# help="Dictionary of isolated atom energies", +# type=str, +# default=None, +# required=False, +# ) +# parser.add_argument( +# "--shuffle", +# help="Shuffle the training dataset", +# type=bool, +# default=True, +# ) +# parser.add_argument( +# "--seed", +# help="Random seed for splitting training and validation sets", +# type=int, +# default=123, +# ) +# return parser -def check_float_or_none(value: str) -> Optional[float]: - try: - return float(value) - except ValueError: - if value != "None": - raise argparse.ArgumentTypeError( - f"{value} is an invalid value (float or None)" - ) from None - return None +# def check_float_or_none(value: str) -> Optional[float]: +# try: +# return float(value) +# except ValueError: +# if value != "None": +# raise argparse.ArgumentTypeError( +# f"{value} is an invalid value (float or None)" +# ) from None +# return None diff --git a/hydragnn/utils/mace_utils/tools/arg_parser_tools.py b/hydragnn/utils/mace_utils/tools/arg_parser_tools.py index da64806a3..dc76c1f43 100644 --- a/hydragnn/utils/mace_utils/tools/arg_parser_tools.py +++ b/hydragnn/utils/mace_utils/tools/arg_parser_tools.py @@ -1,113 +1,113 @@ -import logging -import os +# import logging +# import os -from e3nn import o3 +# from e3nn import o3 -def check_args(args): - """ - Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing - the (potentially) modified args and a list of log messages. - """ - log_messages = [] +# def check_args(args): +# """ +# Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing +# the (potentially) modified args and a list of log messages. +# """ +# log_messages = [] - # Directories - # Use work_dir for all other directories as well, unless they were specified by the user - if args.log_dir is None: - args.log_dir = os.path.join(args.work_dir, "logs") - if args.model_dir is None: - args.model_dir = args.work_dir - if args.checkpoints_dir is None: - args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") - if args.results_dir is None: - args.results_dir = os.path.join(args.work_dir, "results") - if args.downloads_dir is None: - args.downloads_dir = os.path.join(args.work_dir, "downloads") +# # Directories +# # Use work_dir for all other directories as well, unless they were specified by the user +# if args.log_dir is None: +# args.log_dir = os.path.join(args.work_dir, "logs") +# if args.model_dir is None: +# args.model_dir = args.work_dir +# if args.checkpoints_dir is None: +# args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") +# if args.results_dir is None: +# args.results_dir = os.path.join(args.work_dir, "results") +# if args.downloads_dir is None: +# args.downloads_dir = os.path.join(args.work_dir, "downloads") - # Model - # Check if hidden_irreps, num_channels and max_L are consistent - if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: - args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 - elif ( - args.hidden_irreps is not None - and args.num_channels is not None - and args.max_L is not None - ): - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - log_messages.append( - ( - "All of hidden_irreps, num_channels and max_L are specified", - logging.WARNING, - ) - ) - log_messages.append( - ( - f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", - logging.WARNING, - ) - ) - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - elif args.num_channels is not None and args.max_L is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - assert args.max_L >= 0, "max_L must be non-negative integer" - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - elif args.hidden_irreps is not None: - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" +# # Model +# # Check if hidden_irreps, num_channels and max_L are consistent +# if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: +# args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 +# elif ( +# args.hidden_irreps is not None +# and args.num_channels is not None +# and args.max_L is not None +# ): +# args.hidden_irreps = o3.Irreps( +# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) +# .sort() +# .irreps.simplify() +# ) +# log_messages.append( +# ( +# "All of hidden_irreps, num_channels and max_L are specified", +# logging.WARNING, +# ) +# ) +# log_messages.append( +# ( +# f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", +# logging.WARNING, +# ) +# ) +# assert ( +# len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 +# ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" +# elif args.num_channels is not None and args.max_L is not None: +# assert args.num_channels > 0, "num_channels must be positive integer" +# assert args.max_L >= 0, "max_L must be non-negative integer" +# args.hidden_irreps = o3.Irreps( +# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) +# .sort() +# .irreps.simplify() +# ) +# assert ( +# len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 +# ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" +# elif args.hidden_irreps is not None: +# assert ( +# len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 +# ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - args.num_channels = list( - {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} - )[0] - args.max_L = o3.Irreps(args.hidden_irreps).lmax - elif args.max_L is not None and args.num_channels is None: - assert args.max_L >= 0, "max_L must be non-negative integer" - args.num_channels = 128 - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - elif args.max_L is None and args.num_channels is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - args.max_L = 1 - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) +# args.num_channels = list( +# {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} +# )[0] +# args.max_L = o3.Irreps(args.hidden_irreps).lmax +# elif args.max_L is not None and args.num_channels is None: +# assert args.max_L >= 0, "max_L must be non-negative integer" +# args.num_channels = 128 +# args.hidden_irreps = o3.Irreps( +# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) +# .sort() +# .irreps.simplify() +# ) +# elif args.max_L is None and args.num_channels is not None: +# assert args.num_channels > 0, "num_channels must be positive integer" +# args.max_L = 1 +# args.hidden_irreps = o3.Irreps( +# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) +# .sort() +# .irreps.simplify() +# ) - # Loss and optimization - # Check Stage Two loss start - if args.swa: - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - if args.start_swa > args.max_num_epochs: - log_messages.append( - ( - f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", - logging.WARNING, - ) - ) - log_messages.append( - ( - "Stage Two will not start, as start_stage_two > max_num_epochs", - logging.WARNING, - ) - ) - args.swa = False +# # Loss and optimization +# # Check Stage Two loss start +# if args.swa: +# if args.start_swa is None: +# args.start_swa = max(1, args.max_num_epochs // 4 * 3) +# if args.start_swa > args.max_num_epochs: +# log_messages.append( +# ( +# f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", +# logging.WARNING, +# ) +# ) +# log_messages.append( +# ( +# "Stage Two will not start, as start_stage_two > max_num_epochs", +# logging.WARNING, +# ) +# ) +# args.swa = False - return args, log_messages +# return args, log_messages diff --git a/hydragnn/utils/mace_utils/tools/checkpoint.py b/hydragnn/utils/mace_utils/tools/checkpoint.py index 8a62f1f27..c1f2f690e 100644 --- a/hydragnn/utils/mace_utils/tools/checkpoint.py +++ b/hydragnn/utils/mace_utils/tools/checkpoint.py @@ -1,227 +1,227 @@ -########################################################################################### -# Checkpointing -# Authors: Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import dataclasses -import logging -import os -import re -from typing import Dict, List, Optional, Tuple - -import torch - -from .torch_tools import TensorDict - -Checkpoint = Dict[str, TensorDict] - - -@dataclasses.dataclass -class CheckpointState: - model: torch.nn.Module - optimizer: torch.optim.Optimizer - lr_scheduler: torch.optim.lr_scheduler.ExponentialLR - - -class CheckpointBuilder: - @staticmethod - def create_checkpoint(state: CheckpointState) -> Checkpoint: - return { - "model": state.model.state_dict(), - "optimizer": state.optimizer.state_dict(), - "lr_scheduler": state.lr_scheduler.state_dict(), - } - - @staticmethod - def load_checkpoint( - state: CheckpointState, checkpoint: Checkpoint, strict: bool - ) -> None: - state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore - state.optimizer.load_state_dict(checkpoint["optimizer"]) - state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - - -@dataclasses.dataclass -class CheckpointPathInfo: - path: str - tag: str - epochs: int - swa: bool - - -class CheckpointIO: - def __init__( - self, directory: str, tag: str, keep: bool = False, swa_start: int = None - ) -> None: - self.directory = directory - self.tag = tag - self.keep = keep - self.old_path: Optional[str] = None - self.swa_start = swa_start - - self._epochs_string = "_epoch-" - self._filename_extension = "pt" - - def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: - if swa_start is not None and epochs > swa_start: - return ( - self.tag - + self._epochs_string - + str(epochs) - + "_swa" - + "." - + self._filename_extension - ) - return ( - self.tag - + self._epochs_string - + str(epochs) - + "." - + self._filename_extension - ) - - def _list_file_paths(self) -> List[str]: - if not os.path.isdir(self.directory): - return [] - all_paths = [ - os.path.join(self.directory, f) for f in os.listdir(self.directory) - ] - return [path for path in all_paths if os.path.isfile(path)] - - def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: - filename = os.path.basename(path) - regex = re.compile( - rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" - ) - regex2 = re.compile( - rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" - ) - match = regex.match(filename) - match2 = regex2.match(filename) - swa = False - if not match: - if not match2: - return None - match = match2 - swa = True - - return CheckpointPathInfo( - path=path, - tag=match.group("tag"), - epochs=int(match.group("epochs")), - swa=swa, - ) - - def _get_latest_checkpoint_path(self, swa) -> Optional[str]: - all_file_paths = self._list_file_paths() - checkpoint_info_list = [ - self._parse_checkpoint_path(path) for path in all_file_paths - ] - selected_checkpoint_info_list = [ - info for info in checkpoint_info_list if info and info.tag == self.tag - ] - - if len(selected_checkpoint_info_list) == 0: - logging.warning( - f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" - ) - return None - - selected_checkpoint_info_list_swa = [] - selected_checkpoint_info_list_no_swa = [] - - for ckp in selected_checkpoint_info_list: - if ckp.swa: - selected_checkpoint_info_list_swa.append(ckp) - else: - selected_checkpoint_info_list_no_swa.append(ckp) - if swa: - try: - latest_checkpoint_info = max( - selected_checkpoint_info_list_swa, key=lambda info: info.epochs - ) - except ValueError: - logging.warning( - "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." - ) - else: - latest_checkpoint_info = max( - selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs - ) - return latest_checkpoint_info.path - - def save( - self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False - ) -> None: - if not self.keep and self.old_path and not keep_last: - logging.debug(f"Deleting old checkpoint file: {self.old_path}") - os.remove(self.old_path) - - filename = self._get_checkpoint_filename(epochs, self.swa_start) - path = os.path.join(self.directory, filename) - logging.debug(f"Saving checkpoint: {path}") - os.makedirs(self.directory, exist_ok=True) - torch.save(obj=checkpoint, f=path) - self.old_path = path - - def load_latest( - self, swa: Optional[bool] = False, device: Optional[torch.device] = None - ) -> Optional[Tuple[Checkpoint, int]]: - path = self._get_latest_checkpoint_path(swa=swa) - if path is None: - return None - - return self.load(path, device=device) - - def load( - self, path: str, device: Optional[torch.device] = None - ) -> Tuple[Checkpoint, int]: - checkpoint_info = self._parse_checkpoint_path(path) - - if checkpoint_info is None: - raise RuntimeError(f"Cannot find path '{path}'") - - logging.info(f"Loading checkpoint: {checkpoint_info.path}") - return ( - torch.load(f=checkpoint_info.path, map_location=device), - checkpoint_info.epochs, - ) - - -class CheckpointHandler: - def __init__(self, *args, **kwargs) -> None: - self.io = CheckpointIO(*args, **kwargs) - self.builder = CheckpointBuilder() - - def save( - self, state: CheckpointState, epochs: int, keep_last: bool = False - ) -> None: - checkpoint = self.builder.create_checkpoint(state) - self.io.save(checkpoint, epochs, keep_last) - - def load_latest( - self, - state: CheckpointState, - swa: Optional[bool] = False, - device: Optional[torch.device] = None, - strict=False, - ) -> Optional[int]: - result = self.io.load_latest(swa=swa, device=device) - if result is None: - return None - - checkpoint, epochs = result - self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) - return epochs - - def load( - self, - state: CheckpointState, - path: str, - strict=False, - device: Optional[torch.device] = None, - ) -> int: - checkpoint, epochs = self.io.load(path, device=device) - self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) - return epochs +# ########################################################################################### +# # Checkpointing +# # Authors: Gregor Simm +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### + +# import dataclasses +# import logging +# import os +# import re +# from typing import Dict, List, Optional, Tuple + +# import torch + +# from .torch_tools import TensorDict + +# Checkpoint = Dict[str, TensorDict] + + +# @dataclasses.dataclass +# class CheckpointState: +# model: torch.nn.Module +# optimizer: torch.optim.Optimizer +# lr_scheduler: torch.optim.lr_scheduler.ExponentialLR + + +# class CheckpointBuilder: +# @staticmethod +# def create_checkpoint(state: CheckpointState) -> Checkpoint: +# return { +# "model": state.model.state_dict(), +# "optimizer": state.optimizer.state_dict(), +# "lr_scheduler": state.lr_scheduler.state_dict(), +# } + +# @staticmethod +# def load_checkpoint( +# state: CheckpointState, checkpoint: Checkpoint, strict: bool +# ) -> None: +# state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore +# state.optimizer.load_state_dict(checkpoint["optimizer"]) +# state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + +# @dataclasses.dataclass +# class CheckpointPathInfo: +# path: str +# tag: str +# epochs: int +# swa: bool + + +# class CheckpointIO: +# def __init__( +# self, directory: str, tag: str, keep: bool = False, swa_start: int = None +# ) -> None: +# self.directory = directory +# self.tag = tag +# self.keep = keep +# self.old_path: Optional[str] = None +# self.swa_start = swa_start + +# self._epochs_string = "_epoch-" +# self._filename_extension = "pt" + +# def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: +# if swa_start is not None and epochs > swa_start: +# return ( +# self.tag +# + self._epochs_string +# + str(epochs) +# + "_swa" +# + "." +# + self._filename_extension +# ) +# return ( +# self.tag +# + self._epochs_string +# + str(epochs) +# + "." +# + self._filename_extension +# ) + +# def _list_file_paths(self) -> List[str]: +# if not os.path.isdir(self.directory): +# return [] +# all_paths = [ +# os.path.join(self.directory, f) for f in os.listdir(self.directory) +# ] +# return [path for path in all_paths if os.path.isfile(path)] + +# def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: +# filename = os.path.basename(path) +# regex = re.compile( +# rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" +# ) +# regex2 = re.compile( +# rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" +# ) +# match = regex.match(filename) +# match2 = regex2.match(filename) +# swa = False +# if not match: +# if not match2: +# return None +# match = match2 +# swa = True + +# return CheckpointPathInfo( +# path=path, +# tag=match.group("tag"), +# epochs=int(match.group("epochs")), +# swa=swa, +# ) + +# def _get_latest_checkpoint_path(self, swa) -> Optional[str]: +# all_file_paths = self._list_file_paths() +# checkpoint_info_list = [ +# self._parse_checkpoint_path(path) for path in all_file_paths +# ] +# selected_checkpoint_info_list = [ +# info for info in checkpoint_info_list if info and info.tag == self.tag +# ] + +# if len(selected_checkpoint_info_list) == 0: +# logging.warning( +# f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" +# ) +# return None + +# selected_checkpoint_info_list_swa = [] +# selected_checkpoint_info_list_no_swa = [] + +# for ckp in selected_checkpoint_info_list: +# if ckp.swa: +# selected_checkpoint_info_list_swa.append(ckp) +# else: +# selected_checkpoint_info_list_no_swa.append(ckp) +# if swa: +# try: +# latest_checkpoint_info = max( +# selected_checkpoint_info_list_swa, key=lambda info: info.epochs +# ) +# except ValueError: +# logging.warning( +# "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." +# ) +# else: +# latest_checkpoint_info = max( +# selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs +# ) +# return latest_checkpoint_info.path + +# def save( +# self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False +# ) -> None: +# if not self.keep and self.old_path and not keep_last: +# logging.debug(f"Deleting old checkpoint file: {self.old_path}") +# os.remove(self.old_path) + +# filename = self._get_checkpoint_filename(epochs, self.swa_start) +# path = os.path.join(self.directory, filename) +# logging.debug(f"Saving checkpoint: {path}") +# os.makedirs(self.directory, exist_ok=True) +# torch.save(obj=checkpoint, f=path) +# self.old_path = path + +# def load_latest( +# self, swa: Optional[bool] = False, device: Optional[torch.device] = None +# ) -> Optional[Tuple[Checkpoint, int]]: +# path = self._get_latest_checkpoint_path(swa=swa) +# if path is None: +# return None + +# return self.load(path, device=device) + +# def load( +# self, path: str, device: Optional[torch.device] = None +# ) -> Tuple[Checkpoint, int]: +# checkpoint_info = self._parse_checkpoint_path(path) + +# if checkpoint_info is None: +# raise RuntimeError(f"Cannot find path '{path}'") + +# logging.info(f"Loading checkpoint: {checkpoint_info.path}") +# return ( +# torch.load(f=checkpoint_info.path, map_location=device), +# checkpoint_info.epochs, +# ) + + +# class CheckpointHandler: +# def __init__(self, *args, **kwargs) -> None: +# self.io = CheckpointIO(*args, **kwargs) +# self.builder = CheckpointBuilder() + +# def save( +# self, state: CheckpointState, epochs: int, keep_last: bool = False +# ) -> None: +# checkpoint = self.builder.create_checkpoint(state) +# self.io.save(checkpoint, epochs, keep_last) + +# def load_latest( +# self, +# state: CheckpointState, +# swa: Optional[bool] = False, +# device: Optional[torch.device] = None, +# strict=False, +# ) -> Optional[int]: +# result = self.io.load_latest(swa=swa, device=device) +# if result is None: +# return None + +# checkpoint, epochs = result +# self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) +# return epochs + +# def load( +# self, +# state: CheckpointState, +# path: str, +# strict=False, +# device: Optional[torch.device] = None, +# ) -> int: +# checkpoint, epochs = self.io.load(path, device=device) +# self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) +# return epochs diff --git a/hydragnn/utils/mace_utils/tools/scripts_utils.py b/hydragnn/utils/mace_utils/tools/scripts_utils.py index bca80bf64..15dd155d1 100644 --- a/hydragnn/utils/mace_utils/tools/scripts_utils.py +++ b/hydragnn/utils/mace_utils/tools/scripts_utils.py @@ -1,653 +1,653 @@ -########################################################################################### -# Training utils -# Authors: David Kovacs, Ilyes Batatia -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import ast -import dataclasses -import json -import logging -import os -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import torch -import torch.distributed -from e3nn import o3 -from prettytable import PrettyTable - -from mace import data, modules -from mace.tools import evaluate - - -@dataclasses.dataclass -class SubsetCollection: - train: data.Configurations - valid: data.Configurations - tests: List[Tuple[str, data.Configurations]] - - -def get_dataset_from_xyz( - work_dir: str, - train_path: str, - valid_path: str, - valid_fraction: float, - config_type_weights: Dict, - test_path: str = None, - seed: int = 1234, - keep_isolated_atoms: bool = False, - energy_key: str = "REF_energy", - forces_key: str = "REF_forces", - stress_key: str = "REF_stress", - virials_key: str = "virials", - dipole_key: str = "dipoles", - charges_key: str = "charges", -) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: - """Load training and test dataset from xyz file""" - atomic_energies_dict, all_train_configs = data.load_from_xyz( - file_path=train_path, - config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - extract_atomic_energies=True, - keep_isolated_atoms=keep_isolated_atoms, - ) - logging.info( - f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" - ) - if valid_path is not None: - _, valid_configs = data.load_from_xyz( - file_path=valid_path, - config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - stress_key=stress_key, - virials_key=virials_key, - dipole_key=dipole_key, - charges_key=charges_key, - extract_atomic_energies=False, - ) - logging.info( - f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" - ) - train_configs = all_train_configs - else: - train_configs, valid_configs = data.random_train_valid_split( - all_train_configs, valid_fraction, seed, work_dir - ) - logging.info( - f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" - ) - - test_configs = [] - if test_path is not None: - _, all_test_configs = data.load_from_xyz( - file_path=test_path, - config_type_weights=config_type_weights, - energy_key=energy_key, - forces_key=forces_key, - dipole_key=dipole_key, - stress_key=stress_key, - virials_key=virials_key, - charges_key=charges_key, - extract_atomic_energies=False, - ) - # create list of tuples (config_type, list(Atoms)) - test_configs = data.test_config_types(all_test_configs) - logging.info( - f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" - ) - for name, tmp_configs in test_configs: - logging.info( - f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" - ) - - return ( - SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), - atomic_energies_dict, - ) - - -def get_config_type_weights(ct_weights): - """ - Get config type weights from command line argument - """ - try: - config_type_weights = ast.literal_eval(ct_weights) - assert isinstance(config_type_weights, dict) - except Exception as e: # pylint: disable=W0703 - logging.warning( - f"Config type weights not specified correctly ({e}), using Default" - ) - config_type_weights = {"Default": 1.0} - return config_type_weights - - -def print_git_commit(): - try: - import git - - repo = git.Repo(search_parent_directories=True) - commit = repo.head.commit.hexsha - logging.debug(f"Current Git commit: {commit}") - return commit - except Exception as e: # pylint: disable=W0703 - logging.debug(f"Error accessing Git repository: {e}") - return "None" - - -def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: - if model.__class__.__name__ != "ScaleShiftMACE": - return {"error": "Model is not a ScaleShiftMACE model"} - - def radial_to_name(radial_type): - if radial_type == "BesselBasis": - return "bessel" - if radial_type == "GaussianBasis": - return "gaussian" - if radial_type == "ChebychevBasis": - return "chebyshev" - return radial_type - - def radial_to_transform(radial): - if not hasattr(radial, "distance_transform"): - return None - if radial.distance_transform.__class__.__name__ == "AgnesiTransform": - return "Agnesi" - if radial.distance_transform.__class__.__name__ == "SoftTransform": - return "Soft" - return radial.distance_transform.__class__.__name__ - - config = { - "r_max": model.r_max.item(), - "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), - "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), - "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access - "interaction_cls": model.interactions[-1].__class__, - "interaction_cls_first": model.interactions[0].__class__, - "num_interactions": model.num_interactions.item(), - "num_elements": len(model.atomic_numbers), - "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), - "MLP_irreps": ( - o3.Irreps(str(model.readouts[-1].hidden_irreps)) - if model.num_interactions.item() > 1 - else 1 - ), - "gate": ( - model.readouts[-1] # pylint: disable=protected-access - .non_linearity._modules["acts"][0] - .f - if model.num_interactions.item() > 1 - else None - ), - "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), - "avg_num_neighbors": model.interactions[0].avg_num_neighbors, - "atomic_numbers": model.atomic_numbers, - "correlation": len( - model.products[0].symmetric_contractions.contractions[0].weights - ) - + 1, - "radial_type": radial_to_name( - model.radial_embedding.bessel_fn.__class__.__name__ - ), - "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], - "pair_repulsion": hasattr(model, "pair_repulsion_fn"), - "distance_transform": radial_to_transform(model.radial_embedding), - "atomic_inter_scale": model.scale_shift.scale.item(), - "atomic_inter_shift": model.scale_shift.shift.item(), - } - return config - - -def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: - model = torch.load(f=f, map_location=map_location) - model_copy = model.__class__(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - return model_copy.to(map_location) - - -def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: - model_copy = model.__class__(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - return model_copy.to(map_location) - - -def convert_to_json_format(dict_input): - for key, value in dict_input.items(): - if isinstance(value, (np.ndarray, torch.Tensor)): - dict_input[key] = value.tolist() - # # check if the value is a class and convert it to a string - elif hasattr(value, "__class__"): - dict_input[key] = str(value) - return dict_input - - -def convert_from_json_format(dict_input): - dict_output = dict_input.copy() - if ( - dict_input["interaction_cls"] - == "" - ): - dict_output[ - "interaction_cls" - ] = modules.blocks.RealAgnosticResidualInteractionBlock - if ( - dict_input["interaction_cls"] - == "" - ): - dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock - if ( - dict_input["interaction_cls_first"] - == "" - ): - dict_output[ - "interaction_cls_first" - ] = modules.blocks.RealAgnosticResidualInteractionBlock - if ( - dict_input["interaction_cls_first"] - == "" - ): - dict_output[ - "interaction_cls_first" - ] = modules.blocks.RealAgnosticInteractionBlock - dict_output["r_max"] = float(dict_input["r_max"]) - dict_output["num_bessel"] = int(dict_input["num_bessel"]) - dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) - dict_output["max_ell"] = int(dict_input["max_ell"]) - dict_output["num_interactions"] = int(dict_input["num_interactions"]) - dict_output["num_elements"] = int(dict_input["num_elements"]) - dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) - dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) - dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) - dict_output["gate"] = torch.nn.functional.silu - dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) - dict_output["atomic_numbers"] = dict_input["atomic_numbers"] - dict_output["correlation"] = int(dict_input["correlation"]) - dict_output["radial_type"] = dict_input["radial_type"] - dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) - dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) - dict_output["distance_transform"] = dict_input["distance_transform"] - dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) - dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) - - return dict_output - - -def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: - extra_files_extract = {"commit.txt": None, "config.json": None} - model_jit_load = torch.jit.load( - f, _extra_files=extra_files_extract, map_location=map_location - ) - model_load_yaml = modules.ScaleShiftMACE( - **convert_from_json_format(json.loads(extra_files_extract["config.json"])) - ) - model_load_yaml.load_state_dict(model_jit_load.state_dict()) - return model_load_yaml.to(map_location) - - -def get_atomic_energies(E0s, train_collection, z_table) -> dict: - if E0s is not None: - logging.info( - "Isolated Atomic Energies (E0s) not in training file, using command line argument" - ) - if E0s.lower() == "average": - logging.info( - "Computing average Atomic Energies using least squares regression" - ) - # catch if colections.train not defined above - try: - assert train_collection is not None - atomic_energies_dict = data.compute_average_E0s( - train_collection, z_table - ) - except Exception as e: - raise RuntimeError( - f"Could not compute average E0s if no training xyz given, error {e} occured" - ) from e - else: - if E0s.endswith(".json"): - logging.info(f"Loading atomic energies from {E0s}") - with open(E0s, "r", encoding="utf-8") as f: - atomic_energies_dict = json.load(f) - else: - try: - atomic_energies_dict = ast.literal_eval(E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError( - f"E0s specified invalidly, error {e} occured" - ) from e - else: - raise RuntimeError( - "E0s not found in training file and not specified in command line" - ) - return atomic_energies_dict - - -def get_loss_fn( - loss: str, - energy_weight: float, - forces_weight: float, - stress_weight: float, - virials_weight: float, - dipole_weight: float, - dipole_only: bool, - compute_dipole: bool, -) -> torch.nn.Module: - if loss == "weighted": - loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=energy_weight, forces_weight=forces_weight - ) - elif loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=forces_weight) - elif loss == "virials": - loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - virials_weight=virials_weight, - ) - elif loss == "stress": - loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, - ) - elif loss == "dipole": - assert ( - dipole_only is True - ), "dipole loss can only be used with AtomicDipolesMACE model" - loss_fn = modules.DipoleSingleLoss( - dipole_weight=dipole_weight, - ) - elif loss == "energy_forces_dipole": - assert dipole_only is False and compute_dipole is True - loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - dipole_weight=dipole_weight, - ) - else: - loss_fn = modules.EnergyForcesLoss( - energy_weight=energy_weight, forces_weight=forces_weight - ) - return loss_fn - - -def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: - return [ - os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) - ] - - -def custom_key(key): - """ - Helper function to sort the keys of the data loader dictionary - to ensure that the training set, and validation set - are evaluated first - """ - if key == "train": - return (0, key) - if key == "valid": - return (1, key) - return (2, key) - - -class LRScheduler: - def __init__(self, optimizer, args) -> None: - self.scheduler = args.scheduler - self._optimizer_type = ( - args.optimizer - ) # Schedulefree does not need an optimizer but checkpoint handler does. - if args.scheduler == "ExponentialLR": - self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer=optimizer, gamma=args.lr_scheduler_gamma - ) - elif args.scheduler == "ReduceLROnPlateau": - self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer=optimizer, - factor=args.lr_factor, - patience=args.scheduler_patience, - ) - else: - raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") - - def step(self, metrics=None, epoch=None): # pylint: disable=E1123 - if self._optimizer_type == "schedulefree": - return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary - if self.scheduler == "ExponentialLR": - self.lr_scheduler.step(epoch=epoch) - elif self.scheduler == "ReduceLROnPlateau": - self.lr_scheduler.step( # pylint: disable=E1123 - metrics=metrics, epoch=epoch - ) - - def __getattr__(self, name): - if name == "step": - return self.step - return getattr(self.lr_scheduler, name) - - -def create_error_table( - table_type: str, - all_data_loaders: dict, - model: torch.nn.Module, - loss_fn: torch.nn.Module, - output_args: Dict[str, bool], - log_wandb: bool, - device: str, - distributed: bool = False, -) -> PrettyTable: - if log_wandb: - import wandb - table = PrettyTable() - if table_type == "TotalRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSEstressvirials": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - "RMSE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "PerAtomMAEstressvirials": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - "MAE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "TotalMAE": - table.field_names = [ - "config_type", - "MAE E / meV", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "PerAtomMAE": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "DipoleRMSE": - table.field_names = [ - "config_type", - "RMSE MU / mDebye / atom", - "relative MU RMSE %", - ] - elif table_type == "DipoleMAE": - table.field_names = [ - "config_type", - "MAE MU / mDebye / atom", - "relative MU MAE %", - ] - elif table_type == "EnergyDipoleRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "rel F RMSE %", - "RMSE MU / mDebye / atom", - "rel MU RMSE %", - ] - - for name in sorted(all_data_loaders, key=custom_key): - data_loader = all_data_loaders[name] - logging.info(f"Evaluating {name} ...") - _, metrics = evaluate( - model, - loss_fn=loss_fn, - data_loader=data_loader, - output_args=output_args, - device=device, - ) - if distributed: - torch.distributed.barrier() - - del data_loader - torch.cuda.empty_cache() - if log_wandb: - wandb_log_dict = { - name - + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] - * 1e3, # meV / atom - name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A - name + "_final_rel_rmse_f": metrics["rel_rmse_f"], - } - wandb.log(wandb_log_dict) - if table_type == "TotalRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif table_type == "PerAtomRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_virials'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_virials'] * 1000:8.1f}", - ] - ) - elif table_type == "TotalMAE": - table.add_row( - [ - name, - f"{metrics['mae_e'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "PerAtomMAE": - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "DipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - elif table_type == "DipoleMAE": - table.add_row( - [ - name, - f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_mae_mu']:8.1f}", - ] - ) - elif table_type == "EnergyDipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.1f}", - f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - return table +# ########################################################################################### +# # Training utils +# # Authors: David Kovacs, Ilyes Batatia +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### + +# import ast +# import dataclasses +# import json +# import logging +# import os +# from typing import Any, Dict, List, Optional, Tuple + +# import numpy as np +# import torch +# import torch.distributed +# from e3nn import o3 +# from prettytable import PrettyTable + +# from mace import data, modules +# from mace.tools import evaluate + + +# @dataclasses.dataclass +# class SubsetCollection: +# train: data.Configurations +# valid: data.Configurations +# tests: List[Tuple[str, data.Configurations]] + + +# def get_dataset_from_xyz( +# work_dir: str, +# train_path: str, +# valid_path: str, +# valid_fraction: float, +# config_type_weights: Dict, +# test_path: str = None, +# seed: int = 1234, +# keep_isolated_atoms: bool = False, +# energy_key: str = "REF_energy", +# forces_key: str = "REF_forces", +# stress_key: str = "REF_stress", +# virials_key: str = "virials", +# dipole_key: str = "dipoles", +# charges_key: str = "charges", +# ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: +# """Load training and test dataset from xyz file""" +# atomic_energies_dict, all_train_configs = data.load_from_xyz( +# file_path=train_path, +# config_type_weights=config_type_weights, +# energy_key=energy_key, +# forces_key=forces_key, +# stress_key=stress_key, +# virials_key=virials_key, +# dipole_key=dipole_key, +# charges_key=charges_key, +# extract_atomic_energies=True, +# keep_isolated_atoms=keep_isolated_atoms, +# ) +# logging.info( +# f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" +# ) +# if valid_path is not None: +# _, valid_configs = data.load_from_xyz( +# file_path=valid_path, +# config_type_weights=config_type_weights, +# energy_key=energy_key, +# forces_key=forces_key, +# stress_key=stress_key, +# virials_key=virials_key, +# dipole_key=dipole_key, +# charges_key=charges_key, +# extract_atomic_energies=False, +# ) +# logging.info( +# f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" +# ) +# train_configs = all_train_configs +# else: +# train_configs, valid_configs = data.random_train_valid_split( +# all_train_configs, valid_fraction, seed, work_dir +# ) +# logging.info( +# f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" +# ) + +# test_configs = [] +# if test_path is not None: +# _, all_test_configs = data.load_from_xyz( +# file_path=test_path, +# config_type_weights=config_type_weights, +# energy_key=energy_key, +# forces_key=forces_key, +# dipole_key=dipole_key, +# stress_key=stress_key, +# virials_key=virials_key, +# charges_key=charges_key, +# extract_atomic_energies=False, +# ) +# # create list of tuples (config_type, list(Atoms)) +# test_configs = data.test_config_types(all_test_configs) +# logging.info( +# f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" +# ) +# for name, tmp_configs in test_configs: +# logging.info( +# f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" +# ) + +# return ( +# SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), +# atomic_energies_dict, +# ) + + +# def get_config_type_weights(ct_weights): +# """ +# Get config type weights from command line argument +# """ +# try: +# config_type_weights = ast.literal_eval(ct_weights) +# assert isinstance(config_type_weights, dict) +# except Exception as e: # pylint: disable=W0703 +# logging.warning( +# f"Config type weights not specified correctly ({e}), using Default" +# ) +# config_type_weights = {"Default": 1.0} +# return config_type_weights + + +# def print_git_commit(): +# try: +# import git + +# repo = git.Repo(search_parent_directories=True) +# commit = repo.head.commit.hexsha +# logging.debug(f"Current Git commit: {commit}") +# return commit +# except Exception as e: # pylint: disable=W0703 +# logging.debug(f"Error accessing Git repository: {e}") +# return "None" + + +# def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: +# if model.__class__.__name__ != "ScaleShiftMACE": +# return {"error": "Model is not a ScaleShiftMACE model"} + +# def radial_to_name(radial_type): +# if radial_type == "BesselBasis": +# return "bessel" +# if radial_type == "GaussianBasis": +# return "gaussian" +# if radial_type == "ChebychevBasis": +# return "chebyshev" +# return radial_type + +# def radial_to_transform(radial): +# if not hasattr(radial, "distance_transform"): +# return None +# if radial.distance_transform.__class__.__name__ == "AgnesiTransform": +# return "Agnesi" +# if radial.distance_transform.__class__.__name__ == "SoftTransform": +# return "Soft" +# return radial.distance_transform.__class__.__name__ + +# config = { +# "r_max": model.r_max.item(), +# "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), +# "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), +# "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access +# "interaction_cls": model.interactions[-1].__class__, +# "interaction_cls_first": model.interactions[0].__class__, +# "num_interactions": model.num_interactions.item(), +# "num_elements": len(model.atomic_numbers), +# "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), +# "MLP_irreps": ( +# o3.Irreps(str(model.readouts[-1].hidden_irreps)) +# if model.num_interactions.item() > 1 +# else 1 +# ), +# "gate": ( +# model.readouts[-1] # pylint: disable=protected-access +# .non_linearity._modules["acts"][0] +# .f +# if model.num_interactions.item() > 1 +# else None +# ), +# "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), +# "avg_num_neighbors": model.interactions[0].avg_num_neighbors, +# "atomic_numbers": model.atomic_numbers, +# "correlation": len( +# model.products[0].symmetric_contractions.contractions[0].weights +# ) +# + 1, +# "radial_type": radial_to_name( +# model.radial_embedding.bessel_fn.__class__.__name__ +# ), +# "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], +# "pair_repulsion": hasattr(model, "pair_repulsion_fn"), +# "distance_transform": radial_to_transform(model.radial_embedding), +# "atomic_inter_scale": model.scale_shift.scale.item(), +# "atomic_inter_shift": model.scale_shift.shift.item(), +# } +# return config + + +# def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: +# model = torch.load(f=f, map_location=map_location) +# model_copy = model.__class__(**extract_config_mace_model(model)) +# model_copy.load_state_dict(model.state_dict()) +# return model_copy.to(map_location) + + +# def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: +# model_copy = model.__class__(**extract_config_mace_model(model)) +# model_copy.load_state_dict(model.state_dict()) +# return model_copy.to(map_location) + + +# def convert_to_json_format(dict_input): +# for key, value in dict_input.items(): +# if isinstance(value, (np.ndarray, torch.Tensor)): +# dict_input[key] = value.tolist() +# # # check if the value is a class and convert it to a string +# elif hasattr(value, "__class__"): +# dict_input[key] = str(value) +# return dict_input + + +# def convert_from_json_format(dict_input): +# dict_output = dict_input.copy() +# if ( +# dict_input["interaction_cls"] +# == "" +# ): +# dict_output[ +# "interaction_cls" +# ] = modules.blocks.RealAgnosticResidualInteractionBlock +# if ( +# dict_input["interaction_cls"] +# == "" +# ): +# dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock +# if ( +# dict_input["interaction_cls_first"] +# == "" +# ): +# dict_output[ +# "interaction_cls_first" +# ] = modules.blocks.RealAgnosticResidualInteractionBlock +# if ( +# dict_input["interaction_cls_first"] +# == "" +# ): +# dict_output[ +# "interaction_cls_first" +# ] = modules.blocks.RealAgnosticInteractionBlock +# dict_output["r_max"] = float(dict_input["r_max"]) +# dict_output["num_bessel"] = int(dict_input["num_bessel"]) +# dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) +# dict_output["max_ell"] = int(dict_input["max_ell"]) +# dict_output["num_interactions"] = int(dict_input["num_interactions"]) +# dict_output["num_elements"] = int(dict_input["num_elements"]) +# dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) +# dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) +# dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) +# dict_output["gate"] = torch.nn.functional.silu +# dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) +# dict_output["atomic_numbers"] = dict_input["atomic_numbers"] +# dict_output["correlation"] = int(dict_input["correlation"]) +# dict_output["radial_type"] = dict_input["radial_type"] +# dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) +# dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) +# dict_output["distance_transform"] = dict_input["distance_transform"] +# dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) +# dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) + +# return dict_output + + +# def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: +# extra_files_extract = {"commit.txt": None, "config.json": None} +# model_jit_load = torch.jit.load( +# f, _extra_files=extra_files_extract, map_location=map_location +# ) +# model_load_yaml = modules.ScaleShiftMACE( +# **convert_from_json_format(json.loads(extra_files_extract["config.json"])) +# ) +# model_load_yaml.load_state_dict(model_jit_load.state_dict()) +# return model_load_yaml.to(map_location) + + +# def get_atomic_energies(E0s, train_collection, z_table) -> dict: +# if E0s is not None: +# logging.info( +# "Isolated Atomic Energies (E0s) not in training file, using command line argument" +# ) +# if E0s.lower() == "average": +# logging.info( +# "Computing average Atomic Energies using least squares regression" +# ) +# # catch if colections.train not defined above +# try: +# assert train_collection is not None +# atomic_energies_dict = data.compute_average_E0s( +# train_collection, z_table +# ) +# except Exception as e: +# raise RuntimeError( +# f"Could not compute average E0s if no training xyz given, error {e} occured" +# ) from e +# else: +# if E0s.endswith(".json"): +# logging.info(f"Loading atomic energies from {E0s}") +# with open(E0s, "r", encoding="utf-8") as f: +# atomic_energies_dict = json.load(f) +# else: +# try: +# atomic_energies_dict = ast.literal_eval(E0s) +# assert isinstance(atomic_energies_dict, dict) +# except Exception as e: +# raise RuntimeError( +# f"E0s specified invalidly, error {e} occured" +# ) from e +# else: +# raise RuntimeError( +# "E0s not found in training file and not specified in command line" +# ) +# return atomic_energies_dict + + +# def get_loss_fn( +# loss: str, +# energy_weight: float, +# forces_weight: float, +# stress_weight: float, +# virials_weight: float, +# dipole_weight: float, +# dipole_only: bool, +# compute_dipole: bool, +# ) -> torch.nn.Module: +# if loss == "weighted": +# loss_fn = modules.WeightedEnergyForcesLoss( +# energy_weight=energy_weight, forces_weight=forces_weight +# ) +# elif loss == "forces_only": +# loss_fn = modules.WeightedForcesLoss(forces_weight=forces_weight) +# elif loss == "virials": +# loss_fn = modules.WeightedEnergyForcesVirialsLoss( +# energy_weight=energy_weight, +# forces_weight=forces_weight, +# virials_weight=virials_weight, +# ) +# elif loss == "stress": +# loss_fn = modules.WeightedEnergyForcesStressLoss( +# energy_weight=energy_weight, +# forces_weight=forces_weight, +# stress_weight=stress_weight, +# ) +# elif loss == "dipole": +# assert ( +# dipole_only is True +# ), "dipole loss can only be used with AtomicDipolesMACE model" +# loss_fn = modules.DipoleSingleLoss( +# dipole_weight=dipole_weight, +# ) +# elif loss == "energy_forces_dipole": +# assert dipole_only is False and compute_dipole is True +# loss_fn = modules.WeightedEnergyForcesDipoleLoss( +# energy_weight=energy_weight, +# forces_weight=forces_weight, +# dipole_weight=dipole_weight, +# ) +# else: +# loss_fn = modules.EnergyForcesLoss( +# energy_weight=energy_weight, forces_weight=forces_weight +# ) +# return loss_fn + + +# def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: +# return [ +# os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) +# ] + + +# def custom_key(key): +# """ +# Helper function to sort the keys of the data loader dictionary +# to ensure that the training set, and validation set +# are evaluated first +# """ +# if key == "train": +# return (0, key) +# if key == "valid": +# return (1, key) +# return (2, key) + + +# class LRScheduler: +# def __init__(self, optimizer, args) -> None: +# self.scheduler = args.scheduler +# self._optimizer_type = ( +# args.optimizer +# ) # Schedulefree does not need an optimizer but checkpoint handler does. +# if args.scheduler == "ExponentialLR": +# self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( +# optimizer=optimizer, gamma=args.lr_scheduler_gamma +# ) +# elif args.scheduler == "ReduceLROnPlateau": +# self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( +# optimizer=optimizer, +# factor=args.lr_factor, +# patience=args.scheduler_patience, +# ) +# else: +# raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") + +# def step(self, metrics=None, epoch=None): # pylint: disable=E1123 +# if self._optimizer_type == "schedulefree": +# return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary +# if self.scheduler == "ExponentialLR": +# self.lr_scheduler.step(epoch=epoch) +# elif self.scheduler == "ReduceLROnPlateau": +# self.lr_scheduler.step( # pylint: disable=E1123 +# metrics=metrics, epoch=epoch +# ) + +# def __getattr__(self, name): +# if name == "step": +# return self.step +# return getattr(self.lr_scheduler, name) + + +# def create_error_table( +# table_type: str, +# all_data_loaders: dict, +# model: torch.nn.Module, +# loss_fn: torch.nn.Module, +# output_args: Dict[str, bool], +# log_wandb: bool, +# device: str, +# distributed: bool = False, +# ) -> PrettyTable: +# if log_wandb: +# import wandb +# table = PrettyTable() +# if table_type == "TotalRMSE": +# table.field_names = [ +# "config_type", +# "RMSE E / meV", +# "RMSE F / meV / A", +# "relative F RMSE %", +# ] +# elif table_type == "PerAtomRMSE": +# table.field_names = [ +# "config_type", +# "RMSE E / meV / atom", +# "RMSE F / meV / A", +# "relative F RMSE %", +# ] +# elif table_type == "PerAtomRMSEstressvirials": +# table.field_names = [ +# "config_type", +# "RMSE E / meV / atom", +# "RMSE F / meV / A", +# "relative F RMSE %", +# "RMSE Stress (Virials) / meV / A (A^3)", +# ] +# elif table_type == "PerAtomMAEstressvirials": +# table.field_names = [ +# "config_type", +# "MAE E / meV / atom", +# "MAE F / meV / A", +# "relative F MAE %", +# "MAE Stress (Virials) / meV / A (A^3)", +# ] +# elif table_type == "TotalMAE": +# table.field_names = [ +# "config_type", +# "MAE E / meV", +# "MAE F / meV / A", +# "relative F MAE %", +# ] +# elif table_type == "PerAtomMAE": +# table.field_names = [ +# "config_type", +# "MAE E / meV / atom", +# "MAE F / meV / A", +# "relative F MAE %", +# ] +# elif table_type == "DipoleRMSE": +# table.field_names = [ +# "config_type", +# "RMSE MU / mDebye / atom", +# "relative MU RMSE %", +# ] +# elif table_type == "DipoleMAE": +# table.field_names = [ +# "config_type", +# "MAE MU / mDebye / atom", +# "relative MU MAE %", +# ] +# elif table_type == "EnergyDipoleRMSE": +# table.field_names = [ +# "config_type", +# "RMSE E / meV / atom", +# "RMSE F / meV / A", +# "rel F RMSE %", +# "RMSE MU / mDebye / atom", +# "rel MU RMSE %", +# ] + +# for name in sorted(all_data_loaders, key=custom_key): +# data_loader = all_data_loaders[name] +# logging.info(f"Evaluating {name} ...") +# _, metrics = evaluate( +# model, +# loss_fn=loss_fn, +# data_loader=data_loader, +# output_args=output_args, +# device=device, +# ) +# if distributed: +# torch.distributed.barrier() + +# del data_loader +# torch.cuda.empty_cache() +# if log_wandb: +# wandb_log_dict = { +# name +# + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] +# * 1e3, # meV / atom +# name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A +# name + "_final_rel_rmse_f": metrics["rel_rmse_f"], +# } +# wandb.log(wandb_log_dict) +# if table_type == "TotalRMSE": +# table.add_row( +# [ +# name, +# f"{metrics['rmse_e'] * 1000:8.1f}", +# f"{metrics['rmse_f'] * 1000:8.1f}", +# f"{metrics['rel_rmse_f']:8.2f}", +# ] +# ) +# elif table_type == "PerAtomRMSE": +# table.add_row( +# [ +# name, +# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", +# f"{metrics['rmse_f'] * 1000:8.1f}", +# f"{metrics['rel_rmse_f']:8.2f}", +# ] +# ) +# elif ( +# table_type == "PerAtomRMSEstressvirials" +# and metrics["rmse_stress"] is not None +# ): +# table.add_row( +# [ +# name, +# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", +# f"{metrics['rmse_f'] * 1000:8.1f}", +# f"{metrics['rel_rmse_f']:8.2f}", +# f"{metrics['rmse_stress'] * 1000:8.1f}", +# ] +# ) +# elif ( +# table_type == "PerAtomRMSEstressvirials" +# and metrics["rmse_virials"] is not None +# ): +# table.add_row( +# [ +# name, +# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", +# f"{metrics['rmse_f'] * 1000:8.1f}", +# f"{metrics['rel_rmse_f']:8.2f}", +# f"{metrics['rmse_virials'] * 1000:8.1f}", +# ] +# ) +# elif ( +# table_type == "PerAtomMAEstressvirials" +# and metrics["mae_stress"] is not None +# ): +# table.add_row( +# [ +# name, +# f"{metrics['mae_e_per_atom'] * 1000:8.1f}", +# f"{metrics['mae_f'] * 1000:8.1f}", +# f"{metrics['rel_mae_f']:8.2f}", +# f"{metrics['mae_stress'] * 1000:8.1f}", +# ] +# ) +# elif ( +# table_type == "PerAtomMAEstressvirials" +# and metrics["mae_virials"] is not None +# ): +# table.add_row( +# [ +# name, +# f"{metrics['mae_e_per_atom'] * 1000:8.1f}", +# f"{metrics['mae_f'] * 1000:8.1f}", +# f"{metrics['rel_mae_f']:8.2f}", +# f"{metrics['mae_virials'] * 1000:8.1f}", +# ] +# ) +# elif table_type == "TotalMAE": +# table.add_row( +# [ +# name, +# f"{metrics['mae_e'] * 1000:8.1f}", +# f"{metrics['mae_f'] * 1000:8.1f}", +# f"{metrics['rel_mae_f']:8.2f}", +# ] +# ) +# elif table_type == "PerAtomMAE": +# table.add_row( +# [ +# name, +# f"{metrics['mae_e_per_atom'] * 1000:8.1f}", +# f"{metrics['mae_f'] * 1000:8.1f}", +# f"{metrics['rel_mae_f']:8.2f}", +# ] +# ) +# elif table_type == "DipoleRMSE": +# table.add_row( +# [ +# name, +# f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", +# f"{metrics['rel_rmse_mu']:8.1f}", +# ] +# ) +# elif table_type == "DipoleMAE": +# table.add_row( +# [ +# name, +# f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", +# f"{metrics['rel_mae_mu']:8.1f}", +# ] +# ) +# elif table_type == "EnergyDipoleRMSE": +# table.add_row( +# [ +# name, +# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", +# f"{metrics['rmse_f'] * 1000:8.1f}", +# f"{metrics['rel_rmse_f']:8.1f}", +# f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", +# f"{metrics['rel_rmse_mu']:8.1f}", +# ] +# ) +# return table diff --git a/hydragnn/utils/mace_utils/tools/slurm_distributed.py b/hydragnn/utils/mace_utils/tools/slurm_distributed.py index 78de52a1b..35915fe26 100644 --- a/hydragnn/utils/mace_utils/tools/slurm_distributed.py +++ b/hydragnn/utils/mace_utils/tools/slurm_distributed.py @@ -1,34 +1,34 @@ -########################################################################################### -# Slurm environment setup for distributed training. -# This code is refactored from rsarm's contribution at: -# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### +# ########################################################################################### +# # Slurm environment setup for distributed training. +# # This code is refactored from rsarm's contribution at: +# # https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### -import os +# import os -import hostlist +# import hostlist -class DistributedEnvironment: - def __init__(self): - self._setup_distr_env() - self.master_addr = os.environ["MASTER_ADDR"] - self.master_port = os.environ["MASTER_PORT"] - self.world_size = int(os.environ["WORLD_SIZE"]) - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.rank = int(os.environ["RANK"]) +# class DistributedEnvironment: +# def __init__(self): +# self._setup_distr_env() +# self.master_addr = os.environ["MASTER_ADDR"] +# self.master_port = os.environ["MASTER_PORT"] +# self.world_size = int(os.environ["WORLD_SIZE"]) +# self.local_rank = int(os.environ["LOCAL_RANK"]) +# self.rank = int(os.environ["RANK"]) - def _setup_distr_env(self): - hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] - os.environ["MASTER_ADDR"] = hostname - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") - os.environ["WORLD_SIZE"] = os.environ.get( - "SLURM_NTASKS", - str( - int(os.environ["SLURM_NTASKS_PER_NODE"]) - * int(os.environ["SLURM_NNODES"]) - ), - ) - os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] - os.environ["RANK"] = os.environ["SLURM_PROCID"] +# def _setup_distr_env(self): +# hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] +# os.environ["MASTER_ADDR"] = hostname +# os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") +# os.environ["WORLD_SIZE"] = os.environ.get( +# "SLURM_NTASKS", +# str( +# int(os.environ["SLURM_NTASKS_PER_NODE"]) +# * int(os.environ["SLURM_NNODES"]) +# ), +# ) +# os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] +# os.environ["RANK"] = os.environ["SLURM_PROCID"] diff --git a/hydragnn/utils/mace_utils/tools/train.py b/hydragnn/utils/mace_utils/tools/train.py index b38bce167..a0b710222 100644 --- a/hydragnn/utils/mace_utils/tools/train.py +++ b/hydragnn/utils/mace_utils/tools/train.py @@ -1,524 +1,524 @@ -########################################################################################### -# Training script -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import dataclasses -import logging -import time -from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.distributed -from torch.nn.parallel import DistributedDataParallel -from torch.optim.swa_utils import SWALR, AveragedModel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from torch_ema import ExponentialMovingAverage -from torchmetrics import Metric - -from . import torch_geometric -from .checkpoint import CheckpointHandler, CheckpointState -from .torch_tools import to_numpy -from .utils import ( - MetricsLogger, - compute_mae, - compute_q95, - compute_rel_mae, - compute_rel_rmse, - compute_rmse, -) - - -@dataclasses.dataclass -class SWAContainer: - model: AveragedModel - scheduler: SWALR - start: int - loss_fn: torch.nn.Module - - -def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): - eval_metrics["mode"] = "eval" - eval_metrics["epoch"] = epoch - logger.log(eval_metrics) - if epoch is None: - inintial_phrase = "Initial" - else: - inintial_phrase = f"Epoch {epoch}" - if log_errors == "PerAtomRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress_per_atom"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_virials_per_atom"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", - ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_stress_per_atom"] is not None - ): - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - error_stress = eval_metrics["mae_stress"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" - ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_virials_per_atom"] is not None - ): - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - error_virials = eval_metrics["mae_virials"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" - ) - elif log_errors == "TotalRMSE": - error_e = eval_metrics["rmse_e"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", - ) - elif log_errors == "PerAtomMAE": - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", - ) - elif log_errors == "TotalMAE": - error_e = eval_metrics["mae_e"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", - ) - elif log_errors == "DipoleRMSE": - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", - ) - elif log_errors == "EnergyDipoleRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", - ) - - -def train( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - train_loader: DataLoader, - valid_loader: Dict[str, DataLoader], - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, - start_epoch: int, - max_num_epochs: int, - patience: int, - checkpoint_handler: CheckpointHandler, - logger: MetricsLogger, - eval_interval: int, - output_args: Dict[str, bool], - device: torch.device, - log_errors: str, - swa: Optional[SWAContainer] = None, - ema: Optional[ExponentialMovingAverage] = None, - max_grad_norm: Optional[float] = 10.0, - log_wandb: bool = False, - distributed: bool = False, - save_all_checkpoints: bool = False, - distributed_model: Optional[DistributedDataParallel] = None, - train_sampler: Optional[DistributedSampler] = None, - rank: Optional[int] = 0, -): - lowest_loss = np.inf - valid_loss = np.inf - patience_counter = 0 - swa_start = True - keep_last = False - if log_wandb: - import wandb - - if max_grad_norm is not None: - logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") - - logging.info("") - logging.info("===========TRAINING===========") - logging.info("Started training, reporting errors on validation set") - logging.info("Loss metrics on validation set") - epoch = start_epoch - - # # log validation loss before _any_ training - param_context = ema.average_parameters() if ema is not None else nullcontext() - with param_context: - valid_loss, eval_metrics = evaluate( - model=model, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - valid_err_log(valid_loss, eval_metrics, logger, log_errors, None) - - while epoch < max_num_epochs: - # LR scheduler and SWA update - if swa is None or epoch < swa.start: - if epoch > start_epoch: - lr_scheduler.step( - metrics=valid_loss - ) # Can break if exponential LR, TODO fix that! - else: - if swa_start: - logging.info("Changing loss based on Stage Two Weights") - lowest_loss = np.inf - swa_start = False - keep_last = True - loss_fn = swa.loss_fn - swa.model.update_parameters(model) - if epoch > start_epoch: - swa.scheduler.step() - - # Train - if distributed: - train_sampler.set_epoch(epoch) - if "ScheduleFree" in type(optimizer).__name__: - optimizer.train() - train_one_epoch( - model=model, - loss_fn=loss_fn, - data_loader=train_loader, - optimizer=optimizer, - epoch=epoch, - output_args=output_args, - max_grad_norm=max_grad_norm, - ema=ema, - logger=logger, - device=device, - distributed_model=distributed_model, - rank=rank, - ) - if distributed: - torch.distributed.barrier() - - # Validate - if epoch % eval_interval == 0: - model_to_evaluate = ( - model if distributed_model is None else distributed_model - ) - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - if "ScheduleFree" in type(optimizer).__name__: - optimizer.eval() - with param_context: - valid_loss, eval_metrics = evaluate( - model=model_to_evaluate, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - if rank == 0: - valid_err_log( - valid_loss, - eval_metrics, - logger, - log_errors, - epoch, - ) - if log_wandb: - wandb_log_dict = { - "epoch": epoch, - "valid_loss": valid_loss, - "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], - "valid_rmse_f": eval_metrics["rmse_f"], - } - wandb.log(wandb_log_dict) - - if valid_loss >= lowest_loss: - patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - elif patience_counter >= patience and epoch >= swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break - if save_all_checkpoints: - param_context = ( - ema.average_parameters() - if ema is not None - else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=True, - ) - else: - lowest_loss = valid_loss - patience_counter = 0 - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=keep_last, - ) - keep_last = False or save_all_checkpoints - if distributed: - torch.distributed.barrier() - epoch += 1 - - logging.info("Training complete") - - -def train_one_epoch( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - optimizer: torch.optim.Optimizer, - epoch: int, - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - ema: Optional[ExponentialMovingAverage], - logger: MetricsLogger, - device: torch.device, - distributed_model: Optional[DistributedDataParallel] = None, - rank: Optional[int] = 0, -) -> None: - model_to_train = model if distributed_model is None else distributed_model - for batch in data_loader: - _, opt_metrics = take_step( - model=model_to_train, - loss_fn=loss_fn, - batch=batch, - optimizer=optimizer, - ema=ema, - output_args=output_args, - max_grad_norm=max_grad_norm, - device=device, - ) - opt_metrics["mode"] = "opt" - opt_metrics["epoch"] = epoch - if rank == 0: - logger.log(opt_metrics) - - -def take_step( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - batch: torch_geometric.batch.Batch, - optimizer: torch.optim.Optimizer, - ema: Optional[ExponentialMovingAverage], - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - device: torch.device, -) -> Tuple[float, Dict[str, Any]]: - start_time = time.time() - batch = batch.to(device) - optimizer.zero_grad(set_to_none=True) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=True, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - loss = loss_fn(pred=output, ref=batch) - loss.backward() - if max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) - optimizer.step() - - if ema is not None: - ema.update() - - loss_dict = { - "loss": to_numpy(loss), - "time": time.time() - start_time, - } - - return loss, loss_dict - - -def evaluate( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - output_args: Dict[str, bool], - device: torch.device, -) -> Tuple[float, Dict[str, Any]]: - for param in model.parameters(): - param.requires_grad = False - - metrics = MACELoss(loss_fn=loss_fn).to(device) - - start_time = time.time() - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=False, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - avg_loss, aux = metrics(batch, output) - - avg_loss, aux = metrics.compute() - aux["time"] = time.time() - start_time - metrics.reset() - - for param in model.parameters(): - param.requires_grad = True - - return avg_loss, aux - - -class MACELoss(Metric): - def __init__(self, loss_fn: torch.nn.Module): - super().__init__() - self.loss_fn = loss_fn - self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("delta_es", default=[], dist_reduce_fx="cat") - self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("fs", default=[], dist_reduce_fx="cat") - self.add_state("delta_fs", default=[], dist_reduce_fx="cat") - self.add_state( - "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) - self.add_state("delta_stress", default=[], dist_reduce_fx="cat") - self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat") - self.add_state( - "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) - self.add_state("delta_virials", default=[], dist_reduce_fx="cat") - self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("mus", default=[], dist_reduce_fx="cat") - self.add_state("delta_mus", default=[], dist_reduce_fx="cat") - self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") - - def update(self, batch, output): # pylint: disable=arguments-differ - loss = self.loss_fn(pred=output, ref=batch) - self.total_loss += loss - self.num_data += batch.num_graphs - - if output.get("energy") is not None and batch.energy is not None: - self.E_computed += 1.0 - self.delta_es.append(batch.energy - output["energy"]) - self.delta_es_per_atom.append( - (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) - ) - if output.get("forces") is not None and batch.forces is not None: - self.Fs_computed += 1.0 - self.fs.append(batch.forces) - self.delta_fs.append(batch.forces - output["forces"]) - if output.get("stress") is not None and batch.stress is not None: - self.stress_computed += 1.0 - self.delta_stress.append(batch.stress - output["stress"]) - self.delta_stress_per_atom.append( - (batch.stress - output["stress"]) - / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) - ) - if output.get("virials") is not None and batch.virials is not None: - self.virials_computed += 1.0 - self.delta_virials.append(batch.virials - output["virials"]) - self.delta_virials_per_atom.append( - (batch.virials - output["virials"]) - / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) - ) - if output.get("dipole") is not None and batch.dipole is not None: - self.Mus_computed += 1.0 - self.mus.append(batch.dipole) - self.delta_mus.append(batch.dipole - output["dipole"]) - self.delta_mus_per_atom.append( - (batch.dipole - output["dipole"]) - / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) - ) - - def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: - if isinstance(delta, list): - delta = torch.cat(delta) - return to_numpy(delta) - - def compute(self): - aux = {} - aux["loss"] = to_numpy(self.total_loss / self.num_data).item() - if self.E_computed: - delta_es = self.convert(self.delta_es) - delta_es_per_atom = self.convert(self.delta_es_per_atom) - aux["mae_e"] = compute_mae(delta_es) - aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) - aux["rmse_e"] = compute_rmse(delta_es) - aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) - aux["q95_e"] = compute_q95(delta_es) - if self.Fs_computed: - fs = self.convert(self.fs) - delta_fs = self.convert(self.delta_fs) - aux["mae_f"] = compute_mae(delta_fs) - aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) - aux["rmse_f"] = compute_rmse(delta_fs) - aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) - aux["q95_f"] = compute_q95(delta_fs) - if self.stress_computed: - delta_stress = self.convert(self.delta_stress) - delta_stress_per_atom = self.convert(self.delta_stress_per_atom) - aux["mae_stress"] = compute_mae(delta_stress) - aux["rmse_stress"] = compute_rmse(delta_stress) - aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom) - aux["q95_stress"] = compute_q95(delta_stress) - if self.virials_computed: - delta_virials = self.convert(self.delta_virials) - delta_virials_per_atom = self.convert(self.delta_virials_per_atom) - aux["mae_virials"] = compute_mae(delta_virials) - aux["rmse_virials"] = compute_rmse(delta_virials) - aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) - aux["q95_virials"] = compute_q95(delta_virials) - if self.Mus_computed: - mus = self.convert(self.mus) - delta_mus = self.convert(self.delta_mus) - delta_mus_per_atom = self.convert(self.delta_mus_per_atom) - aux["mae_mu"] = compute_mae(delta_mus) - aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) - aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) - aux["rmse_mu"] = compute_rmse(delta_mus) - aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) - aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) - aux["q95_mu"] = compute_q95(delta_mus) - - return aux["loss"], aux +# ########################################################################################### +# # Training script +# # Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# # This program is distributed under the MIT License (see MIT.md) +# ########################################################################################### + +# import dataclasses +# import logging +# import time +# from contextlib import nullcontext +# from typing import Any, Dict, List, Optional, Tuple, Union + +# import numpy as np +# import torch +# import torch.distributed +# from torch.nn.parallel import DistributedDataParallel +# from torch.optim.swa_utils import SWALR, AveragedModel +# from torch.utils.data import DataLoader +# from torch.utils.data.distributed import DistributedSampler +# from torch_ema import ExponentialMovingAverage +# from torchmetrics import Metric + +# from . import torch_geometric +# from .checkpoint import CheckpointHandler, CheckpointState +# from .torch_tools import to_numpy +# from .utils import ( +# MetricsLogger, +# compute_mae, +# compute_q95, +# compute_rel_mae, +# compute_rel_rmse, +# compute_rmse, +# ) + + +# @dataclasses.dataclass +# class SWAContainer: +# model: AveragedModel +# scheduler: SWALR +# start: int +# loss_fn: torch.nn.Module + + +# def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): +# eval_metrics["mode"] = "eval" +# eval_metrics["epoch"] = epoch +# logger.log(eval_metrics) +# if epoch is None: +# inintial_phrase = "Initial" +# else: +# inintial_phrase = f"Epoch {epoch}" +# if log_errors == "PerAtomRMSE": +# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 +# error_f = eval_metrics["rmse_f"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" +# ) +# elif ( +# log_errors == "PerAtomRMSEstressvirials" +# and eval_metrics["rmse_stress_per_atom"] is not None +# ): +# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 +# error_f = eval_metrics["rmse_f"] * 1e3 +# error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", +# ) +# elif ( +# log_errors == "PerAtomRMSEstressvirials" +# and eval_metrics["rmse_virials_per_atom"] is not None +# ): +# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 +# error_f = eval_metrics["rmse_f"] * 1e3 +# error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", +# ) +# elif ( +# log_errors == "PerAtomMAEstressvirials" +# and eval_metrics["mae_stress_per_atom"] is not None +# ): +# error_e = eval_metrics["mae_e_per_atom"] * 1e3 +# error_f = eval_metrics["mae_f"] * 1e3 +# error_stress = eval_metrics["mae_stress"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" +# ) +# elif ( +# log_errors == "PerAtomMAEstressvirials" +# and eval_metrics["mae_virials_per_atom"] is not None +# ): +# error_e = eval_metrics["mae_e_per_atom"] * 1e3 +# error_f = eval_metrics["mae_f"] * 1e3 +# error_virials = eval_metrics["mae_virials"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" +# ) +# elif log_errors == "TotalRMSE": +# error_e = eval_metrics["rmse_e"] * 1e3 +# error_f = eval_metrics["rmse_f"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", +# ) +# elif log_errors == "PerAtomMAE": +# error_e = eval_metrics["mae_e_per_atom"] * 1e3 +# error_f = eval_metrics["mae_f"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", +# ) +# elif log_errors == "TotalMAE": +# error_e = eval_metrics["mae_e"] * 1e3 +# error_f = eval_metrics["mae_f"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", +# ) +# elif log_errors == "DipoleRMSE": +# error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", +# ) +# elif log_errors == "EnergyDipoleRMSE": +# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 +# error_f = eval_metrics["rmse_f"] * 1e3 +# error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 +# logging.info( +# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", +# ) + + +# def train( +# model: torch.nn.Module, +# loss_fn: torch.nn.Module, +# train_loader: DataLoader, +# valid_loader: Dict[str, DataLoader], +# optimizer: torch.optim.Optimizer, +# lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, +# start_epoch: int, +# max_num_epochs: int, +# patience: int, +# checkpoint_handler: CheckpointHandler, +# logger: MetricsLogger, +# eval_interval: int, +# output_args: Dict[str, bool], +# device: torch.device, +# log_errors: str, +# swa: Optional[SWAContainer] = None, +# ema: Optional[ExponentialMovingAverage] = None, +# max_grad_norm: Optional[float] = 10.0, +# log_wandb: bool = False, +# distributed: bool = False, +# save_all_checkpoints: bool = False, +# distributed_model: Optional[DistributedDataParallel] = None, +# train_sampler: Optional[DistributedSampler] = None, +# rank: Optional[int] = 0, +# ): +# lowest_loss = np.inf +# valid_loss = np.inf +# patience_counter = 0 +# swa_start = True +# keep_last = False +# if log_wandb: +# import wandb + +# if max_grad_norm is not None: +# logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") + +# logging.info("") +# logging.info("===========TRAINING===========") +# logging.info("Started training, reporting errors on validation set") +# logging.info("Loss metrics on validation set") +# epoch = start_epoch + +# # # log validation loss before _any_ training +# param_context = ema.average_parameters() if ema is not None else nullcontext() +# with param_context: +# valid_loss, eval_metrics = evaluate( +# model=model, +# loss_fn=loss_fn, +# data_loader=valid_loader, +# output_args=output_args, +# device=device, +# ) +# valid_err_log(valid_loss, eval_metrics, logger, log_errors, None) + +# while epoch < max_num_epochs: +# # LR scheduler and SWA update +# if swa is None or epoch < swa.start: +# if epoch > start_epoch: +# lr_scheduler.step( +# metrics=valid_loss +# ) # Can break if exponential LR, TODO fix that! +# else: +# if swa_start: +# logging.info("Changing loss based on Stage Two Weights") +# lowest_loss = np.inf +# swa_start = False +# keep_last = True +# loss_fn = swa.loss_fn +# swa.model.update_parameters(model) +# if epoch > start_epoch: +# swa.scheduler.step() + +# # Train +# if distributed: +# train_sampler.set_epoch(epoch) +# if "ScheduleFree" in type(optimizer).__name__: +# optimizer.train() +# train_one_epoch( +# model=model, +# loss_fn=loss_fn, +# data_loader=train_loader, +# optimizer=optimizer, +# epoch=epoch, +# output_args=output_args, +# max_grad_norm=max_grad_norm, +# ema=ema, +# logger=logger, +# device=device, +# distributed_model=distributed_model, +# rank=rank, +# ) +# if distributed: +# torch.distributed.barrier() + +# # Validate +# if epoch % eval_interval == 0: +# model_to_evaluate = ( +# model if distributed_model is None else distributed_model +# ) +# param_context = ( +# ema.average_parameters() if ema is not None else nullcontext() +# ) +# if "ScheduleFree" in type(optimizer).__name__: +# optimizer.eval() +# with param_context: +# valid_loss, eval_metrics = evaluate( +# model=model_to_evaluate, +# loss_fn=loss_fn, +# data_loader=valid_loader, +# output_args=output_args, +# device=device, +# ) +# if rank == 0: +# valid_err_log( +# valid_loss, +# eval_metrics, +# logger, +# log_errors, +# epoch, +# ) +# if log_wandb: +# wandb_log_dict = { +# "epoch": epoch, +# "valid_loss": valid_loss, +# "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], +# "valid_rmse_f": eval_metrics["rmse_f"], +# } +# wandb.log(wandb_log_dict) + +# if valid_loss >= lowest_loss: +# patience_counter += 1 +# if patience_counter >= patience and epoch < swa.start: +# logging.info( +# f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" +# ) +# epoch = swa.start +# elif patience_counter >= patience and epoch >= swa.start: +# logging.info( +# f"Stopping optimization after {patience_counter} epochs without improvement" +# ) +# break +# if save_all_checkpoints: +# param_context = ( +# ema.average_parameters() +# if ema is not None +# else nullcontext() +# ) +# with param_context: +# checkpoint_handler.save( +# state=CheckpointState(model, optimizer, lr_scheduler), +# epochs=epoch, +# keep_last=True, +# ) +# else: +# lowest_loss = valid_loss +# patience_counter = 0 +# param_context = ( +# ema.average_parameters() if ema is not None else nullcontext() +# ) +# with param_context: +# checkpoint_handler.save( +# state=CheckpointState(model, optimizer, lr_scheduler), +# epochs=epoch, +# keep_last=keep_last, +# ) +# keep_last = False or save_all_checkpoints +# if distributed: +# torch.distributed.barrier() +# epoch += 1 + +# logging.info("Training complete") + + +# def train_one_epoch( +# model: torch.nn.Module, +# loss_fn: torch.nn.Module, +# data_loader: DataLoader, +# optimizer: torch.optim.Optimizer, +# epoch: int, +# output_args: Dict[str, bool], +# max_grad_norm: Optional[float], +# ema: Optional[ExponentialMovingAverage], +# logger: MetricsLogger, +# device: torch.device, +# distributed_model: Optional[DistributedDataParallel] = None, +# rank: Optional[int] = 0, +# ) -> None: +# model_to_train = model if distributed_model is None else distributed_model +# for batch in data_loader: +# _, opt_metrics = take_step( +# model=model_to_train, +# loss_fn=loss_fn, +# batch=batch, +# optimizer=optimizer, +# ema=ema, +# output_args=output_args, +# max_grad_norm=max_grad_norm, +# device=device, +# ) +# opt_metrics["mode"] = "opt" +# opt_metrics["epoch"] = epoch +# if rank == 0: +# logger.log(opt_metrics) + + +# def take_step( +# model: torch.nn.Module, +# loss_fn: torch.nn.Module, +# batch: torch_geometric.batch.Batch, +# optimizer: torch.optim.Optimizer, +# ema: Optional[ExponentialMovingAverage], +# output_args: Dict[str, bool], +# max_grad_norm: Optional[float], +# device: torch.device, +# ) -> Tuple[float, Dict[str, Any]]: +# start_time = time.time() +# batch = batch.to(device) +# optimizer.zero_grad(set_to_none=True) +# batch_dict = batch.to_dict() +# output = model( +# batch_dict, +# training=True, +# compute_force=output_args["forces"], +# compute_virials=output_args["virials"], +# compute_stress=output_args["stress"], +# ) +# loss = loss_fn(pred=output, ref=batch) +# loss.backward() +# if max_grad_norm is not None: +# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) +# optimizer.step() + +# if ema is not None: +# ema.update() + +# loss_dict = { +# "loss": to_numpy(loss), +# "time": time.time() - start_time, +# } + +# return loss, loss_dict + + +# def evaluate( +# model: torch.nn.Module, +# loss_fn: torch.nn.Module, +# data_loader: DataLoader, +# output_args: Dict[str, bool], +# device: torch.device, +# ) -> Tuple[float, Dict[str, Any]]: +# for param in model.parameters(): +# param.requires_grad = False + +# metrics = MACELoss(loss_fn=loss_fn).to(device) + +# start_time = time.time() +# for batch in data_loader: +# batch = batch.to(device) +# batch_dict = batch.to_dict() +# output = model( +# batch_dict, +# training=False, +# compute_force=output_args["forces"], +# compute_virials=output_args["virials"], +# compute_stress=output_args["stress"], +# ) +# avg_loss, aux = metrics(batch, output) + +# avg_loss, aux = metrics.compute() +# aux["time"] = time.time() - start_time +# metrics.reset() + +# for param in model.parameters(): +# param.requires_grad = True + +# return avg_loss, aux + + +# class MACELoss(Metric): +# def __init__(self, loss_fn: torch.nn.Module): +# super().__init__() +# self.loss_fn = loss_fn +# self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") +# self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") +# self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") +# self.add_state("delta_es", default=[], dist_reduce_fx="cat") +# self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") +# self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") +# self.add_state("fs", default=[], dist_reduce_fx="cat") +# self.add_state("delta_fs", default=[], dist_reduce_fx="cat") +# self.add_state( +# "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" +# ) +# self.add_state("delta_stress", default=[], dist_reduce_fx="cat") +# self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat") +# self.add_state( +# "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" +# ) +# self.add_state("delta_virials", default=[], dist_reduce_fx="cat") +# self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") +# self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") +# self.add_state("mus", default=[], dist_reduce_fx="cat") +# self.add_state("delta_mus", default=[], dist_reduce_fx="cat") +# self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") + +# def update(self, batch, output): # pylint: disable=arguments-differ +# loss = self.loss_fn(pred=output, ref=batch) +# self.total_loss += loss +# self.num_data += batch.num_graphs + +# if output.get("energy") is not None and batch.energy is not None: +# self.E_computed += 1.0 +# self.delta_es.append(batch.energy - output["energy"]) +# self.delta_es_per_atom.append( +# (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) +# ) +# if output.get("forces") is not None and batch.forces is not None: +# self.Fs_computed += 1.0 +# self.fs.append(batch.forces) +# self.delta_fs.append(batch.forces - output["forces"]) +# if output.get("stress") is not None and batch.stress is not None: +# self.stress_computed += 1.0 +# self.delta_stress.append(batch.stress - output["stress"]) +# self.delta_stress_per_atom.append( +# (batch.stress - output["stress"]) +# / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) +# ) +# if output.get("virials") is not None and batch.virials is not None: +# self.virials_computed += 1.0 +# self.delta_virials.append(batch.virials - output["virials"]) +# self.delta_virials_per_atom.append( +# (batch.virials - output["virials"]) +# / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) +# ) +# if output.get("dipole") is not None and batch.dipole is not None: +# self.Mus_computed += 1.0 +# self.mus.append(batch.dipole) +# self.delta_mus.append(batch.dipole - output["dipole"]) +# self.delta_mus_per_atom.append( +# (batch.dipole - output["dipole"]) +# / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) +# ) + +# def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: +# if isinstance(delta, list): +# delta = torch.cat(delta) +# return to_numpy(delta) + +# def compute(self): +# aux = {} +# aux["loss"] = to_numpy(self.total_loss / self.num_data).item() +# if self.E_computed: +# delta_es = self.convert(self.delta_es) +# delta_es_per_atom = self.convert(self.delta_es_per_atom) +# aux["mae_e"] = compute_mae(delta_es) +# aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) +# aux["rmse_e"] = compute_rmse(delta_es) +# aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) +# aux["q95_e"] = compute_q95(delta_es) +# if self.Fs_computed: +# fs = self.convert(self.fs) +# delta_fs = self.convert(self.delta_fs) +# aux["mae_f"] = compute_mae(delta_fs) +# aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) +# aux["rmse_f"] = compute_rmse(delta_fs) +# aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) +# aux["q95_f"] = compute_q95(delta_fs) +# if self.stress_computed: +# delta_stress = self.convert(self.delta_stress) +# delta_stress_per_atom = self.convert(self.delta_stress_per_atom) +# aux["mae_stress"] = compute_mae(delta_stress) +# aux["rmse_stress"] = compute_rmse(delta_stress) +# aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom) +# aux["q95_stress"] = compute_q95(delta_stress) +# if self.virials_computed: +# delta_virials = self.convert(self.delta_virials) +# delta_virials_per_atom = self.convert(self.delta_virials_per_atom) +# aux["mae_virials"] = compute_mae(delta_virials) +# aux["rmse_virials"] = compute_rmse(delta_virials) +# aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) +# aux["q95_virials"] = compute_q95(delta_virials) +# if self.Mus_computed: +# mus = self.convert(self.mus) +# delta_mus = self.convert(self.delta_mus) +# delta_mus_per_atom = self.convert(self.delta_mus_per_atom) +# aux["mae_mu"] = compute_mae(delta_mus) +# aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) +# aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) +# aux["rmse_mu"] = compute_rmse(delta_mus) +# aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) +# aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) +# aux["q95_mu"] = compute_q95(delta_mus) + +# return aux["loss"], aux diff --git a/tests/test_forces_equivariant.py b/tests/test_forces_equivariant.py new file mode 100644 index 000000000..8573312ec --- /dev/null +++ b/tests/test_forces_equivariant.py @@ -0,0 +1,28 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +import os +import pytest + +import subprocess + + +@pytest.mark.parametrize("example", ["LennardJones"]) +# @pytest.mark.parametrize("model_type", ["SchNet", "EGNN", "DimeNet", "PNAPlus", "MACE"]) +@pytest.mark.parametrize("model_type", ["MACE"]) +@pytest.mark.mpi_skip() +def pytest_examples(example, model_type): + path = os.path.join(os.path.dirname(__file__), "..", "examples", example) + file_path = os.path.join(path, example + ".py") # Assuming different model scripts + return_code = subprocess.call(["python", file_path, "--model_type", model_type]) + + # Check the file ran without error. + assert return_code == 0 From 61730a650cd8719abd9a0354b2aaa60283ca0fbd Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 13:41:35 -0400 Subject: [PATCH 21/51] Making tests run faster and adjusting requirements --- examples/LennardJones/LJ.json | 10 +++++---- requirements-torch.txt | 4 ++-- requirements-torch2.txt | 3 +++ requirements.txt | 3 --- tests/test_graphs.py | 41 +++++++++++++++++++---------------- 5 files changed, 33 insertions(+), 28 deletions(-) create mode 100644 requirements-torch2.txt diff --git a/examples/LennardJones/LJ.json b/examples/LennardJones/LJ.json index a6b18f12b..05b3c6cb6 100644 --- a/examples/LennardJones/LJ.json +++ b/examples/LennardJones/LJ.json @@ -30,10 +30,12 @@ "num_before_skip": 1, "num_after_skip": 1, "envelope_exponent": 5, + "max_ell": 1, + "node_max_ell": 1, "num_radial": 5, "num_spherical": 2, - "hidden_dim": 20, - "num_conv_layers": 4, + "hidden_dim": 2, + "num_conv_layers": 2, "output_heads": { "node": { "num_headlayers": 2, @@ -55,9 +57,9 @@ "output_names": ["graph_energy"] }, "Training": { - "num_epoch": 15, + "num_epoch": 2, "batch_size": 64, - "perc_train": 0.7, + "perc_train": 0.1, "patience": 20, "early_stopping": true, "Optimizer": { diff --git a/requirements-torch.txt b/requirements-torch.txt index 41211d9e3..b4b14066d 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,3 +1,3 @@ -torch==2.0.1 -torchvision +torch==2.0.1 +torchvision torchaudio diff --git a/requirements-torch2.txt b/requirements-torch2.txt new file mode 100644 index 000000000..9fb3c2dc1 --- /dev/null +++ b/requirements-torch2.txt @@ -0,0 +1,3 @@ +e3nn +torch-ema +torchmetrics diff --git a/requirements.txt b/requirements.txt index 72f254d8e..b8107fd54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,4 @@ tqdm tensorboard psutil sympy -e3nn matscipy -torch-ema -torchmetrics diff --git a/tests/test_graphs.py b/tests/test_graphs.py index ea3c0c333..755b47668 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -197,16 +197,16 @@ def unittest_train_model( @pytest.mark.parametrize( "model_type", [ - "SAGE", - "GIN", - "GAT", - "MFC", - "PNA", - "PNAPlus", - "CGCNN", - "SchNet", - "DimeNet", - "EGNN", + # "SAGE", + # "GIN", + # "GAT", + # "MFC", + # "PNA", + # "PNAPlus", + # "CGCNN", + # "SchNet", + # "DimeNet", + # "EGNN", "MACE", ], ) @@ -217,30 +217,33 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): # Test only models @pytest.mark.parametrize( - "model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"] + # "model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"] + "model_type", ["MACE"] ) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) # Test across equivariant models -@pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "MACE"]) +# @pytest.mark.parametrize("model_type", ["EGNN", "SchNet", "MACE"]) +@pytest.mark.parametrize("model_type", ["MACE"]) def pytest_train_equivariant_model(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_equivariant.json", False, overwrite_data) # Test vector output -@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "MACE"]) +# @pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "MACE"]) +@pytest.mark.parametrize("model_type", ["MACE"]) def pytest_train_model_vectoroutput(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data) -@pytest.mark.parametrize( - "model_type", - ["SAGE", "GIN", "GAT", "MFC", "PNA", "PNAPlus", "SchNet", "DimeNet", "EGNN"], -) -def pytest_train_model_conv_head(model_type, overwrite_data=False): - unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) +# @pytest.mark.parametrize( +# "model_type", +# ["SAGE", "GIN", "GAT", "MFC", "PNA", "PNAPlus", "SchNet", "DimeNet", "EGNN"], +# ) +# def pytest_train_model_conv_head(model_type, overwrite_data=False): +# unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) # def debug_train_model_vectoroutput(model_type="MACE", overwrite_data=False): From d7652e3b641e99bef64168c308aaa7c81af93e5e Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 13:44:11 -0400 Subject: [PATCH 22/51] need to install new requirements file --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d148ee2bb..18d7823fe 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -41,6 +41,7 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade -r requirements.txt -r requirements-dev.txt python -m pip install --upgrade -r requirements-torch.txt --index-url https://download.pytorch.org/whl/cpu + python -m pip install --upgrade -r requirements-torch2.txt python -m pip install --upgrade -r requirements-pyg.txt --find-links https://data.pyg.org/whl/torch-2.0.1+cpu.html python -m pip install --upgrade -r requirements-deepspeed.txt - name: Format black From e7a5d2fedd8574e929dedd637a343cc57301596f Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 13:44:40 -0400 Subject: [PATCH 23/51] formatting --- hydragnn/utils/mace_utils/modules/__init__.py | 2 ++ hydragnn/utils/mace_utils/tools/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/hydragnn/utils/mace_utils/modules/__init__.py b/hydragnn/utils/mace_utils/modules/__init__.py index 2c8bb160a..a9b38f67b 100644 --- a/hydragnn/utils/mace_utils/modules/__init__.py +++ b/hydragnn/utils/mace_utils/modules/__init__.py @@ -20,6 +20,7 @@ # ResidualElementDependentInteractionBlock, ScaleShiftBlock, ) + # from .loss import ( # DipoleSingleLoss, # UniversalLoss, @@ -40,6 +41,7 @@ # ) from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis from .symmetric_contraction import SymmetricContraction + # from .utils import ( # compute_avg_num_neighbors, # compute_fixed_charge_dipole, diff --git a/hydragnn/utils/mace_utils/tools/__init__.py b/hydragnn/utils/mace_utils/tools/__init__.py index ce89f5967..8d1e7bc22 100644 --- a/hydragnn/utils/mace_utils/tools/__init__.py +++ b/hydragnn/utils/mace_utils/tools/__init__.py @@ -1,6 +1,7 @@ # from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser # from .arg_parser_tools import check_args from .cg import U_matrix_real + # from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .finetuning_utils import load_foundations from .torch_tools import ( @@ -16,6 +17,7 @@ to_one_hot, voigt_to_matrix, ) + # from .train import SWAContainer, evaluate, train # from .utils import ( # AtomicNumberTable, From 076aafa2c80ddafdb2997a6c3341860b1066b081 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 13:48:53 -0400 Subject: [PATCH 24/51] formatting --- tests/test_graphs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 755b47668..2528fe4c5 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -218,7 +218,8 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): # Test only models @pytest.mark.parametrize( # "model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"] - "model_type", ["MACE"] + "model_type", + ["MACE"], ) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) From a12ce417005808a83004d94a4d200d200cac4a2f Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 14:19:51 -0400 Subject: [PATCH 25/51] debugging for GitHub test --- tests/test_model_loadpred.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_model_loadpred.py b/tests/test_model_loadpred.py index a8d650b43..af7b79539 100755 --- a/tests/test_model_loadpred.py +++ b/tests/test_model_loadpred.py @@ -12,7 +12,7 @@ import torch import random import hydragnn -from .test_graphs import unittest_train_model +from tests.test_graphs import unittest_train_model def unittest_model_prediction(config): @@ -81,11 +81,15 @@ def pytest_model_loadpred(): else: with open(config_file, "r") as f: config = json.load(f) + print("\nFIRST CHECK ON INPUT DIM:") + print(config["NeuralNetwork"]["Architecture"]["input_dim"], "\n") for dataset_name, raw_data_path in config["Dataset"]["path"].items(): if not os.path.isfile(raw_data_path): print(dataset_name, "datasets not found: ", raw_data_path) case_exist = False break + print("\nSECOND CHECK ON INPUT DIM:") + print(config["NeuralNetwork"]["Architecture"]["input_dim"], "\n") if not case_exist: unittest_train_model( config["NeuralNetwork"]["Architecture"]["model_type"], @@ -93,4 +97,10 @@ def pytest_model_loadpred(): False, False, ) + print("\nTHIRD CHECK ON INPUT DIM:") + print(config["NeuralNetwork"]["Architecture"]["input_dim"], "\n") unittest_model_prediction(config) + + +if __name__ == "__main__": + pytest_model_loadpred() \ No newline at end of file From f32d0404589e9b33f747a5b8fa03672a35b6d5ff Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 14:28:22 -0400 Subject: [PATCH 26/51] formatting --- tests/test_model_loadpred.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_model_loadpred.py b/tests/test_model_loadpred.py index af7b79539..df35f120f 100755 --- a/tests/test_model_loadpred.py +++ b/tests/test_model_loadpred.py @@ -102,5 +102,5 @@ def pytest_model_loadpred(): unittest_model_prediction(config) -if __name__ == "__main__": - pytest_model_loadpred() \ No newline at end of file +# if __name__ == "__main__": +# pytest_model_loadpred() From 31164166b29e8212d346132a42ded448a68d1d18 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 15:25:59 -0400 Subject: [PATCH 27/51] Update config to avoid key error on input_dim --- tests/test_model_loadpred.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_model_loadpred.py b/tests/test_model_loadpred.py index df35f120f..74d84b7dc 100755 --- a/tests/test_model_loadpred.py +++ b/tests/test_model_loadpred.py @@ -13,6 +13,7 @@ import random import hydragnn from tests.test_graphs import unittest_train_model +from hydragnn.utils.config_utils import update_config def unittest_model_prediction(config): @@ -23,6 +24,7 @@ def unittest_model_prediction(config): val_loader, test_loader, ) = hydragnn.preprocess.load_data.dataset_loading_and_splitting(config=config) + config = update_config(config, train_loader, val_loader, test_loader) model = hydragnn.models.create.create_model_config( config=config["NeuralNetwork"], @@ -81,15 +83,11 @@ def pytest_model_loadpred(): else: with open(config_file, "r") as f: config = json.load(f) - print("\nFIRST CHECK ON INPUT DIM:") - print(config["NeuralNetwork"]["Architecture"]["input_dim"], "\n") for dataset_name, raw_data_path in config["Dataset"]["path"].items(): if not os.path.isfile(raw_data_path): print(dataset_name, "datasets not found: ", raw_data_path) case_exist = False break - print("\nSECOND CHECK ON INPUT DIM:") - print(config["NeuralNetwork"]["Architecture"]["input_dim"], "\n") if not case_exist: unittest_train_model( config["NeuralNetwork"]["Architecture"]["model_type"], @@ -97,8 +95,6 @@ def pytest_model_loadpred(): False, False, ) - print("\nTHIRD CHECK ON INPUT DIM:") - print(config["NeuralNetwork"]["Architecture"]["input_dim"], "\n") unittest_model_prediction(config) From dbc6ec99c957887bdedd95695aa6a325e8f172ba Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 26 Sep 2024 17:48:06 -0400 Subject: [PATCH 28/51] Removing Unnecessary MACE files (draft 1) --- hydragnn/models/MACEStack.py | 3 +- hydragnn/utils/mace_utils/data/__init__.py | 34 - hydragnn/utils/mace_utils/data/atomic_data.py | 227 ---- .../utils/mace_utils/data/hdf5_dataset.py | 91 -- .../utils/mace_utils/data/neighborhood.py | 66 - hydragnn/utils/mace_utils/data/utils.py | 398 ------ hydragnn/utils/mace_utils/modules/__init__.py | 51 - hydragnn/utils/mace_utils/modules/loss.py | 367 ------ hydragnn/utils/mace_utils/modules/models.py | 1065 ----------------- hydragnn/utils/mace_utils/modules/utils.py | 351 ------ tests/test_forces_equivariant.py | 4 + 11 files changed, 6 insertions(+), 2651 deletions(-) delete mode 100644 hydragnn/utils/mace_utils/data/__init__.py delete mode 100644 hydragnn/utils/mace_utils/data/atomic_data.py delete mode 100644 hydragnn/utils/mace_utils/data/hdf5_dataset.py delete mode 100644 hydragnn/utils/mace_utils/data/neighborhood.py delete mode 100644 hydragnn/utils/mace_utils/data/utils.py delete mode 100644 hydragnn/utils/mace_utils/modules/loss.py delete mode 100644 hydragnn/utils/mace_utils/modules/models.py diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index aaeffee9d..c09fd39b8 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -12,6 +12,7 @@ # Adapted From: # GitHub: https://github.com/ACEsuit/mace # ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) ########################################################################################### # Implementation of MACE models and other models based E(3)-Equivariant MPNNs # Authors: Ilyes Batatia, Gregor Simm @@ -414,7 +415,7 @@ def _conv_args(self, data): # Create node_attrs from atomic numbers. Later on it may contain more information ## Node attrs are intrinsic properties of the atoms, like charge, atomic number, etc.. - ## data.node_attrs is already used as a method or smt in another place, so has been renamed to data.node_attributes from MACE and same with other data variable names + ## data.node_attrs is already used in another place, so has been renamed to data.node_attributes from MACE and same with other data variable names one_hot = torch.nn.functional.one_hot( data["x"].long().squeeze(-1), num_classes=118 ).float() # [n_atoms, 118] ## 118 atoms in the peridoic table diff --git a/hydragnn/utils/mace_utils/data/__init__.py b/hydragnn/utils/mace_utils/data/__init__.py deleted file mode 100644 index ace87d766..000000000 --- a/hydragnn/utils/mace_utils/data/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# from .atomic_data import AtomicData -# from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 -# from .neighborhood import get_neighborhood -# from .utils import ( -# Configuration, -# Configurations, -# compute_average_E0s, -# config_from_atoms, -# config_from_atoms_list, -# load_from_xyz, -# random_train_valid_split, -# save_AtomicData_to_HDF5, -# save_configurations_as_HDF5, -# save_dataset_as_HDF5, -# test_config_types, -# ) - -# __all__ = [ -# "get_neighborhood", -# "Configuration", -# "Configurations", -# "random_train_valid_split", -# "load_from_xyz", -# "test_config_types", -# "config_from_atoms", -# "config_from_atoms_list", -# "AtomicData", -# "compute_average_E0s", -# "save_dataset_as_HDF5", -# "HDF5Dataset", -# "dataset_from_sharded_hdf5", -# "save_AtomicData_to_HDF5", -# "save_configurations_as_HDF5", -# ] diff --git a/hydragnn/utils/mace_utils/data/atomic_data.py b/hydragnn/utils/mace_utils/data/atomic_data.py deleted file mode 100644 index ded8d438d..000000000 --- a/hydragnn/utils/mace_utils/data/atomic_data.py +++ /dev/null @@ -1,227 +0,0 @@ -# ########################################################################################### -# # Atomic Data Class for handling molecules as graphs -# # Authors: Ilyes Batatia, Gregor Simm -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# from typing import Optional, Sequence - -# import torch.utils.data - -# from hydragnn.utils.mace_utils.tools import ( -# AtomicNumberTable, -# atomic_numbers_to_indices, -# to_one_hot, -# torch_geometric, -# voigt_to_matrix, -# ) - -# from .neighborhood import get_neighborhood -# from .utils import Configuration - - -# class AtomicData(torch_geometric.data.Data): -# num_graphs: torch.Tensor -# batch: torch.Tensor -# edge_index: torch.Tensor -# node_attrs: torch.Tensor -# edge_vectors: torch.Tensor -# edge_lengths: torch.Tensor -# positions: torch.Tensor -# shifts: torch.Tensor -# unit_shifts: torch.Tensor -# cell: torch.Tensor -# forces: torch.Tensor -# energy: torch.Tensor -# stress: torch.Tensor -# virials: torch.Tensor -# dipole: torch.Tensor -# charges: torch.Tensor -# weight: torch.Tensor -# energy_weight: torch.Tensor -# forces_weight: torch.Tensor -# stress_weight: torch.Tensor -# virials_weight: torch.Tensor - -# def __init__( -# self, -# edge_index: torch.Tensor, # [2, n_edges] -# node_attrs: torch.Tensor, # [n_nodes, n_node_feats] -# positions: torch.Tensor, # [n_nodes, 3] -# shifts: torch.Tensor, # [n_edges, 3], -# unit_shifts: torch.Tensor, # [n_edges, 3] -# cell: Optional[torch.Tensor], # [3,3] -# weight: Optional[torch.Tensor], # [,] -# energy_weight: Optional[torch.Tensor], # [,] -# forces_weight: Optional[torch.Tensor], # [,] -# stress_weight: Optional[torch.Tensor], # [,] -# virials_weight: Optional[torch.Tensor], # [,] -# forces: Optional[torch.Tensor], # [n_nodes, 3] -# energy: Optional[torch.Tensor], # [, ] -# stress: Optional[torch.Tensor], # [1,3,3] -# virials: Optional[torch.Tensor], # [1,3,3] -# dipole: Optional[torch.Tensor], # [, 3] -# charges: Optional[torch.Tensor], # [n_nodes, ] -# ): -# # Check shapes -# num_nodes = node_attrs.shape[0] - -# assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 -# assert positions.shape == (num_nodes, 3) -# assert shifts.shape[1] == 3 -# assert unit_shifts.shape[1] == 3 -# assert len(node_attrs.shape) == 2 -# assert weight is None or len(weight.shape) == 0 -# assert energy_weight is None or len(energy_weight.shape) == 0 -# assert forces_weight is None or len(forces_weight.shape) == 0 -# assert stress_weight is None or len(stress_weight.shape) == 0 -# assert virials_weight is None or len(virials_weight.shape) == 0 -# assert cell is None or cell.shape == (3, 3) -# assert forces is None or forces.shape == (num_nodes, 3) -# assert energy is None or len(energy.shape) == 0 -# assert stress is None or stress.shape == (1, 3, 3) -# assert virials is None or virials.shape == (1, 3, 3) -# assert dipole is None or dipole.shape[-1] == 3 -# assert charges is None or charges.shape == (num_nodes,) -# # Aggregate data -# data = { -# "num_nodes": num_nodes, -# "edge_index": edge_index, -# "positions": positions, -# "shifts": shifts, -# "unit_shifts": unit_shifts, -# "cell": cell, -# "node_attrs": node_attrs, -# "weight": weight, -# "energy_weight": energy_weight, -# "forces_weight": forces_weight, -# "stress_weight": stress_weight, -# "virials_weight": virials_weight, -# "forces": forces, -# "energy": energy, -# "stress": stress, -# "virials": virials, -# "dipole": dipole, -# "charges": charges, -# } -# super().__init__(**data) - -# @classmethod -# def from_config( -# cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float -# ) -> "AtomicData": -# edge_index, shifts, unit_shifts = get_neighborhood( -# positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell -# ) -# indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) -# one_hot = to_one_hot( -# torch.tensor(indices, dtype=torch.long).unsqueeze(-1), -# num_classes=len(z_table), -# ) - -# cell = ( -# torch.tensor(config.cell, dtype=torch.get_default_dtype()) -# if config.cell is not None -# else torch.tensor( -# 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() -# ).view(3, 3) -# ) - -# weight = ( -# torch.tensor(config.weight, dtype=torch.get_default_dtype()) -# if config.weight is not None -# else 1 -# ) - -# energy_weight = ( -# torch.tensor(config.energy_weight, dtype=torch.get_default_dtype()) -# if config.energy_weight is not None -# else 1 -# ) - -# forces_weight = ( -# torch.tensor(config.forces_weight, dtype=torch.get_default_dtype()) -# if config.forces_weight is not None -# else 1 -# ) - -# stress_weight = ( -# torch.tensor(config.stress_weight, dtype=torch.get_default_dtype()) -# if config.stress_weight is not None -# else 1 -# ) - -# virials_weight = ( -# torch.tensor(config.virials_weight, dtype=torch.get_default_dtype()) -# if config.virials_weight is not None -# else 1 -# ) - -# forces = ( -# torch.tensor(config.forces, dtype=torch.get_default_dtype()) -# if config.forces is not None -# else None -# ) -# energy = ( -# torch.tensor(config.energy, dtype=torch.get_default_dtype()) -# if config.energy is not None -# else None -# ) -# stress = ( -# voigt_to_matrix( -# torch.tensor(config.stress, dtype=torch.get_default_dtype()) -# ).unsqueeze(0) -# if config.stress is not None -# else None -# ) -# virials = ( -# voigt_to_matrix( -# torch.tensor(config.virials, dtype=torch.get_default_dtype()) -# ).unsqueeze(0) -# if config.virials is not None -# else None -# ) -# dipole = ( -# torch.tensor(config.dipole, dtype=torch.get_default_dtype()).unsqueeze(0) -# if config.dipole is not None -# else None -# ) -# charges = ( -# torch.tensor(config.charges, dtype=torch.get_default_dtype()) -# if config.charges is not None -# else None -# ) - -# return cls( -# edge_index=torch.tensor(edge_index, dtype=torch.long), -# positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), -# shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), -# unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), -# cell=cell, -# node_attrs=one_hot, -# weight=weight, -# energy_weight=energy_weight, -# forces_weight=forces_weight, -# stress_weight=stress_weight, -# virials_weight=virials_weight, -# forces=forces, -# energy=energy, -# stress=stress, -# virials=virials, -# dipole=dipole, -# charges=charges, -# ) - - -# def get_data_loader( -# dataset: Sequence[AtomicData], -# batch_size: int, -# shuffle=True, -# drop_last=False, -# ) -> torch.utils.data.DataLoader: -# return torch_geometric.dataloader.DataLoader( -# dataset=dataset, -# batch_size=batch_size, -# shuffle=shuffle, -# drop_last=drop_last, -# ) diff --git a/hydragnn/utils/mace_utils/data/hdf5_dataset.py b/hydragnn/utils/mace_utils/data/hdf5_dataset.py deleted file mode 100644 index f617e02f5..000000000 --- a/hydragnn/utils/mace_utils/data/hdf5_dataset.py +++ /dev/null @@ -1,91 +0,0 @@ -# from glob import glob -# from typing import List - -# from torch.utils.data import ConcatDataset, Dataset - -# # Try import but pass otherwise -# try: -# import h5py -# except ImportError: -# pass - -# from hydragnn.utils.mace_utils.data.atomic_data import AtomicData -# from hydragnn.utils.mace_utils.data.utils import Configuration -# from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable - - -# class HDF5Dataset(Dataset): -# def __init__(self, file_path, r_max, z_table, **kwargs): -# super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments -# self.file_path = file_path -# self._file = None -# batch_key = list(self.file.keys())[0] -# self.batch_size = len(self.file[batch_key].keys()) -# self.length = len(self.file.keys()) * self.batch_size -# self.r_max = r_max -# self.z_table = z_table -# try: -# self.drop_last = bool(self.file.attrs["drop_last"]) -# except KeyError: -# self.drop_last = False -# self.kwargs = kwargs - -# @property -# def file(self): -# if self._file is None: -# # If a file has not already been opened, open one here -# self._file = h5py.File(self.file_path, "r") -# return self._file - -# def __getstate__(self): -# _d = dict(self.__dict__) - -# # An opened h5py.File cannot be pickled, so we must exclude it from the state -# _d["_file"] = None -# return _d - -# def __len__(self): -# return self.length - -# def __getitem__(self, index): -# # compute the index of the batch -# batch_index = index // self.batch_size -# config_index = index % self.batch_size -# grp = self.file["config_batch_" + str(batch_index)] -# subgrp = grp["config_" + str(config_index)] -# config = Configuration( -# atomic_numbers=subgrp["atomic_numbers"][()], -# positions=subgrp["positions"][()], -# energy=unpack_value(subgrp["energy"][()]), -# forces=unpack_value(subgrp["forces"][()]), -# stress=unpack_value(subgrp["stress"][()]), -# virials=unpack_value(subgrp["virials"][()]), -# dipole=unpack_value(subgrp["dipole"][()]), -# charges=unpack_value(subgrp["charges"][()]), -# weight=unpack_value(subgrp["weight"][()]), -# energy_weight=unpack_value(subgrp["energy_weight"][()]), -# forces_weight=unpack_value(subgrp["forces_weight"][()]), -# stress_weight=unpack_value(subgrp["stress_weight"][()]), -# virials_weight=unpack_value(subgrp["virials_weight"][()]), -# config_type=unpack_value(subgrp["config_type"][()]), -# pbc=unpack_value(subgrp["pbc"][()]), -# cell=unpack_value(subgrp["cell"][()]), -# ) -# atomic_data = AtomicData.from_config( -# config, z_table=self.z_table, cutoff=self.r_max -# ) -# return atomic_data - - -# def dataset_from_sharded_hdf5(files: List, z_table: AtomicNumberTable, r_max: float): -# files = glob(files + "/*") -# datasets = [] -# for file in files: -# datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max)) -# full_dataset = ConcatDataset(datasets) -# return full_dataset - - -# def unpack_value(value): -# value = value.decode("utf-8") if isinstance(value, bytes) else value -# return None if str(value) == "None" else value diff --git a/hydragnn/utils/mace_utils/data/neighborhood.py b/hydragnn/utils/mace_utils/data/neighborhood.py deleted file mode 100644 index 5bd70b6eb..000000000 --- a/hydragnn/utils/mace_utils/data/neighborhood.py +++ /dev/null @@ -1,66 +0,0 @@ -# from typing import Optional, Tuple - -# import numpy as np -# from matscipy.neighbours import neighbour_list - - -# def get_neighborhood( -# positions: np.ndarray, # [num_positions, 3] -# cutoff: float, -# pbc: Optional[Tuple[bool, bool, bool]] = None, -# cell: Optional[np.ndarray] = None, # [3, 3] -# true_self_interaction=False, -# ) -> Tuple[np.ndarray, np.ndarray]: -# if pbc is None: -# pbc = (False, False, False) - -# if cell is None or cell.any() == np.zeros((3, 3)).any(): -# cell = np.identity(3, dtype=float) - -# assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) -# assert cell.shape == (3, 3) - -# pbc_x = pbc[0] -# pbc_y = pbc[1] -# pbc_z = pbc[2] -# identity = np.identity(3, dtype=float) -# max_positions = np.max(np.absolute(positions)) + 1 -# # Extend cell in non-periodic directions -# # For models with more than 5 layers, the multiplicative constant needs to be increased. -# temp_cell = np.copy(cell) -# if not pbc_x: -# temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :] -# if not pbc_y: -# temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :] -# if not pbc_z: -# temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :] - -# sender, receiver, unit_shifts = neighbour_list( -# quantities="ijS", -# pbc=pbc, -# cell=temp_cell, -# positions=positions, -# cutoff=cutoff, -# # self_interaction=True, # we want edges from atom to itself in different periodic images -# # use_scaled_positions=False, # positions are not scaled positions -# ) - -# if not true_self_interaction: -# # Eliminate self-edges that don't cross periodic boundaries -# true_self_edge = sender == receiver -# true_self_edge &= np.all(unit_shifts == 0, axis=1) -# keep_edge = ~true_self_edge - -# # Note: after eliminating self-edges, it can be that no edges remain in this system -# sender = sender[keep_edge] -# receiver = receiver[keep_edge] -# unit_shifts = unit_shifts[keep_edge] - -# # Build output -# edge_index = np.stack((sender, receiver)) # [2, n_edges] - -# # From the docs: With the shift vector S, the distances D between atoms can be computed from -# # D = positions[j]-positions[i]+S.dot(cell) -# shifts = np.dot(unit_shifts, cell) # [n_edges, 3] - -# return edge_index, shifts, unit_shifts diff --git a/hydragnn/utils/mace_utils/data/utils.py b/hydragnn/utils/mace_utils/data/utils.py deleted file mode 100644 index 0eb3cd187..000000000 --- a/hydragnn/utils/mace_utils/data/utils.py +++ /dev/null @@ -1,398 +0,0 @@ -# ########################################################################################### -# # Data parsing utilities -# # Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# import logging -# from dataclasses import dataclass -# from typing import Dict, List, Optional, Sequence, Tuple - -# import ase.data -# import ase.io -# import numpy as np - -# # Try import but pass otherwise -# try: -# import h5py -# except ImportError: -# pass - -# from hydragnn.utils.mace_utils.tools import AtomicNumberTable - -# Vector = np.ndarray # [3,] -# Positions = np.ndarray # [..., 3] -# Forces = np.ndarray # [..., 3] -# Stress = np.ndarray # [6, ], [3,3], [9, ] -# Virials = np.ndarray # [6, ], [3,3], [9, ] -# Charges = np.ndarray # [..., 1] -# Cell = np.ndarray # [3,3] -# Pbc = tuple # (3,) - -# DEFAULT_CONFIG_TYPE = "Default" -# DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} - - -# @dataclass -# class Configuration: -# atomic_numbers: np.ndarray -# positions: Positions # Angstrom -# energy: Optional[float] = None # eV -# forces: Optional[Forces] = None # eV/Angstrom -# stress: Optional[Stress] = None # eV/Angstrom^3 -# virials: Optional[Virials] = None # eV -# dipole: Optional[Vector] = None # Debye -# charges: Optional[Charges] = None # atomic unit -# cell: Optional[Cell] = None -# pbc: Optional[Pbc] = None - -# weight: float = 1.0 # weight of config in loss -# energy_weight: float = 1.0 # weight of config energy in loss -# forces_weight: float = 1.0 # weight of config forces in loss -# stress_weight: float = 1.0 # weight of config stress in loss -# virials_weight: float = 1.0 # weight of config virial in loss -# config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config - - -# Configurations = List[Configuration] - - -# def random_train_valid_split( -# items: Sequence, valid_fraction: float, seed: int, work_dir: str -# ) -> Tuple[List, List]: -# assert 0.0 < valid_fraction < 1.0 - -# size = len(items) -# train_size = size - int(valid_fraction * size) - -# indices = list(range(size)) -# rng = np.random.default_rng(seed) -# rng.shuffle(indices) -# if len(indices[train_size:]) < 10: -# logging.info( -# f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" -# ) -# else: -# # Save indices to file -# with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: -# for index in indices[train_size:]: -# f.write(f"{index}\n") - -# logging.info( -# f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" -# ) - -# return ( -# [items[i] for i in indices[:train_size]], -# [items[i] for i in indices[train_size:]], -# ) - - -# def config_from_atoms_list( -# atoms_list: List[ase.Atoms], -# energy_key="REF_energy", -# forces_key="REF_forces", -# stress_key="REF_stress", -# virials_key="REF_virials", -# dipole_key="REF_dipole", -# charges_key="REF_charges", -# config_type_weights: Dict[str, float] = None, -# ) -> Configurations: -# """Convert list of ase.Atoms into Configurations""" -# if config_type_weights is None: -# config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - -# all_configs = [] -# for atoms in atoms_list: -# all_configs.append( -# config_from_atoms( -# atoms, -# energy_key=energy_key, -# forces_key=forces_key, -# stress_key=stress_key, -# virials_key=virials_key, -# dipole_key=dipole_key, -# charges_key=charges_key, -# config_type_weights=config_type_weights, -# ) -# ) -# return all_configs - - -# def config_from_atoms( -# atoms: ase.Atoms, -# energy_key="REF_energy", -# forces_key="REF_forces", -# stress_key="REF_stress", -# virials_key="REF_virials", -# dipole_key="REF_dipole", -# charges_key="REF_charges", -# config_type_weights: Dict[str, float] = None, -# ) -> Configuration: -# """Convert ase.Atoms to Configuration""" -# if config_type_weights is None: -# config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - -# energy = atoms.info.get(energy_key, None) # eV -# forces = atoms.arrays.get(forces_key, None) # eV / Ang -# stress = atoms.info.get(stress_key, None) # eV / Ang ^ 3 -# virials = atoms.info.get(virials_key, None) -# dipole = atoms.info.get(dipole_key, None) # Debye -# # Charges default to 0 instead of None if not found -# charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit -# atomic_numbers = np.array( -# [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] -# ) -# pbc = tuple(atoms.get_pbc()) -# cell = np.array(atoms.get_cell()) -# config_type = atoms.info.get("config_type", "Default") -# weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( -# config_type, 1.0 -# ) -# energy_weight = atoms.info.get("config_energy_weight", 1.0) -# forces_weight = atoms.info.get("config_forces_weight", 1.0) -# stress_weight = atoms.info.get("config_stress_weight", 1.0) -# virials_weight = atoms.info.get("config_virials_weight", 1.0) - -# # fill in missing quantities but set their weight to 0.0 -# if energy is None: -# energy = 0.0 -# energy_weight = 0.0 -# if forces is None: -# forces = np.zeros(np.shape(atoms.positions)) -# forces_weight = 0.0 -# if stress is None: -# stress = np.zeros(6) -# stress_weight = 0.0 -# if virials is None: -# virials = np.zeros((3, 3)) -# virials_weight = 0.0 -# if dipole is None: -# dipole = np.zeros(3) -# # dipoles_weight = 0.0 - -# return Configuration( -# atomic_numbers=atomic_numbers, -# positions=atoms.get_positions(), -# energy=energy, -# forces=forces, -# stress=stress, -# virials=virials, -# dipole=dipole, -# charges=charges, -# weight=weight, -# energy_weight=energy_weight, -# forces_weight=forces_weight, -# stress_weight=stress_weight, -# virials_weight=virials_weight, -# config_type=config_type, -# pbc=pbc, -# cell=cell, -# ) - - -# def test_config_types( -# test_configs: Configurations, -# ) -> List[Tuple[Optional[str], List[Configuration]]]: -# """Split test set based on config_type-s""" -# test_by_ct = [] -# all_cts = [] -# for conf in test_configs: -# if conf.config_type not in all_cts: -# all_cts.append(conf.config_type) -# test_by_ct.append((conf.config_type, [conf])) -# else: -# ind = all_cts.index(conf.config_type) -# test_by_ct[ind][1].append(conf) -# return test_by_ct - - -# def load_from_xyz( -# file_path: str, -# config_type_weights: Dict, -# energy_key: str = "REF_energy", -# forces_key: str = "REF_forces", -# stress_key: str = "REF_stress", -# virials_key: str = "REF_virials", -# dipole_key: str = "REF_dipole", -# charges_key: str = "REF_charges", -# extract_atomic_energies: bool = False, -# keep_isolated_atoms: bool = False, -# ) -> Tuple[Dict[int, float], Configurations]: -# atoms_list = ase.io.read(file_path, index=":") -# if energy_key == "energy": -# logging.warning( -# "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." -# ) -# energy_key = "REF_energy" -# for atoms in atoms_list: -# try: -# atoms.info["REF_energy"] = atoms.get_potential_energy() -# except Exception as e: # pylint: disable=W0703 -# logging.error(f"Failed to extract energy: {e}") -# atoms.info["REF_energy"] = None -# if forces_key == "forces": -# logging.warning( -# "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." -# ) -# forces_key = "REF_forces" -# for atoms in atoms_list: -# try: -# atoms.arrays["REF_forces"] = atoms.get_forces() -# except Exception as e: # pylint: disable=W0703 -# logging.error(f"Failed to extract forces: {e}") -# atoms.arrays["REF_forces"] = None -# if stress_key == "stress": -# logging.warning( -# "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." -# ) -# stress_key = "REF_stress" -# for atoms in atoms_list: -# try: -# atoms.info["REF_stress"] = atoms.get_stress() -# except Exception as e: # pylint: disable=W0703 -# atoms.info["REF_stress"] = None -# if not isinstance(atoms_list, list): -# atoms_list = [atoms_list] - -# atomic_energies_dict = {} -# if extract_atomic_energies: -# atoms_without_iso_atoms = [] - -# for idx, atoms in enumerate(atoms_list): -# isolated_atom_config = ( -# len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" -# ) -# if isolated_atom_config: -# if energy_key in atoms.info.keys(): -# atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ -# energy_key -# ] -# else: -# logging.warning( -# f"Configuration '{idx}' is marked as 'IsolatedAtom' " -# "but does not contain an energy. Zero energy will be used." -# ) -# atomic_energies_dict[atoms.get_atomic_numbers()[0]] = np.zeros(1) -# else: -# atoms_without_iso_atoms.append(atoms) - -# if len(atomic_energies_dict) > 0: -# logging.info("Using isolated atom energies from training file") -# if not keep_isolated_atoms: -# atoms_list = atoms_without_iso_atoms - -# configs = config_from_atoms_list( -# atoms_list, -# config_type_weights=config_type_weights, -# energy_key=energy_key, -# forces_key=forces_key, -# stress_key=stress_key, -# virials_key=virials_key, -# dipole_key=dipole_key, -# charges_key=charges_key, -# ) -# return atomic_energies_dict, configs - - -# def compute_average_E0s( -# collections_train: Configurations, z_table: AtomicNumberTable -# ) -> Dict[int, float]: -# """ -# Function to compute the average interaction energy of each chemical element -# returns dictionary of E0s -# """ -# len_train = len(collections_train) -# len_zs = len(z_table) -# A = np.zeros((len_train, len_zs)) -# B = np.zeros(len_train) -# for i in range(len_train): -# B[i] = collections_train[i].energy -# for j, z in enumerate(z_table.zs): -# A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) -# try: -# E0s = np.linalg.lstsq(A, B, rcond=None)[0] -# atomic_energies_dict = {} -# for i, z in enumerate(z_table.zs): -# atomic_energies_dict[z] = E0s[i] -# except np.linalg.LinAlgError: -# logging.error( -# "Failed to compute E0s using least squares regression, using the same for all atoms" -# ) -# atomic_energies_dict = {} -# for i, z in enumerate(z_table.zs): -# atomic_energies_dict[z] = 0.0 -# return atomic_energies_dict - - -# def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: -# with h5py.File(out_name, "w") as f: -# for i, data in enumerate(dataset): -# grp = f.create_group(f"config_{i}") -# grp["num_nodes"] = data.num_nodes -# grp["edge_index"] = data.edge_index -# grp["positions"] = data.positions -# grp["shifts"] = data.shifts -# grp["unit_shifts"] = data.unit_shifts -# grp["cell"] = data.cell -# grp["node_attrs"] = data.node_attrs -# grp["weight"] = data.weight -# grp["energy_weight"] = data.energy_weight -# grp["forces_weight"] = data.forces_weight -# grp["stress_weight"] = data.stress_weight -# grp["virials_weight"] = data.virials_weight -# grp["forces"] = data.forces -# grp["energy"] = data.energy -# grp["stress"] = data.stress -# grp["virials"] = data.virials -# grp["dipole"] = data.dipole -# grp["charges"] = data.charges - - -# def save_AtomicData_to_HDF5(data, i, h5_file) -> None: -# grp = h5_file.create_group(f"config_{i}") -# grp["num_nodes"] = data.num_nodes -# grp["edge_index"] = data.edge_index -# grp["positions"] = data.positions -# grp["shifts"] = data.shifts -# grp["unit_shifts"] = data.unit_shifts -# grp["cell"] = data.cell -# grp["node_attrs"] = data.node_attrs -# grp["weight"] = data.weight -# grp["energy_weight"] = data.energy_weight -# grp["forces_weight"] = data.forces_weight -# grp["stress_weight"] = data.stress_weight -# grp["virials_weight"] = data.virials_weight -# grp["forces"] = data.forces -# grp["energy"] = data.energy -# grp["stress"] = data.stress -# grp["virials"] = data.virials -# grp["dipole"] = data.dipole -# grp["charges"] = data.charges - - -# def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: -# grp = h5_file.create_group("config_batch_0") -# for j, config in enumerate(configurations): -# subgroup_name = f"config_{j}" -# subgroup = grp.create_group(subgroup_name) -# subgroup["atomic_numbers"] = write_value(config.atomic_numbers) -# subgroup["positions"] = write_value(config.positions) -# subgroup["energy"] = write_value(config.energy) -# subgroup["forces"] = write_value(config.forces) -# subgroup["stress"] = write_value(config.stress) -# subgroup["virials"] = write_value(config.virials) -# subgroup["dipole"] = write_value(config.dipole) -# subgroup["charges"] = write_value(config.charges) -# subgroup["cell"] = write_value(config.cell) -# subgroup["pbc"] = write_value(config.pbc) -# subgroup["weight"] = write_value(config.weight) -# subgroup["energy_weight"] = write_value(config.energy_weight) -# subgroup["forces_weight"] = write_value(config.forces_weight) -# subgroup["stress_weight"] = write_value(config.stress_weight) -# subgroup["virials_weight"] = write_value(config.virials_weight) -# subgroup["config_type"] = write_value(config.config_type) - - -# def write_value(value): -# return value if value is not None else "None" diff --git a/hydragnn/utils/mace_utils/modules/__init__.py b/hydragnn/utils/mace_utils/modules/__init__.py index a9b38f67b..971aeb294 100644 --- a/hydragnn/utils/mace_utils/modules/__init__.py +++ b/hydragnn/utils/mace_utils/modules/__init__.py @@ -21,36 +21,9 @@ ScaleShiftBlock, ) -# from .loss import ( -# DipoleSingleLoss, -# UniversalLoss, -# WeightedEnergyForcesDipoleLoss, -# WeightedEnergyForcesLoss, -# WeightedEnergyForcesStressLoss, -# WeightedEnergyForcesVirialsLoss, -# WeightedForcesLoss, -# WeightedHuberEnergyForcesStressLoss, -# ) -# from .models import ( -# MACE, -# AtomicDipolesMACE, -# BOTNet, -# EnergyDipolesMACE, -# ScaleShiftBOTNet, -# ScaleShiftMACE, -# ) from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis from .symmetric_contraction import SymmetricContraction -# from .utils import ( -# compute_avg_num_neighbors, -# compute_fixed_charge_dipole, -# compute_mean_rms_energy_forces, -# compute_mean_std_atomic_inter_energy, -# compute_rms_dipoles, -# compute_statistics, -# ) - interaction_classes: Dict[str, Type[InteractionBlock]] = { # "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, # "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, @@ -60,12 +33,6 @@ # "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, } -# scaling_classes: Dict[str, Callable] = { -# "std_scaling": compute_mean_std_atomic_inter_energy, -# "rms_forces_scaling": compute_mean_rms_energy_forces, -# "rms_dipoles_scaling": compute_rms_dipoles, -# } - gate_dict: Dict[str, Optional[Callable]] = { "abs": torch.abs, "tanh": torch.tanh, @@ -88,24 +55,6 @@ "PolynomialCutoff", "BesselBasis", "GaussianBasis", - "MACE", - "ScaleShiftMACE", - "BOTNet", - "ScaleShiftBOTNet", - "AtomicDipolesMACE", - "EnergyDipolesMACE", - "WeightedEnergyForcesLoss", - "WeightedForcesLoss", - "WeightedEnergyForcesVirialsLoss", - "WeightedEnergyForcesStressLoss", - "DipoleSingleLoss", - "WeightedEnergyForcesDipoleLoss", - "WeightedHuberEnergyForcesStressLoss", - "UniversalLoss", "SymmetricContraction", "interaction_classes", - "compute_mean_std_atomic_inter_energy", - "compute_avg_num_neighbors", - "compute_statistics", - "compute_fixed_charge_dipole", ] diff --git a/hydragnn/utils/mace_utils/modules/loss.py b/hydragnn/utils/mace_utils/modules/loss.py deleted file mode 100644 index 9ece6e0c2..000000000 --- a/hydragnn/utils/mace_utils/modules/loss.py +++ /dev/null @@ -1,367 +0,0 @@ -# ########################################################################################### -# # Implementation of different loss functions -# # Authors: Ilyes Batatia, Gregor Simm -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# import torch - -# from hydragnn.utils.mace_utils.tools import TensorDict -# from hydragnn.utils.mace_utils.tools.torch_geometric import Batch - - -# def mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: -# # energy: [n_graphs, ] -# return torch.mean(torch.square(ref["energy"] - pred["energy"])) # [] - - -# def weighted_mean_squared_error_energy(ref: Batch, pred: TensorDict) -> torch.Tensor: -# # energy: [n_graphs, ] -# configs_weight = ref.weight # [n_graphs, ] -# configs_energy_weight = ref.energy_weight # [n_graphs, ] -# num_atoms = ref.ptr[1:] - ref.ptr[:-1] # [n_graphs,] -# return torch.mean( -# configs_weight -# * configs_energy_weight -# * torch.square((ref["energy"] - pred["energy"]) / num_atoms) -# ) # [] - - -# def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor: -# # energy: [n_graphs, ] -# configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] -# configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] -# return torch.mean( -# configs_weight -# * configs_stress_weight -# * torch.square(ref["stress"] - pred["stress"]) -# ) # [] - - -# def weighted_mean_squared_virials(ref: Batch, pred: TensorDict) -> torch.Tensor: -# # energy: [n_graphs, ] -# configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] -# configs_virials_weight = ref.virials_weight.view(-1, 1, 1) # [n_graphs, ] -# num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,] -# return torch.mean( -# configs_weight -# * configs_virials_weight -# * torch.square((ref["virials"] - pred["virials"]) / num_atoms) -# ) # [] - - -# def mean_squared_error_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: -# # forces: [n_atoms, 3] -# configs_weight = torch.repeat_interleave( -# ref.weight, ref.ptr[1:] - ref.ptr[:-1] -# ).unsqueeze( -# -1 -# ) # [n_atoms, 1] -# configs_forces_weight = torch.repeat_interleave( -# ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] -# ).unsqueeze( -# -1 -# ) # [n_atoms, 1] -# return torch.mean( -# configs_weight -# * configs_forces_weight -# * torch.square(ref["forces"] - pred["forces"]) -# ) # [] - - -# def weighted_mean_squared_error_dipole(ref: Batch, pred: TensorDict) -> torch.Tensor: -# # dipole: [n_graphs, ] -# num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) # [n_graphs,1] -# return torch.mean(torch.square((ref["dipole"] - pred["dipole"]) / num_atoms)) # [] -# # return torch.mean(torch.square((torch.reshape(ref['dipole'], pred["dipole"].shape) - pred['dipole']) / num_atoms)) # [] - - -# def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: -# # forces: [n_atoms, 3] -# configs_weight = torch.repeat_interleave( -# ref.weight, ref.ptr[1:] - ref.ptr[:-1] -# ).unsqueeze( -# -1 -# ) # [n_atoms, 1] -# configs_forces_weight = torch.repeat_interleave( -# ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] -# ).unsqueeze( -# -1 -# ) # [n_atoms, 1] - -# # Define the multiplication factors for each condition -# factors = torch.tensor([1.0, 0.7, 0.4, 0.1]) - -# # Apply multiplication factors based on conditions -# c1 = torch.norm(ref["forces"], dim=-1) < 100 -# c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( -# torch.norm(ref["forces"], dim=-1) < 200 -# ) -# c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( -# torch.norm(ref["forces"], dim=-1) < 300 -# ) - -# err = ref["forces"] - pred["forces"] - -# se = torch.zeros_like(err) - -# se[c1] = torch.square(err[c1]) * factors[0] -# se[c2] = torch.square(err[c2]) * factors[1] -# se[c3] = torch.square(err[c3]) * factors[2] -# se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] - -# return torch.mean(configs_weight * configs_forces_weight * se) - - -# def conditional_huber_forces( -# ref: Batch, pred: TensorDict, huber_delta: float -# ) -> torch.Tensor: -# # Define the multiplication factors for each condition -# factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) - -# # Apply multiplication factors based on conditions -# c1 = torch.norm(ref["forces"], dim=-1) < 100 -# c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( -# torch.norm(ref["forces"], dim=-1) < 200 -# ) -# c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( -# torch.norm(ref["forces"], dim=-1) < 300 -# ) -# c4 = ~(c1 | c2 | c3) - -# se = torch.zeros_like(pred["forces"]) - -# se[c1] = torch.nn.functional.huber_loss( -# ref["forces"][c1], pred["forces"][c1], reduction="none", delta=factors[0] -# ) -# se[c2] = torch.nn.functional.huber_loss( -# ref["forces"][c2], pred["forces"][c2], reduction="none", delta=factors[1] -# ) -# se[c3] = torch.nn.functional.huber_loss( -# ref["forces"][c3], pred["forces"][c3], reduction="none", delta=factors[2] -# ) -# se[c4] = torch.nn.functional.huber_loss( -# ref["forces"][c4], pred["forces"][c4], reduction="none", delta=factors[3] -# ) - -# return torch.mean(se) - - -# class WeightedEnergyForcesLoss(torch.nn.Module): -# def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: -# super().__init__() -# self.register_buffer( -# "energy_weight", -# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "forces_weight", -# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# return self.energy_weight * weighted_mean_squared_error_energy( -# ref, pred -# ) + self.forces_weight * mean_squared_error_forces(ref, pred) - -# def __repr__(self): -# return ( -# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " -# f"forces_weight={self.forces_weight:.3f})" -# ) - - -# class WeightedForcesLoss(torch.nn.Module): -# def __init__(self, forces_weight=1.0) -> None: -# super().__init__() -# self.register_buffer( -# "forces_weight", -# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# return self.forces_weight * mean_squared_error_forces(ref, pred) - -# def __repr__(self): -# return f"{self.__class__.__name__}(" f"forces_weight={self.forces_weight:.3f})" - - -# class WeightedEnergyForcesStressLoss(torch.nn.Module): -# def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: -# super().__init__() -# self.register_buffer( -# "energy_weight", -# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "forces_weight", -# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "stress_weight", -# torch.tensor(stress_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# return ( -# self.energy_weight * weighted_mean_squared_error_energy(ref, pred) -# + self.forces_weight * mean_squared_error_forces(ref, pred) -# + self.stress_weight * weighted_mean_squared_stress(ref, pred) -# ) - -# def __repr__(self): -# return ( -# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " -# f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" -# ) - - -# class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): -# def __init__( -# self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 -# ) -> None: -# super().__init__() -# self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) -# self.register_buffer( -# "energy_weight", -# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "forces_weight", -# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "stress_weight", -# torch.tensor(stress_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# num_atoms = ref.ptr[1:] - ref.ptr[:-1] -# return ( -# self.energy_weight -# * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) -# + self.forces_weight * self.huber_loss(ref["forces"], pred["forces"]) -# + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) -# ) - -# def __repr__(self): -# return ( -# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " -# f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" -# ) - - -# class UniversalLoss(torch.nn.Module): -# def __init__( -# self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 -# ) -> None: -# super().__init__() -# self.huber_delta = huber_delta -# self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) -# self.register_buffer( -# "energy_weight", -# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "forces_weight", -# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "stress_weight", -# torch.tensor(stress_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# num_atoms = ref.ptr[1:] - ref.ptr[:-1] -# return ( -# self.energy_weight -# * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) -# + self.forces_weight -# * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) -# + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) -# ) - -# def __repr__(self): -# return ( -# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " -# f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" -# ) - - -# class WeightedEnergyForcesVirialsLoss(torch.nn.Module): -# def __init__( -# self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 -# ) -> None: -# super().__init__() -# self.register_buffer( -# "energy_weight", -# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "forces_weight", -# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "virials_weight", -# torch.tensor(virials_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# return ( -# self.energy_weight * weighted_mean_squared_error_energy(ref, pred) -# + self.forces_weight * mean_squared_error_forces(ref, pred) -# + self.virials_weight * weighted_mean_squared_virials(ref, pred) -# ) - -# def __repr__(self): -# return ( -# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " -# f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" -# ) - - -# class DipoleSingleLoss(torch.nn.Module): -# def __init__(self, dipole_weight=1.0) -> None: -# super().__init__() -# self.register_buffer( -# "dipole_weight", -# torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# return ( -# self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100.0 -# ) # multiply by 100 to have the right scale for the loss - -# def __repr__(self): -# return f"{self.__class__.__name__}(" f"dipole_weight={self.dipole_weight:.3f})" - - -# class WeightedEnergyForcesDipoleLoss(torch.nn.Module): -# def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: -# super().__init__() -# self.register_buffer( -# "energy_weight", -# torch.tensor(energy_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "forces_weight", -# torch.tensor(forces_weight, dtype=torch.get_default_dtype()), -# ) -# self.register_buffer( -# "dipole_weight", -# torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), -# ) - -# def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: -# return ( -# self.energy_weight * weighted_mean_squared_error_energy(ref, pred) -# + self.forces_weight * mean_squared_error_forces(ref, pred) -# + self.dipole_weight * weighted_mean_squared_error_dipole(ref, pred) * 100 -# ) - -# def __repr__(self): -# return ( -# f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " -# f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" -# ) diff --git a/hydragnn/utils/mace_utils/modules/models.py b/hydragnn/utils/mace_utils/modules/models.py deleted file mode 100644 index cc87fed91..000000000 --- a/hydragnn/utils/mace_utils/modules/models.py +++ /dev/null @@ -1,1065 +0,0 @@ -# ########################################################################################### -# # Implementation of MACE models and other models based E(3)-Equivariant MPNNs -# # Authors: Ilyes Batatia, Gregor Simm -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# from typing import Any, Callable, Dict, List, Optional, Type, Union - -# import numpy as np -# import torch -# from e3nn import o3 -# from e3nn.util.jit import compile_mode - -# from hydragnn.utils.mace_utils.data import AtomicData -# from hydragnn.utils.mace_utils.modules.radial import ZBLBasis -# from hydragnn.utils.mace_utils.tools.scatter import scatter_sum - -# from .blocks import ( -# AtomicEnergiesBlock, -# EquivariantProductBasisBlock, -# InteractionBlock, -# LinearDipoleReadoutBlock, -# LinearNodeEmbeddingBlock, -# LinearReadoutBlock, -# NonLinearDipoleReadoutBlock, -# NonLinearReadoutBlock, -# RadialEmbeddingBlock, -# ScaleShiftBlock, -# ) -# from .utils import ( -# compute_fixed_charge_dipole, -# compute_forces, -# get_edge_vectors_and_lengths, -# get_outputs, -# get_symmetric_displacement, -# ) - -# # pylint: disable=C0302 - - -# @compile_mode("script") -# class MACE(torch.nn.Module): -# def __init__( -# self, -# r_max: float, -# num_bessel: int, -# num_polynomial_cutoff: int, -# max_ell: int, -# interaction_cls: Type[InteractionBlock], -# interaction_cls_first: Type[InteractionBlock], -# num_interactions: int, -# num_elements: int, -# hidden_irreps: o3.Irreps, -# MLP_irreps: o3.Irreps, -# atomic_energies: np.ndarray, -# avg_num_neighbors: float, -# atomic_numbers: List[int], -# correlation: Union[int, List[int]], -# gate: Optional[Callable], -# pair_repulsion: bool = False, -# distance_transform: str = "None", -# radial_MLP: Optional[List[int]] = None, -# radial_type: Optional[str] = "bessel", -# ): -# super().__init__() -# self.register_buffer( -# "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) -# ) -# self.register_buffer( -# "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) -# ) -# self.register_buffer( -# "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) -# ) -# if isinstance(correlation, int): -# correlation = [correlation] * num_interactions -# # Embedding -# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) -# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) -# self.node_embedding = LinearNodeEmbeddingBlock( -# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps -# ) -# self.radial_embedding = RadialEmbeddingBlock( -# r_max=r_max, -# num_bessel=num_bessel, -# num_polynomial_cutoff=num_polynomial_cutoff, -# radial_type=radial_type, -# distance_transform=distance_transform, -# ) -# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") -# if pair_repulsion: -# self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) -# self.pair_repulsion = True - -# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) -# num_features = hidden_irreps.count(o3.Irrep(0, 1)) -# interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() -# self.spherical_harmonics = o3.SphericalHarmonics( -# sh_irreps, normalize=True, normalization="component" -# ) -# if radial_MLP is None: -# radial_MLP = [64, 64, 64] -# # Interactions and readout -# self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - -# inter = interaction_cls_first( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=node_feats_irreps, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=interaction_irreps, -# hidden_irreps=hidden_irreps, -# avg_num_neighbors=avg_num_neighbors, -# radial_MLP=radial_MLP, -# ) -# self.interactions = torch.nn.ModuleList([inter]) - -# # Use the appropriate self connection at the first layer for proper E0 -# use_sc_first = False -# if "Residual" in str(interaction_cls_first): -# use_sc_first = True - -# node_feats_irreps_out = inter.target_irreps -# prod = EquivariantProductBasisBlock( -# node_feats_irreps=node_feats_irreps_out, -# target_irreps=hidden_irreps, -# correlation=correlation[0], -# num_elements=num_elements, -# use_sc=use_sc_first, -# ) -# self.products = torch.nn.ModuleList([prod]) - -# self.readouts = torch.nn.ModuleList() -# self.readouts.append(LinearReadoutBlock(hidden_irreps)) - -# for i in range(num_interactions - 1): -# if i == num_interactions - 2: -# hidden_irreps_out = str( -# hidden_irreps[0] -# ) # Select only scalars for last layer -# else: -# hidden_irreps_out = hidden_irreps -# inter = interaction_cls( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=hidden_irreps, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=interaction_irreps, -# hidden_irreps=hidden_irreps_out, -# avg_num_neighbors=avg_num_neighbors, -# radial_MLP=radial_MLP, -# ) -# self.interactions.append(inter) -# prod = EquivariantProductBasisBlock( -# node_feats_irreps=interaction_irreps, -# target_irreps=hidden_irreps_out, -# correlation=correlation[i + 1], -# num_elements=num_elements, -# use_sc=True, -# ) -# self.products.append(prod) -# if i == num_interactions - 2: -# self.readouts.append( -# NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) -# ) -# else: -# self.readouts.append(LinearReadoutBlock(hidden_irreps)) - -# def forward( -# self, -# data: Dict[str, torch.Tensor], -# training: bool = False, -# compute_force: bool = True, -# compute_virials: bool = False, -# compute_stress: bool = False, -# compute_displacement: bool = False, -# compute_hessian: bool = False, -# ) -> Dict[str, Optional[torch.Tensor]]: -# # Setup -# data["node_attrs"].requires_grad_(True) -# data["positions"].requires_grad_(True) -# num_graphs = data["ptr"].numel() - 1 -# displacement = torch.zeros( -# (num_graphs, 3, 3), -# dtype=data["positions"].dtype, -# device=data["positions"].device, -# ) -# if compute_virials or compute_stress or compute_displacement: -# ( -# data["positions"], -# data["shifts"], -# displacement, -# ) = get_symmetric_displacement( -# positions=data["positions"], -# unit_shifts=data["unit_shifts"], -# cell=data["cell"], -# edge_index=data["edge_index"], -# num_graphs=num_graphs, -# batch=data["batch"], -# ) - -# # Atomic energies -# node_e0 = self.atomic_energies_fn(data["node_attrs"]) -# e0 = scatter_sum( -# src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs -# ) # [n_graphs,] -# # Embeddings -# node_feats = self.node_embedding(data["node_attrs"]) -# vectors, lengths = get_edge_vectors_and_lengths( -# positions=data["positions"], -# edge_index=data["edge_index"], -# shifts=data["shifts"], -# ) -# edge_attrs = self.spherical_harmonics(vectors) -# edge_feats = self.radial_embedding( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) -# if hasattr(self, "pair_repulsion"): -# pair_node_energy = self.pair_repulsion_fn( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) -# pair_energy = scatter_sum( -# src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs -# ) # [n_graphs,] -# else: -# pair_node_energy = torch.zeros_like(node_e0) -# pair_energy = torch.zeros_like(e0) - -# # Interactions -# energies = [e0, pair_energy] -# node_energies_list = [node_e0, pair_node_energy] -# node_feats_list = [] -# for interaction, product, readout in zip( -# self.interactions, self.products, self.readouts -# ): -# node_feats, sc = interaction( -# node_attrs=data["node_attrs"], -# node_feats=node_feats, -# edge_attrs=edge_attrs, -# edge_feats=edge_feats, -# edge_index=data["edge_index"], -# ) -# node_feats = product( -# node_feats=node_feats, -# sc=sc, -# node_attrs=data["node_attrs"], -# ) -# node_feats_list.append(node_feats) -# node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] -# energy = scatter_sum( -# src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs -# ) # [n_graphs,] -# energies.append(energy) -# node_energies_list.append(node_energies) - -# # Concatenate node features -# node_feats_out = torch.cat(node_feats_list, dim=-1) - -# # Sum over energy contributions -# contributions = torch.stack(energies, dim=-1) -# total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] -# node_energy_contributions = torch.stack(node_energies_list, dim=-1) -# node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - -# # Outputs -# forces, virials, stress, hessian = get_outputs( -# energy=total_energy, -# positions=data["positions"], -# displacement=displacement, -# cell=data["cell"], -# training=training, -# compute_force=compute_force, -# compute_virials=compute_virials, -# compute_stress=compute_stress, -# compute_hessian=compute_hessian, -# ) - -# return { -# "energy": total_energy, -# "node_energy": node_energy, -# "contributions": contributions, -# "forces": forces, -# "virials": virials, -# "stress": stress, -# "displacement": displacement, -# "hessian": hessian, -# "node_feats": node_feats_out, -# } - - -# @compile_mode("script") -# class ScaleShiftMACE(MACE): -# def __init__( -# self, -# atomic_inter_scale: float, -# atomic_inter_shift: float, -# **kwargs, -# ): -# super().__init__(**kwargs) -# self.scale_shift = ScaleShiftBlock( -# scale=atomic_inter_scale, shift=atomic_inter_shift -# ) - -# def forward( -# self, -# data: Dict[str, torch.Tensor], -# training: bool = False, -# compute_force: bool = True, -# compute_virials: bool = False, -# compute_stress: bool = False, -# compute_displacement: bool = False, -# compute_hessian: bool = False, -# ) -> Dict[str, Optional[torch.Tensor]]: -# # Setup -# data["positions"].requires_grad_(True) -# data["node_attrs"].requires_grad_(True) -# num_graphs = data["ptr"].numel() - 1 -# displacement = torch.zeros( -# (num_graphs, 3, 3), -# dtype=data["positions"].dtype, -# device=data["positions"].device, -# ) -# if compute_virials or compute_stress or compute_displacement: -# ( -# data["positions"], -# data["shifts"], -# displacement, -# ) = get_symmetric_displacement( -# positions=data["positions"], -# unit_shifts=data["unit_shifts"], -# cell=data["cell"], -# edge_index=data["edge_index"], -# num_graphs=num_graphs, -# batch=data["batch"], -# ) - -# # Atomic energies -# node_e0 = self.atomic_energies_fn(data["node_attrs"]) -# e0 = scatter_sum( -# src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs -# ) # [n_graphs,] - -# # Embeddings -# node_feats = self.node_embedding(data["node_attrs"]) -# vectors, lengths = get_edge_vectors_and_lengths( -# positions=data["positions"], -# edge_index=data["edge_index"], -# shifts=data["shifts"], -# ) -# edge_attrs = self.spherical_harmonics(vectors) -# edge_feats = self.radial_embedding( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) -# if hasattr(self, "pair_repulsion"): -# pair_node_energy = self.pair_repulsion_fn( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) -# else: -# pair_node_energy = torch.zeros_like(node_e0) -# # Interactions -# node_es_list = [pair_node_energy] -# node_feats_list = [] -# for interaction, product, readout in zip( -# self.interactions, self.products, self.readouts -# ): -# node_feats, sc = interaction( -# node_attrs=data["node_attrs"], -# node_feats=node_feats, -# edge_attrs=edge_attrs, -# edge_feats=edge_feats, -# edge_index=data["edge_index"], -# ) -# node_feats = product( -# node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] -# ) -# node_feats_list.append(node_feats) -# node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } -# # Concatenate node features -# node_feats_out = torch.cat(node_feats_list, dim=-1) -# # print("node_es_list", node_es_list) -# # Sum over interactions -# node_inter_es = torch.sum( -# torch.stack(node_es_list, dim=0), dim=0 -# ) # [n_nodes, ] -# node_inter_es = self.scale_shift(node_inter_es) - -# # Sum over nodes in graph -# inter_e = scatter_sum( -# src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs -# ) # [n_graphs,] - -# # Add E_0 and (scaled) interaction energy -# total_energy = e0 + inter_e -# node_energy = node_e0 + node_inter_es -# forces, virials, stress, hessian = get_outputs( -# energy=inter_e, -# positions=data["positions"], -# displacement=displacement, -# cell=data["cell"], -# training=training, -# compute_force=compute_force, -# compute_virials=compute_virials, -# compute_stress=compute_stress, -# compute_hessian=compute_hessian, -# ) -# output = { -# "energy": total_energy, -# "node_energy": node_energy, -# "interaction_energy": inter_e, -# "forces": forces, -# "virials": virials, -# "stress": stress, -# "hessian": hessian, -# "displacement": displacement, -# "node_feats": node_feats_out, -# } - -# return output - - -# class BOTNet(torch.nn.Module): -# def __init__( -# self, -# r_max: float, -# num_bessel: int, -# num_polynomial_cutoff: int, -# max_ell: int, -# interaction_cls: Type[InteractionBlock], -# interaction_cls_first: Type[InteractionBlock], -# num_interactions: int, -# num_elements: int, -# hidden_irreps: o3.Irreps, -# MLP_irreps: o3.Irreps, -# atomic_energies: np.ndarray, -# gate: Optional[Callable], -# avg_num_neighbors: float, -# atomic_numbers: List[int], -# ): -# super().__init__() -# self.r_max = r_max -# self.atomic_numbers = atomic_numbers -# # Embedding -# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) -# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) -# self.node_embedding = LinearNodeEmbeddingBlock( -# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps -# ) -# self.radial_embedding = RadialEmbeddingBlock( -# r_max=r_max, -# num_bessel=num_bessel, -# num_polynomial_cutoff=num_polynomial_cutoff, -# ) -# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - -# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) -# self.spherical_harmonics = o3.SphericalHarmonics( -# sh_irreps, normalize=True, normalization="component" -# ) - -# # Interactions and readouts -# self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - -# self.interactions = torch.nn.ModuleList() -# self.readouts = torch.nn.ModuleList() - -# inter = interaction_cls_first( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=node_feats_irreps, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=hidden_irreps, -# avg_num_neighbors=avg_num_neighbors, -# ) -# self.interactions.append(inter) -# self.readouts.append(LinearReadoutBlock(inter.irreps_out)) - -# for i in range(num_interactions - 1): -# inter = interaction_cls( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=inter.irreps_out, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=hidden_irreps, -# avg_num_neighbors=avg_num_neighbors, -# ) -# self.interactions.append(inter) -# if i == num_interactions - 2: -# self.readouts.append( -# NonLinearReadoutBlock(inter.irreps_out, MLP_irreps, gate) -# ) -# else: -# self.readouts.append(LinearReadoutBlock(inter.irreps_out)) - -# def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: -# # Setup -# data.positions.requires_grad = True - -# # Atomic energies -# node_e0 = self.atomic_energies_fn(data.node_attrs) -# e0 = scatter_sum( -# src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs -# ) # [n_graphs,] - -# # Embeddings -# node_feats = self.node_embedding(data.node_attrs) -# vectors, lengths = get_edge_vectors_and_lengths( -# positions=data.positions, edge_index=data.edge_index, shifts=data.shifts -# ) -# edge_attrs = self.spherical_harmonics(vectors) -# edge_feats = self.radial_embedding( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) - -# # Interactions -# energies = [e0] -# for interaction, readout in zip(self.interactions, self.readouts): -# node_feats = interaction( -# node_attrs=data.node_attrs, -# node_feats=node_feats, -# edge_attrs=edge_attrs, -# edge_feats=edge_feats, -# edge_index=data.edge_index, -# ) -# node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] -# energy = scatter_sum( -# src=node_energies, index=data.batch, dim=-1, dim_size=data.num_graphs -# ) # [n_graphs,] -# energies.append(energy) - -# # Sum over energy contributions -# contributions = torch.stack(energies, dim=-1) -# total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - -# output = { -# "energy": total_energy, -# "contributions": contributions, -# "forces": compute_forces( -# energy=total_energy, positions=data.positions, training=training -# ), -# } - -# return output - - -# class ScaleShiftBOTNet(BOTNet): -# def __init__( -# self, -# atomic_inter_scale: float, -# atomic_inter_shift: float, -# **kwargs, -# ): -# super().__init__(**kwargs) -# self.scale_shift = ScaleShiftBlock( -# scale=atomic_inter_scale, shift=atomic_inter_shift -# ) - -# def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: -# # Setup -# data.positions.requires_grad = True - -# # Atomic energies -# node_e0 = self.atomic_energies_fn(data.node_attrs) -# e0 = scatter_sum( -# src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs -# ) # [n_graphs,] - -# # Embeddings -# node_feats = self.node_embedding(data.node_attrs) -# vectors, lengths = get_edge_vectors_and_lengths( -# positions=data.positions, edge_index=data.edge_index, shifts=data.shifts -# ) -# edge_attrs = self.spherical_harmonics(vectors) -# edge_feats = self.radial_embedding( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) - -# # Interactions -# node_es_list = [] -# for interaction, readout in zip(self.interactions, self.readouts): -# node_feats = interaction( -# node_attrs=data.node_attrs, -# node_feats=node_feats, -# edge_attrs=edge_attrs, -# edge_feats=edge_feats, -# edge_index=data.edge_index, -# ) - -# node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } - -# # Sum over interactions -# node_inter_es = torch.sum( -# torch.stack(node_es_list, dim=0), dim=0 -# ) # [n_nodes, ] -# node_inter_es = self.scale_shift(node_inter_es) - -# # Sum over nodes in graph -# inter_e = scatter_sum( -# src=node_inter_es, index=data.batch, dim=-1, dim_size=data.num_graphs -# ) # [n_graphs,] - -# # Add E_0 and (scaled) interaction energy -# total_e = e0 + inter_e - -# output = { -# "energy": total_e, -# "forces": compute_forces( -# energy=inter_e, positions=data.positions, training=training -# ), -# } - -# return output - - -# @compile_mode("script") -# class AtomicDipolesMACE(torch.nn.Module): -# def __init__( -# self, -# r_max: float, -# num_bessel: int, -# num_polynomial_cutoff: int, -# max_ell: int, -# interaction_cls: Type[InteractionBlock], -# interaction_cls_first: Type[InteractionBlock], -# num_interactions: int, -# num_elements: int, -# hidden_irreps: o3.Irreps, -# MLP_irreps: o3.Irreps, -# avg_num_neighbors: float, -# atomic_numbers: List[int], -# correlation: int, -# gate: Optional[Callable], -# atomic_energies: Optional[ -# None -# ], # Just here to make it compatible with energy models, MUST be None -# radial_type: Optional[str] = "bessel", -# radial_MLP: Optional[List[int]] = None, -# ): -# super().__init__() -# self.register_buffer( -# "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) -# ) -# self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) -# self.register_buffer( -# "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) -# ) -# assert atomic_energies is None - -# # Embedding -# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) -# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) -# self.node_embedding = LinearNodeEmbeddingBlock( -# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps -# ) -# self.radial_embedding = RadialEmbeddingBlock( -# r_max=r_max, -# num_bessel=num_bessel, -# num_polynomial_cutoff=num_polynomial_cutoff, -# radial_type=radial_type, -# ) -# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - -# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) -# num_features = hidden_irreps.count(o3.Irrep(0, 1)) -# interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() -# self.spherical_harmonics = o3.SphericalHarmonics( -# sh_irreps, normalize=True, normalization="component" -# ) -# if radial_MLP is None: -# radial_MLP = [64, 64, 64] - -# # Interactions and readouts -# inter = interaction_cls_first( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=node_feats_irreps, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=interaction_irreps, -# hidden_irreps=hidden_irreps, -# avg_num_neighbors=avg_num_neighbors, -# radial_MLP=radial_MLP, -# ) -# self.interactions = torch.nn.ModuleList([inter]) - -# # Use the appropriate self connection at the first layer -# use_sc_first = False -# if "Residual" in str(interaction_cls_first): -# use_sc_first = True - -# node_feats_irreps_out = inter.target_irreps -# prod = EquivariantProductBasisBlock( -# node_feats_irreps=node_feats_irreps_out, -# target_irreps=hidden_irreps, -# correlation=correlation, -# num_elements=num_elements, -# use_sc=use_sc_first, -# ) -# self.products = torch.nn.ModuleList([prod]) - -# self.readouts = torch.nn.ModuleList() -# self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) - -# for i in range(num_interactions - 1): -# if i == num_interactions - 2: -# assert ( -# len(hidden_irreps) > 1 -# ), "To predict dipoles use at least l=1 hidden_irreps" -# hidden_irreps_out = str( -# hidden_irreps[1] -# ) # Select only l=1 vectors for last layer -# else: -# hidden_irreps_out = hidden_irreps -# inter = interaction_cls( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=hidden_irreps, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=interaction_irreps, -# hidden_irreps=hidden_irreps_out, -# avg_num_neighbors=avg_num_neighbors, -# radial_MLP=radial_MLP, -# ) -# self.interactions.append(inter) -# prod = EquivariantProductBasisBlock( -# node_feats_irreps=interaction_irreps, -# target_irreps=hidden_irreps_out, -# correlation=correlation, -# num_elements=num_elements, -# use_sc=True, -# ) -# self.products.append(prod) -# if i == num_interactions - 2: -# self.readouts.append( -# NonLinearDipoleReadoutBlock( -# hidden_irreps_out, MLP_irreps, gate, dipole_only=True -# ) -# ) -# else: -# self.readouts.append( -# LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) -# ) - -# def forward( -# self, -# data: Dict[str, torch.Tensor], -# training: bool = False, # pylint: disable=W0613 -# compute_force: bool = False, -# compute_virials: bool = False, -# compute_stress: bool = False, -# compute_displacement: bool = False, -# ) -> Dict[str, Optional[torch.Tensor]]: -# assert compute_force is False -# assert compute_virials is False -# assert compute_stress is False -# assert compute_displacement is False -# # Setup -# data["node_attrs"].requires_grad_(True) -# data["positions"].requires_grad_(True) -# num_graphs = data["ptr"].numel() - 1 - -# # Embeddings -# node_feats = self.node_embedding(data["node_attrs"]) -# vectors, lengths = get_edge_vectors_and_lengths( -# positions=data["positions"], -# edge_index=data["edge_index"], -# shifts=data["shifts"], -# ) -# edge_attrs = self.spherical_harmonics(vectors) -# edge_feats = self.radial_embedding( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) - -# # Interactions -# dipoles = [] -# for interaction, product, readout in zip( -# self.interactions, self.products, self.readouts -# ): -# node_feats, sc = interaction( -# node_attrs=data["node_attrs"], -# node_feats=node_feats, -# edge_attrs=edge_attrs, -# edge_feats=edge_feats, -# edge_index=data["edge_index"], -# ) -# node_feats = product( -# node_feats=node_feats, -# sc=sc, -# node_attrs=data["node_attrs"], -# ) -# node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] -# dipoles.append(node_dipoles) - -# # Compute the dipoles -# contributions_dipoles = torch.stack( -# dipoles, dim=-1 -# ) # [n_nodes,3,n_contributions] -# atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] -# total_dipole = scatter_sum( -# src=atomic_dipoles, -# index=data["batch"], -# dim=0, -# dim_size=num_graphs, -# ) # [n_graphs,3] -# baseline = compute_fixed_charge_dipole( -# charges=data["charges"], -# positions=data["positions"], -# batch=data["batch"], -# num_graphs=num_graphs, -# ) # [n_graphs,3] -# total_dipole = total_dipole + baseline - -# output = { -# "dipole": total_dipole, -# "atomic_dipoles": atomic_dipoles, -# } -# return output - - -# @compile_mode("script") -# class EnergyDipolesMACE(torch.nn.Module): -# def __init__( -# self, -# r_max: float, -# num_bessel: int, -# num_polynomial_cutoff: int, -# max_ell: int, -# interaction_cls: Type[InteractionBlock], -# interaction_cls_first: Type[InteractionBlock], -# num_interactions: int, -# num_elements: int, -# hidden_irreps: o3.Irreps, -# MLP_irreps: o3.Irreps, -# avg_num_neighbors: float, -# atomic_numbers: List[int], -# correlation: int, -# gate: Optional[Callable], -# atomic_energies: Optional[np.ndarray], -# radial_MLP: Optional[List[int]] = None, -# ): -# super().__init__() -# self.register_buffer( -# "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) -# ) -# self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) -# self.register_buffer( -# "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) -# ) -# # Embedding -# node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) -# node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) -# self.node_embedding = LinearNodeEmbeddingBlock( -# irreps_in=node_attr_irreps, irreps_out=node_feats_irreps -# ) -# self.radial_embedding = RadialEmbeddingBlock( -# r_max=r_max, -# num_bessel=num_bessel, -# num_polynomial_cutoff=num_polynomial_cutoff, -# ) -# edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - -# sh_irreps = o3.Irreps.spherical_harmonics(max_ell) -# num_features = hidden_irreps.count(o3.Irrep(0, 1)) -# interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() -# self.spherical_harmonics = o3.SphericalHarmonics( -# sh_irreps, normalize=True, normalization="component" -# ) -# if radial_MLP is None: -# radial_MLP = [64, 64, 64] -# # Interactions and readouts -# self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - -# inter = interaction_cls_first( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=node_feats_irreps, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=interaction_irreps, -# hidden_irreps=hidden_irreps, -# avg_num_neighbors=avg_num_neighbors, -# radial_MLP=radial_MLP, -# ) -# self.interactions = torch.nn.ModuleList([inter]) - -# # Use the appropriate self connection at the first layer -# use_sc_first = False -# if "Residual" in str(interaction_cls_first): -# use_sc_first = True - -# node_feats_irreps_out = inter.target_irreps -# prod = EquivariantProductBasisBlock( -# node_feats_irreps=node_feats_irreps_out, -# target_irreps=hidden_irreps, -# correlation=correlation, -# num_elements=num_elements, -# use_sc=use_sc_first, -# ) -# self.products = torch.nn.ModuleList([prod]) - -# self.readouts = torch.nn.ModuleList() -# self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) - -# for i in range(num_interactions - 1): -# if i == num_interactions - 2: -# assert ( -# len(hidden_irreps) > 1 -# ), "To predict dipoles use at least l=1 hidden_irreps" -# hidden_irreps_out = str( -# hidden_irreps[:2] -# ) # Select scalars and l=1 vectors for last layer -# else: -# hidden_irreps_out = hidden_irreps -# inter = interaction_cls( -# node_attrs_irreps=node_attr_irreps, -# node_feats_irreps=hidden_irreps, -# edge_attrs_irreps=sh_irreps, -# edge_feats_irreps=edge_feats_irreps, -# target_irreps=interaction_irreps, -# hidden_irreps=hidden_irreps_out, -# avg_num_neighbors=avg_num_neighbors, -# radial_MLP=radial_MLP, -# ) -# self.interactions.append(inter) -# prod = EquivariantProductBasisBlock( -# node_feats_irreps=interaction_irreps, -# target_irreps=hidden_irreps_out, -# correlation=correlation, -# num_elements=num_elements, -# use_sc=True, -# ) -# self.products.append(prod) -# if i == num_interactions - 2: -# self.readouts.append( -# NonLinearDipoleReadoutBlock( -# hidden_irreps_out, MLP_irreps, gate, dipole_only=False -# ) -# ) -# else: -# self.readouts.append( -# LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) -# ) - -# def forward( -# self, -# data: Dict[str, torch.Tensor], -# training: bool = False, -# compute_force: bool = True, -# compute_virials: bool = False, -# compute_stress: bool = False, -# compute_displacement: bool = False, -# ) -> Dict[str, Optional[torch.Tensor]]: -# # Setup -# data["node_attrs"].requires_grad_(True) -# data["positions"].requires_grad_(True) -# num_graphs = data["ptr"].numel() - 1 -# displacement = torch.zeros( -# (num_graphs, 3, 3), -# dtype=data["positions"].dtype, -# device=data["positions"].device, -# ) -# if compute_virials or compute_stress or compute_displacement: -# ( -# data["positions"], -# data["shifts"], -# displacement, -# ) = get_symmetric_displacement( -# positions=data["positions"], -# unit_shifts=data["unit_shifts"], -# cell=data["cell"], -# edge_index=data["edge_index"], -# num_graphs=num_graphs, -# batch=data["batch"], -# ) - -# # Atomic energies -# node_e0 = self.atomic_energies_fn(data["node_attrs"]) -# e0 = scatter_sum( -# src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs -# ) # [n_graphs,] - -# # Embeddings -# node_feats = self.node_embedding(data["node_attrs"]) -# vectors, lengths = get_edge_vectors_and_lengths( -# positions=data["positions"], -# edge_index=data["edge_index"], -# shifts=data["shifts"], -# ) -# edge_attrs = self.spherical_harmonics(vectors) -# edge_feats = self.radial_embedding( -# lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers -# ) - -# # Interactions -# energies = [e0] -# node_energies_list = [node_e0] -# dipoles = [] -# for interaction, product, readout in zip( -# self.interactions, self.products, self.readouts -# ): -# node_feats, sc = interaction( -# node_attrs=data["node_attrs"], -# node_feats=node_feats, -# edge_attrs=edge_attrs, -# edge_feats=edge_feats, -# edge_index=data["edge_index"], -# ) -# node_feats = product( -# node_feats=node_feats, -# sc=sc, -# node_attrs=data["node_attrs"], -# ) -# node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] -# # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] -# node_energies = node_out[:, 0] -# energy = scatter_sum( -# src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs -# ) # [n_graphs,] -# energies.append(energy) -# node_dipoles = node_out[:, 1:] -# dipoles.append(node_dipoles) - -# # Compute the energies and dipoles -# contributions = torch.stack(energies, dim=-1) -# total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] -# node_energy_contributions = torch.stack(node_energies_list, dim=-1) -# node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] -# contributions_dipoles = torch.stack( -# dipoles, dim=-1 -# ) # [n_nodes,3,n_contributions] -# atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] -# total_dipole = scatter_sum( -# src=atomic_dipoles, -# index=data["batch"].unsqueeze(-1), -# dim=0, -# dim_size=num_graphs, -# ) # [n_graphs,3] -# baseline = compute_fixed_charge_dipole( -# charges=data["charges"], -# positions=data["positions"], -# batch=data["batch"], -# num_graphs=num_graphs, -# ) # [n_graphs,3] -# total_dipole = total_dipole + baseline - -# forces, virials, stress, _ = get_outputs( -# energy=total_energy, -# positions=data["positions"], -# displacement=displacement, -# cell=data["cell"], -# training=training, -# compute_force=compute_force, -# compute_virials=compute_virials, -# compute_stress=compute_stress, -# ) - -# output = { -# "energy": total_energy, -# "node_energy": node_energy, -# "contributions": contributions, -# "forces": forces, -# "virials": virials, -# "stress": stress, -# "displacement": displacement, -# "dipole": total_dipole, -# "atomic_dipoles": atomic_dipoles, -# } -# return output diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index a2e569475..b66cc897e 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -20,196 +20,6 @@ from .blocks import AtomicEnergiesBlock -# def compute_forces( -# energy: torch.Tensor, positions: torch.Tensor, training: bool = True -# ) -> torch.Tensor: -# grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] -# gradient = torch.autograd.grad( -# outputs=[energy], # [n_graphs, ] -# inputs=[positions], # [n_nodes, 3] -# grad_outputs=grad_outputs, -# retain_graph=training, # Make sure the graph is not destroyed during training -# create_graph=training, # Create graph for second derivative -# allow_unused=True, # For complete dissociation turn to true -# )[ -# 0 -# ] # [n_nodes, 3] -# if gradient is None: -# return torch.zeros_like(positions) -# return -1 * gradient - - -# def compute_forces_virials( -# energy: torch.Tensor, -# positions: torch.Tensor, -# displacement: torch.Tensor, -# cell: torch.Tensor, -# training: bool = True, -# compute_stress: bool = False, -# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: -# grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] -# forces, virials = torch.autograd.grad( -# outputs=[energy], # [n_graphs, ] -# inputs=[positions, displacement], # [n_nodes, 3] -# grad_outputs=grad_outputs, -# retain_graph=training, # Make sure the graph is not destroyed during training -# create_graph=training, # Create graph for second derivative -# allow_unused=True, -# ) -# stress = torch.zeros_like(displacement) -# if compute_stress and virials is not None: -# cell = cell.view(-1, 3, 3) -# volume = torch.linalg.det(cell).abs().unsqueeze(-1) -# stress = virials / volume.view(-1, 1, 1) -# stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) -# if forces is None: -# forces = torch.zeros_like(positions) -# if virials is None: -# virials = torch.zeros((1, 3, 3)) - -# return -1 * forces, -1 * virials, stress - - -# def get_symmetric_displacement( -# positions: torch.Tensor, -# unit_shifts: torch.Tensor, -# cell: Optional[torch.Tensor], -# edge_index: torch.Tensor, -# num_graphs: int, -# batch: torch.Tensor, -# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: -# if cell is None: -# cell = torch.zeros( -# num_graphs * 3, -# 3, -# dtype=positions.dtype, -# device=positions.device, -# ) -# sender = edge_index[0] -# displacement = torch.zeros( -# (num_graphs, 3, 3), -# dtype=positions.dtype, -# device=positions.device, -# ) -# displacement.requires_grad_(True) -# symmetric_displacement = 0.5 * ( -# displacement + displacement.transpose(-1, -2) -# ) # From https://github.com/mir-group/nequip -# positions = positions + torch.einsum( -# "be,bec->bc", positions, symmetric_displacement[batch] -# ) -# cell = cell.view(-1, 3, 3) -# cell = cell + torch.matmul(cell, symmetric_displacement) -# shifts = torch.einsum( -# "be,bec->bc", -# unit_shifts, -# cell[batch[sender]], -# ) -# return positions, shifts, displacement - - -# @torch.jit.unused -# def compute_hessians_vmap( -# forces: torch.Tensor, -# positions: torch.Tensor, -# ) -> torch.Tensor: -# forces_flatten = forces.view(-1) -# num_elements = forces_flatten.shape[0] - -# def get_vjp(v): -# return torch.autograd.grad( -# -1 * forces_flatten, -# positions, -# v, -# retain_graph=True, -# create_graph=False, -# allow_unused=False, -# ) - -# I_N = torch.eye(num_elements).to(forces.device) -# try: -# chunk_size = 1 if num_elements < 64 else 16 -# gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( -# I_N -# )[0] -# except RuntimeError: -# gradient = compute_hessians_loop(forces, positions) -# if gradient is None: -# return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) -# return gradient - - -# @torch.jit.unused -# def compute_hessians_loop( -# forces: torch.Tensor, -# positions: torch.Tensor, -# ) -> torch.Tensor: - -# hessian = [] -# for grad_elem in forces.view(-1): -# hess_row = torch.autograd.grad( -# outputs=[-1 * grad_elem], -# inputs=[positions], -# grad_outputs=torch.ones_like(grad_elem), -# retain_graph=True, -# create_graph=False, -# allow_unused=False, -# )[0] -# hess_row = hess_row.detach() # this makes it very slow? but needs less memory -# if hess_row is None: -# hessian.append(torch.zeros_like(positions)) -# else: -# hessian.append(hess_row) -# hessian = torch.stack(hessian) -# return hessian - - -# def get_outputs( -# energy: torch.Tensor, -# positions: torch.Tensor, -# displacement: Optional[torch.Tensor], -# cell: torch.Tensor, -# training: bool = False, -# compute_force: bool = True, -# compute_virials: bool = True, -# compute_stress: bool = True, -# compute_hessian: bool = False, -# ) -> Tuple[ -# Optional[torch.Tensor], -# Optional[torch.Tensor], -# Optional[torch.Tensor], -# Optional[torch.Tensor], -# ]: -# if (compute_virials or compute_stress) and displacement is not None: -# # forces come for free -# forces, virials, stress = compute_forces_virials( -# energy=energy, -# positions=positions, -# displacement=displacement, -# cell=cell, -# compute_stress=compute_stress, -# training=(training or compute_hessian), -# ) -# elif compute_force: -# forces, virials, stress = ( -# compute_forces( -# energy=energy, -# positions=positions, -# training=(training or compute_hessian), -# ), -# None, -# None, -# ) -# else: -# forces, virials, stress = (None, None, None) -# if compute_hessian: -# assert forces is not None, "Forces must be computed to get the hessian" -# hessian = compute_hessians_vmap(forces, positions) -# else: -# hessian = None -# return forces, virials, stress, hessian - - def get_edge_vectors_and_lengths( positions: torch.Tensor, # [n_nodes, 3] edge_index: torch.Tensor, # [2, n_edges] @@ -251,164 +61,3 @@ def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max ) out.append(x[:, -num_features:]) return torch.cat(out, dim=-1) - - -# def compute_mean_std_atomic_inter_energy( -# data_loader: torch.utils.data.DataLoader, -# atomic_energies: np.ndarray, -# ) -> Tuple[float, float]: -# atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - -# avg_atom_inter_es_list = [] - -# for batch in data_loader: -# node_e0 = atomic_energies_fn(batch.node_attrs) -# graph_e0s = scatter_sum( -# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs -# ) -# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] -# avg_atom_inter_es_list.append( -# (batch.energy - graph_e0s) / graph_sizes -# ) # {[n_graphs], } - -# avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] -# mean = to_numpy(torch.mean(avg_atom_inter_es)).item() -# std = to_numpy(torch.std(avg_atom_inter_es)).item() -# std = _check_non_zero(std) - -# return mean, std - - -# def _compute_mean_std_atomic_inter_energy( -# batch: Batch, -# atomic_energies_fn: AtomicEnergiesBlock, -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# node_e0 = atomic_energies_fn(batch.node_attrs) -# graph_e0s = scatter_sum( -# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs -# ) -# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] -# atom_energies = (batch.energy - graph_e0s) / graph_sizes -# return atom_energies - - -# def compute_mean_rms_energy_forces( -# data_loader: torch.utils.data.DataLoader, -# atomic_energies: np.ndarray, -# ) -> Tuple[float, float]: -# atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - -# atom_energy_list = [] -# forces_list = [] - -# for batch in data_loader: -# node_e0 = atomic_energies_fn(batch.node_attrs) -# graph_e0s = scatter_sum( -# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs -# ) -# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] -# atom_energy_list.append( -# (batch.energy - graph_e0s) / graph_sizes -# ) # {[n_graphs], } -# forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - -# atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] -# forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - -# mean = to_numpy(torch.mean(atom_energies)).item() -# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() -# rms = _check_non_zero(rms) - -# return mean, rms - - -# def _compute_mean_rms_energy_forces( -# batch: Batch, -# atomic_energies_fn: AtomicEnergiesBlock, -# ) -> Tuple[torch.Tensor, torch.Tensor]: -# node_e0 = atomic_energies_fn(batch.node_attrs) -# graph_e0s = scatter_sum( -# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs -# ) -# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] -# atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } -# forces = batch.forces # {[n_graphs*n_atoms,3], } - -# return atom_energies, forces - - -# def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: -# num_neighbors = [] - -# for batch in data_loader: -# _, receivers = batch.edge_index -# _, counts = torch.unique(receivers, return_counts=True) -# num_neighbors.append(counts) - -# avg_num_neighbors = torch.mean( -# torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) -# ) -# return to_numpy(avg_num_neighbors).item() - - -# def compute_statistics( -# data_loader: torch.utils.data.DataLoader, -# atomic_energies: np.ndarray, -# ) -> Tuple[float, float, float, float]: -# atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - -# atom_energy_list = [] -# forces_list = [] -# num_neighbors = [] - -# for batch in data_loader: -# node_e0 = atomic_energies_fn(batch.node_attrs) -# graph_e0s = scatter_sum( -# src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs -# ) -# graph_sizes = batch.ptr[1:] - batch.ptr[:-1] -# atom_energy_list.append( -# (batch.energy - graph_e0s) / graph_sizes -# ) # {[n_graphs], } -# forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - -# _, receivers = batch.edge_index -# _, counts = torch.unique(receivers, return_counts=True) -# num_neighbors.append(counts) - -# atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] -# forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - -# mean = to_numpy(torch.mean(atom_energies)).item() -# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() - -# avg_num_neighbors = torch.mean( -# torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) -# ) - -# return to_numpy(avg_num_neighbors).item(), mean, rms - - -# def compute_rms_dipoles( -# data_loader: torch.utils.data.DataLoader, -# ) -> Tuple[float, float]: -# dipoles_list = [] -# for batch in data_loader: -# dipoles_list.append(batch.dipole) # {[n_graphs,3], } - -# dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } -# rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() -# rms = _check_non_zero(rms) -# return rms - - -# def compute_fixed_charge_dipole( -# charges: torch.Tensor, -# positions: torch.Tensor, -# batch: torch.Tensor, -# num_graphs: int, -# ) -> torch.Tensor: -# mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] -# return scatter_sum( -# src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs -# ) # [N_graphs,3] diff --git a/tests/test_forces_equivariant.py b/tests/test_forces_equivariant.py index 8573312ec..5ff770adc 100644 --- a/tests/test_forces_equivariant.py +++ b/tests/test_forces_equivariant.py @@ -26,3 +26,7 @@ def pytest_examples(example, model_type): # Check the file ran without error. assert return_code == 0 + + +# if __name__ == "__main__": +# pytest_examples("LennardJones", "MACE") From 5bb5cb6f0472af5750fbf376085843803785d7aa Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 27 Sep 2024 11:45:45 -0400 Subject: [PATCH 29/51] Commenting MACE utils in torch geometric (draft 2) --- hydragnn/utils/mace_utils/modules/__init__.py | 10 - hydragnn/utils/mace_utils/modules/utils.py | 2 +- .../tools/torch_geometric/__init__.py | 13 +- .../mace_utils/tools/torch_geometric/batch.py | 514 +++++----- .../mace_utils/tools/torch_geometric/data.py | 882 +++++++++--------- .../tools/torch_geometric/dataloader.py | 150 +-- .../tools/torch_geometric/dataset.py | 560 +++++------ .../mace_utils/tools/torch_geometric/seed.py | 26 +- .../mace_utils/tools/torch_geometric/utils.py | 78 +- 9 files changed, 1113 insertions(+), 1122 deletions(-) diff --git a/hydragnn/utils/mace_utils/modules/__init__.py b/hydragnn/utils/mace_utils/modules/__init__.py index 971aeb294..f461a338a 100644 --- a/hydragnn/utils/mace_utils/modules/__init__.py +++ b/hydragnn/utils/mace_utils/modules/__init__.py @@ -3,8 +3,6 @@ import torch from .blocks import ( - # AgnosticNonlinearInteractionBlock, - # AgnosticResidualNonlinearInteractionBlock, AtomicEnergiesBlock, EquivariantProductBasisBlock, InteractionBlock, @@ -15,9 +13,6 @@ NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, - # RealAgnosticInteractionBlock, - # RealAgnosticResidualInteractionBlock, - # ResidualElementDependentInteractionBlock, ScaleShiftBlock, ) @@ -25,12 +20,7 @@ from .symmetric_contraction import SymmetricContraction interaction_classes: Dict[str, Type[InteractionBlock]] = { - # "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, - # "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, - # "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock, - # "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, - # "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, } gate_dict: Dict[str, Optional[Callable]] = { diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index b66cc897e..ae952ed25 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -15,7 +15,7 @@ from hydragnn.utils.mace_utils.tools import to_numpy from hydragnn.utils.mace_utils.tools.scatter import scatter_sum -from hydragnn.utils.mace_utils.tools.torch_geometric.batch import Batch +# from hydragnn.utils.mace_utils.tools.torch_geometric.batch import Batch from .blocks import AtomicEnergiesBlock diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py b/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py index 486f0d09d..ea70a022f 100644 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py @@ -1,7 +1,8 @@ -from .batch import Batch -from .data import Data -from .dataloader import DataLoader -from .dataset import Dataset -from .seed import seed_everything +# from .batch import Batch +# from .data import Data +# from .dataloader import DataLoader +# from .dataset import Dataset +# from .seed import seed_everything -__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] +# __all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] +# __all__ = ["Data", "Dataset", "DataLoader", "seed_everything"] diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py b/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py index be5ec9d0c..8dfd3ddc1 100644 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py @@ -1,257 +1,257 @@ -from collections.abc import Sequence -from typing import List - -import numpy as np -import torch -from torch import Tensor - -from .data import Data -from .dataset import IndexType - - -class Batch(Data): - r"""A plain old python object modeling a batch of graphs as one big - (disconnected) graph. With :class:`torch_geometric.data.Data` being the - base class, all its methods can also be used here. - In addition, single graphs can be reconstructed via the assignment vector - :obj:`batch`, which maps each node to its respective graph identifier. - """ - - def __init__(self, batch=None, ptr=None, **kwargs): - super(Batch, self).__init__(**kwargs) - - for key, item in kwargs.items(): - if key == "num_nodes": - self.__num_nodes__ = item - else: - self[key] = item - - self.batch = batch - self.ptr = ptr - self.__data_class__ = Data - self.__slices__ = None - self.__cumsum__ = None - self.__cat_dims__ = None - self.__num_nodes_list__ = None - self.__num_graphs__ = None - - @classmethod - def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): - r"""Constructs a batch object from a python list holding - :class:`torch_geometric.data.Data` objects. - The assignment vector :obj:`batch` is created on the fly. - Additionally, creates assignment batch vectors for each key in - :obj:`follow_batch`. - Will exclude any keys given in :obj:`exclude_keys`.""" - - keys = list(set(data_list[0].keys) - set(exclude_keys)) - assert "batch" not in keys and "ptr" not in keys - - batch = cls() - for key in data_list[0].__dict__.keys(): - if key[:2] != "__" and key[-2:] != "__": - batch[key] = None - - batch.__num_graphs__ = len(data_list) - batch.__data_class__ = data_list[0].__class__ - for key in keys + ["batch"]: - batch[key] = [] - batch["ptr"] = [0] - - device = None - slices = {key: [0] for key in keys} - cumsum = {key: [0] for key in keys} - cat_dims = {} - num_nodes_list = [] - for i, data in enumerate(data_list): - for key in keys: - item = data[key] - - # Increase values by `cumsum` value. - cum = cumsum[key][-1] - if isinstance(item, Tensor) and item.dtype != torch.bool: - if not isinstance(cum, int) or cum != 0: - item = item + cum - elif isinstance(item, (int, float)): - item = item + cum - - # Gather the size of the `cat` dimension. - size = 1 - cat_dim = data.__cat_dim__(key, data[key]) - # 0-dimensional tensors have no dimension along which to - # concatenate, so we set `cat_dim` to `None`. - if isinstance(item, Tensor) and item.dim() == 0: - cat_dim = None - cat_dims[key] = cat_dim - - # Add a batch dimension to items whose `cat_dim` is `None`: - if isinstance(item, Tensor) and cat_dim is None: - cat_dim = 0 # Concatenate along this new batch dimension. - item = item.unsqueeze(0) - device = item.device - elif isinstance(item, Tensor): - size = item.size(cat_dim) - device = item.device - - batch[key].append(item) # Append item to the attribute list. - - slices[key].append(size + slices[key][-1]) - inc = data.__inc__(key, item) - if isinstance(inc, (tuple, list)): - inc = torch.tensor(inc) - cumsum[key].append(inc + cumsum[key][-1]) - - if key in follow_batch: - if isinstance(size, Tensor): - for j, size in enumerate(size.tolist()): - tmp = f"{key}_{j}_batch" - batch[tmp] = [] if i == 0 else batch[tmp] - batch[tmp].append( - torch.full((size,), i, dtype=torch.long, device=device) - ) - else: - tmp = f"{key}_batch" - batch[tmp] = [] if i == 0 else batch[tmp] - batch[tmp].append( - torch.full((size,), i, dtype=torch.long, device=device) - ) - - if hasattr(data, "__num_nodes__"): - num_nodes_list.append(data.__num_nodes__) - else: - num_nodes_list.append(None) - - num_nodes = data.num_nodes - if num_nodes is not None: - item = torch.full((num_nodes,), i, dtype=torch.long, device=device) - batch.batch.append(item) - batch.ptr.append(batch.ptr[-1] + num_nodes) - - batch.batch = None if len(batch.batch) == 0 else batch.batch - batch.ptr = None if len(batch.ptr) == 1 else batch.ptr - batch.__slices__ = slices - batch.__cumsum__ = cumsum - batch.__cat_dims__ = cat_dims - batch.__num_nodes_list__ = num_nodes_list - - ref_data = data_list[0] - for key in batch.keys: - items = batch[key] - item = items[0] - cat_dim = ref_data.__cat_dim__(key, item) - cat_dim = 0 if cat_dim is None else cat_dim - if isinstance(item, Tensor): - batch[key] = torch.cat(items, cat_dim) - elif isinstance(item, (int, float)): - batch[key] = torch.tensor(items) - - # if torch_geometric.is_debug_enabled(): - # batch.debug() - - return batch.contiguous() - - def get_example(self, idx: int) -> Data: - r"""Reconstructs the :class:`torch_geometric.data.Data` object at index - :obj:`idx` from the batch object. - The batch object must have been created via :meth:`from_data_list` in - order to be able to reconstruct the initial objects.""" - - if self.__slices__ is None: - raise RuntimeError( - ( - "Cannot reconstruct data list from batch because the batch " - "object was not created using `Batch.from_data_list()`." - ) - ) - - data = self.__data_class__() - idx = self.num_graphs + idx if idx < 0 else idx - - for key in self.__slices__.keys(): - item = self[key] - if self.__cat_dims__[key] is None: - # The item was concatenated along a new batch dimension, - # so just index in that dimension: - item = item[idx] - else: - # Narrow the item based on the values in `__slices__`. - if isinstance(item, Tensor): - dim = self.__cat_dims__[key] - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item.narrow(dim, start, end - start) - else: - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item[start:end] - item = item[0] if len(item) == 1 else item - - # Decrease its value by `cumsum` value: - cum = self.__cumsum__[key][idx] - if isinstance(item, Tensor): - if not isinstance(cum, int) or cum != 0: - item = item - cum - elif isinstance(item, (int, float)): - item = item - cum - - data[key] = item - - if self.__num_nodes_list__[idx] is not None: - data.num_nodes = self.__num_nodes_list__[idx] - - return data - - def index_select(self, idx: IndexType) -> List[Data]: - if isinstance(idx, slice): - idx = list(range(self.num_graphs)[idx]) - - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - idx = idx.flatten().tolist() - - elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() - - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - idx = idx.flatten().tolist() - - elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - idx = idx.flatten().nonzero()[0].flatten().tolist() - - elif isinstance(idx, Sequence) and not isinstance(idx, str): - pass - - else: - raise IndexError( - f"Only integers, slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')" - ) - - return [self.get_example(i) for i in idx] - - def __getitem__(self, idx): - if isinstance(idx, str): - return super(Batch, self).__getitem__(idx) - elif isinstance(idx, (int, np.integer)): - return self.get_example(idx) - else: - return self.index_select(idx) - - def to_data_list(self) -> List[Data]: - r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects - from the batch object. - The batch object must have been created via :meth:`from_data_list` in - order to be able to reconstruct the initial objects.""" - return [self.get_example(i) for i in range(self.num_graphs)] - - @property - def num_graphs(self) -> int: - """Returns the number of graphs in the batch.""" - if self.__num_graphs__ is not None: - return self.__num_graphs__ - elif self.ptr is not None: - return self.ptr.numel() - 1 - elif self.batch is not None: - return int(self.batch.max()) + 1 - else: - raise ValueError +# from collections.abc import Sequence +# from typing import List + +# import numpy as np +# import torch +# from torch import Tensor + +# from .data import Data +# from .dataset import IndexType + + +# class Batch(Data): +# r"""A plain old python object modeling a batch of graphs as one big +# (disconnected) graph. With :class:`torch_geometric.data.Data` being the +# base class, all its methods can also be used here. +# In addition, single graphs can be reconstructed via the assignment vector +# :obj:`batch`, which maps each node to its respective graph identifier. +# """ + +# def __init__(self, batch=None, ptr=None, **kwargs): +# super(Batch, self).__init__(**kwargs) + +# for key, item in kwargs.items(): +# if key == "num_nodes": +# self.__num_nodes__ = item +# else: +# self[key] = item + +# self.batch = batch +# self.ptr = ptr +# self.__data_class__ = Data +# self.__slices__ = None +# self.__cumsum__ = None +# self.__cat_dims__ = None +# self.__num_nodes_list__ = None +# self.__num_graphs__ = None + +# @classmethod +# def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): +# r"""Constructs a batch object from a python list holding +# :class:`torch_geometric.data.Data` objects. +# The assignment vector :obj:`batch` is created on the fly. +# Additionally, creates assignment batch vectors for each key in +# :obj:`follow_batch`. +# Will exclude any keys given in :obj:`exclude_keys`.""" + +# keys = list(set(data_list[0].keys) - set(exclude_keys)) +# assert "batch" not in keys and "ptr" not in keys + +# batch = cls() +# for key in data_list[0].__dict__.keys(): +# if key[:2] != "__" and key[-2:] != "__": +# batch[key] = None + +# batch.__num_graphs__ = len(data_list) +# batch.__data_class__ = data_list[0].__class__ +# for key in keys + ["batch"]: +# batch[key] = [] +# batch["ptr"] = [0] + +# device = None +# slices = {key: [0] for key in keys} +# cumsum = {key: [0] for key in keys} +# cat_dims = {} +# num_nodes_list = [] +# for i, data in enumerate(data_list): +# for key in keys: +# item = data[key] + +# # Increase values by `cumsum` value. +# cum = cumsum[key][-1] +# if isinstance(item, Tensor) and item.dtype != torch.bool: +# if not isinstance(cum, int) or cum != 0: +# item = item + cum +# elif isinstance(item, (int, float)): +# item = item + cum + +# # Gather the size of the `cat` dimension. +# size = 1 +# cat_dim = data.__cat_dim__(key, data[key]) +# # 0-dimensional tensors have no dimension along which to +# # concatenate, so we set `cat_dim` to `None`. +# if isinstance(item, Tensor) and item.dim() == 0: +# cat_dim = None +# cat_dims[key] = cat_dim + +# # Add a batch dimension to items whose `cat_dim` is `None`: +# if isinstance(item, Tensor) and cat_dim is None: +# cat_dim = 0 # Concatenate along this new batch dimension. +# item = item.unsqueeze(0) +# device = item.device +# elif isinstance(item, Tensor): +# size = item.size(cat_dim) +# device = item.device + +# batch[key].append(item) # Append item to the attribute list. + +# slices[key].append(size + slices[key][-1]) +# inc = data.__inc__(key, item) +# if isinstance(inc, (tuple, list)): +# inc = torch.tensor(inc) +# cumsum[key].append(inc + cumsum[key][-1]) + +# if key in follow_batch: +# if isinstance(size, Tensor): +# for j, size in enumerate(size.tolist()): +# tmp = f"{key}_{j}_batch" +# batch[tmp] = [] if i == 0 else batch[tmp] +# batch[tmp].append( +# torch.full((size,), i, dtype=torch.long, device=device) +# ) +# else: +# tmp = f"{key}_batch" +# batch[tmp] = [] if i == 0 else batch[tmp] +# batch[tmp].append( +# torch.full((size,), i, dtype=torch.long, device=device) +# ) + +# if hasattr(data, "__num_nodes__"): +# num_nodes_list.append(data.__num_nodes__) +# else: +# num_nodes_list.append(None) + +# num_nodes = data.num_nodes +# if num_nodes is not None: +# item = torch.full((num_nodes,), i, dtype=torch.long, device=device) +# batch.batch.append(item) +# batch.ptr.append(batch.ptr[-1] + num_nodes) + +# batch.batch = None if len(batch.batch) == 0 else batch.batch +# batch.ptr = None if len(batch.ptr) == 1 else batch.ptr +# batch.__slices__ = slices +# batch.__cumsum__ = cumsum +# batch.__cat_dims__ = cat_dims +# batch.__num_nodes_list__ = num_nodes_list + +# ref_data = data_list[0] +# for key in batch.keys: +# items = batch[key] +# item = items[0] +# cat_dim = ref_data.__cat_dim__(key, item) +# cat_dim = 0 if cat_dim is None else cat_dim +# if isinstance(item, Tensor): +# batch[key] = torch.cat(items, cat_dim) +# elif isinstance(item, (int, float)): +# batch[key] = torch.tensor(items) + +# # if torch_geometric.is_debug_enabled(): +# # batch.debug() + +# return batch.contiguous() + +# def get_example(self, idx: int) -> Data: +# r"""Reconstructs the :class:`torch_geometric.data.Data` object at index +# :obj:`idx` from the batch object. +# The batch object must have been created via :meth:`from_data_list` in +# order to be able to reconstruct the initial objects.""" + +# if self.__slices__ is None: +# raise RuntimeError( +# ( +# "Cannot reconstruct data list from batch because the batch " +# "object was not created using `Batch.from_data_list()`." +# ) +# ) + +# data = self.__data_class__() +# idx = self.num_graphs + idx if idx < 0 else idx + +# for key in self.__slices__.keys(): +# item = self[key] +# if self.__cat_dims__[key] is None: +# # The item was concatenated along a new batch dimension, +# # so just index in that dimension: +# item = item[idx] +# else: +# # Narrow the item based on the values in `__slices__`. +# if isinstance(item, Tensor): +# dim = self.__cat_dims__[key] +# start = self.__slices__[key][idx] +# end = self.__slices__[key][idx + 1] +# item = item.narrow(dim, start, end - start) +# else: +# start = self.__slices__[key][idx] +# end = self.__slices__[key][idx + 1] +# item = item[start:end] +# item = item[0] if len(item) == 1 else item + +# # Decrease its value by `cumsum` value: +# cum = self.__cumsum__[key][idx] +# if isinstance(item, Tensor): +# if not isinstance(cum, int) or cum != 0: +# item = item - cum +# elif isinstance(item, (int, float)): +# item = item - cum + +# data[key] = item + +# if self.__num_nodes_list__[idx] is not None: +# data.num_nodes = self.__num_nodes_list__[idx] + +# return data + +# def index_select(self, idx: IndexType) -> List[Data]: +# if isinstance(idx, slice): +# idx = list(range(self.num_graphs)[idx]) + +# elif isinstance(idx, Tensor) and idx.dtype == torch.long: +# idx = idx.flatten().tolist() + +# elif isinstance(idx, Tensor) and idx.dtype == torch.bool: +# idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + +# elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: +# idx = idx.flatten().tolist() + +# elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: +# idx = idx.flatten().nonzero()[0].flatten().tolist() + +# elif isinstance(idx, Sequence) and not isinstance(idx, str): +# pass + +# else: +# raise IndexError( +# f"Only integers, slices (':'), list, tuples, torch.tensor and " +# f"np.ndarray of dtype long or bool are valid indices (got " +# f"'{type(idx).__name__}')" +# ) + +# return [self.get_example(i) for i in idx] + +# def __getitem__(self, idx): +# if isinstance(idx, str): +# return super(Batch, self).__getitem__(idx) +# elif isinstance(idx, (int, np.integer)): +# return self.get_example(idx) +# else: +# return self.index_select(idx) + +# def to_data_list(self) -> List[Data]: +# r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects +# from the batch object. +# The batch object must have been created via :meth:`from_data_list` in +# order to be able to reconstruct the initial objects.""" +# return [self.get_example(i) for i in range(self.num_graphs)] + +# @property +# def num_graphs(self) -> int: +# """Returns the number of graphs in the batch.""" +# if self.__num_graphs__ is not None: +# return self.__num_graphs__ +# elif self.ptr is not None: +# return self.ptr.numel() - 1 +# elif self.batch is not None: +# return int(self.batch.max()) + 1 +# else: +# raise ValueError diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/data.py b/hydragnn/utils/mace_utils/tools/torch_geometric/data.py index 4e1ab3084..6fdd25d47 100644 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/data.py +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/data.py @@ -1,441 +1,441 @@ -import collections -import copy -import re - -import torch - -# from ..utils.num_nodes import maybe_num_nodes - -__num_nodes_warn_msg__ = ( - "The number of nodes in your data object can only be inferred by its {} " - "indices, and hence may result in unexpected batch-wise behavior, e.g., " - "in case there exists isolated nodes. Please consider explicitly setting " - "the number of nodes for this data object by assigning it to " - "data.num_nodes." -) - - -def size_repr(key, item, indent=0): - indent_str = " " * indent - if torch.is_tensor(item) and item.dim() == 0: - out = item.item() - elif torch.is_tensor(item): - out = str(list(item.size())) - elif isinstance(item, list) or isinstance(item, tuple): - out = str([len(item)]) - elif isinstance(item, dict): - lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] - out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" - elif isinstance(item, str): - out = f'"{item}"' - else: - out = str(item) - - return f"{indent_str}{key}={out}" - - -class Data(object): - r"""A plain old python object modeling a single graph with various - (optional) attributes: - - Args: - x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, - num_node_features]`. (default: :obj:`None`) - edge_index (LongTensor, optional): Graph connectivity in COO format - with shape :obj:`[2, num_edges]`. (default: :obj:`None`) - edge_attr (Tensor, optional): Edge feature matrix with shape - :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) - y (Tensor, optional): Graph or node targets with arbitrary shape. - (default: :obj:`None`) - pos (Tensor, optional): Node position matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - normal (Tensor, optional): Normal vector matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - face (LongTensor, optional): Face adjacency matrix with shape - :obj:`[3, num_faces]`. (default: :obj:`None`) - - The data object is not restricted to these attributes and can be extended - by any other additional data. - - Example:: - - data = Data(x=x, edge_index=edge_index) - data.train_idx = torch.tensor([...], dtype=torch.long) - data.test_mask = torch.tensor([...], dtype=torch.bool) - """ - - def __init__( - self, - x=None, - edge_index=None, - edge_attr=None, - y=None, - pos=None, - normal=None, - face=None, - **kwargs, - ): - self.x = x - self.edge_index = edge_index - self.edge_attr = edge_attr - self.y = y - self.pos = pos - self.normal = normal - self.face = face - for key, item in kwargs.items(): - if key == "num_nodes": - self.__num_nodes__ = item - else: - self[key] = item - - if edge_index is not None and edge_index.dtype != torch.long: - raise ValueError( - ( - f"Argument `edge_index` needs to be of type `torch.long` but " - f"found type `{edge_index.dtype}`." - ) - ) - - if face is not None and face.dtype != torch.long: - raise ValueError( - ( - f"Argument `face` needs to be of type `torch.long` but found " - f"type `{face.dtype}`." - ) - ) - - @classmethod - def from_dict(cls, dictionary): - r"""Creates a data object from a python dictionary.""" - data = cls() - - for key, item in dictionary.items(): - data[key] = item - - return data - - def to_dict(self): - return {key: item for key, item in self} - - def to_namedtuple(self): - keys = self.keys - DataTuple = collections.namedtuple("DataTuple", keys) - return DataTuple(*[self[key] for key in keys]) - - def __getitem__(self, key): - r"""Gets the data of the attribute :obj:`key`.""" - return getattr(self, key, None) - - def __setitem__(self, key, value): - """Sets the attribute :obj:`key` to :obj:`value`.""" - setattr(self, key, value) - - def __delitem__(self, key): - r"""Delete the data of the attribute :obj:`key`.""" - return delattr(self, key) - - @property - def keys(self): - r"""Returns all names of graph attributes.""" - keys = [key for key in self.__dict__.keys() if self[key] is not None] - keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] - return keys - - def __len__(self): - r"""Returns the number of all present attributes.""" - return len(self.keys) - - def __contains__(self, key): - r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the - data.""" - return key in self.keys - - def __iter__(self): - r"""Iterates over all present attributes in the data, yielding their - attribute names and content.""" - for key in sorted(self.keys): - yield key, self[key] - - def __call__(self, *keys): - r"""Iterates over all attributes :obj:`*keys` in the data, yielding - their attribute names and content. - If :obj:`*keys` is not given this method will iterative over all - present attributes.""" - for key in sorted(self.keys) if not keys else keys: - if key in self: - yield key, self[key] - - def __cat_dim__(self, key, value): - r"""Returns the dimension for which :obj:`value` of attribute - :obj:`key` will get concatenated when creating batches. - - .. note:: - - This method is for internal use only, and should only be overridden - if the batch concatenation process is corrupted for a specific data - attribute. - """ - if bool(re.search("(index|face)", key)): - return -1 - return 0 - - def __inc__(self, key, value): - r"""Returns the incremental count to cumulatively increase the value - of the next attribute of :obj:`key` when creating batches. - - .. note:: - - This method is for internal use only, and should only be overridden - if the batch concatenation process is corrupted for a specific data - attribute. - """ - # Only `*index*` and `*face*` attributes should be cumulatively summed - # up when creating batches. - return self.num_nodes if bool(re.search("(index|face)", key)) else 0 - - @property - def num_nodes(self): - r"""Returns or sets the number of nodes in the graph. - - .. note:: - The number of nodes in your data object is typically automatically - inferred, *e.g.*, when node features :obj:`x` are present. - In some cases however, a graph may only be given by its edge - indices :obj:`edge_index`. - PyTorch Geometric then *guesses* the number of nodes - according to :obj:`edge_index.max().item() + 1`, but in case there - exists isolated nodes, this number has not to be correct and can - therefore result in unexpected batch-wise behavior. - Thus, we recommend to set the number of nodes in your data object - explicitly via :obj:`data.num_nodes = ...`. - You will be given a warning that requests you to do so. - """ - if hasattr(self, "__num_nodes__"): - return self.__num_nodes__ - for key, item in self("x", "pos", "normal", "batch"): - return item.size(self.__cat_dim__(key, item)) - if hasattr(self, "adj"): - return self.adj.size(0) - if hasattr(self, "adj_t"): - return self.adj_t.size(1) - # if self.face is not None: - # logging.warning(__num_nodes_warn_msg__.format("face")) - # return maybe_num_nodes(self.face) - # if self.edge_index is not None: - # logging.warning(__num_nodes_warn_msg__.format("edge")) - # return maybe_num_nodes(self.edge_index) - return None - - @num_nodes.setter - def num_nodes(self, num_nodes): - self.__num_nodes__ = num_nodes - - @property - def num_edges(self): - """ - Returns the number of edges in the graph. - For undirected graphs, this will return the number of bi-directional - edges, which is double the amount of unique edges. - """ - for key, item in self("edge_index", "edge_attr"): - return item.size(self.__cat_dim__(key, item)) - for key, item in self("adj", "adj_t"): - return item.nnz() - return None - - @property - def num_faces(self): - r"""Returns the number of faces in the mesh.""" - if self.face is not None: - return self.face.size(self.__cat_dim__("face", self.face)) - return None - - @property - def num_node_features(self): - r"""Returns the number of features per node in the graph.""" - if self.x is None: - return 0 - return 1 if self.x.dim() == 1 else self.x.size(1) - - @property - def num_features(self): - r"""Alias for :py:attr:`~num_node_features`.""" - return self.num_node_features - - @property - def num_edge_features(self): - r"""Returns the number of features per edge in the graph.""" - if self.edge_attr is None: - return 0 - return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) - - def __apply__(self, item, func): - if torch.is_tensor(item): - return func(item) - elif isinstance(item, (tuple, list)): - return [self.__apply__(v, func) for v in item] - elif isinstance(item, dict): - return {k: self.__apply__(v, func) for k, v in item.items()} - else: - return item - - def apply(self, func, *keys): - r"""Applies the function :obj:`func` to all tensor attributes - :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to - all present attributes. - """ - for key, item in self(*keys): - self[key] = self.__apply__(item, func) - return self - - def contiguous(self, *keys): - r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. - If :obj:`*keys` is not given, all present attributes are ensured to - have a contiguous memory layout.""" - return self.apply(lambda x: x.contiguous(), *keys) - - def to(self, device, *keys, **kwargs): - r"""Performs tensor dtype and/or device conversion to all attributes - :obj:`*keys`. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.to(device, **kwargs), *keys) - - def cpu(self, *keys): - r"""Copies all attributes :obj:`*keys` to CPU memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.cpu(), *keys) - - def cuda(self, device=None, non_blocking=False, *keys): - r"""Copies all attributes :obj:`*keys` to CUDA memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply( - lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys - ) - - def clone(self): - r"""Performs a deep-copy of the data object.""" - return self.__class__.from_dict( - { - k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) - for k, v in self.__dict__.items() - } - ) - - def pin_memory(self, *keys): - r"""Copies all attributes :obj:`*keys` to pinned memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.pin_memory(), *keys) - - def debug(self): - if self.edge_index is not None: - if self.edge_index.dtype != torch.long: - raise RuntimeError( - ( - "Expected edge indices of dtype {}, but found dtype " " {}" - ).format(torch.long, self.edge_index.dtype) - ) - - if self.face is not None: - if self.face.dtype != torch.long: - raise RuntimeError( - ( - "Expected face indices of dtype {}, but found dtype " " {}" - ).format(torch.long, self.face.dtype) - ) - - if self.edge_index is not None: - if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: - raise RuntimeError( - ( - "Edge indices should have shape [2, num_edges] but found" - " shape {}" - ).format(self.edge_index.size()) - ) - - if self.edge_index is not None and self.num_nodes is not None: - if self.edge_index.numel() > 0: - min_index = self.edge_index.min() - max_index = self.edge_index.max() - else: - min_index = max_index = 0 - if min_index < 0 or max_index > self.num_nodes - 1: - raise RuntimeError( - ( - "Edge indices must lay in the interval [0, {}]" - " but found them in the interval [{}, {}]" - ).format(self.num_nodes - 1, min_index, max_index) - ) - - if self.face is not None: - if self.face.dim() != 2 or self.face.size(0) != 3: - raise RuntimeError( - ( - "Face indices should have shape [3, num_faces] but found" - " shape {}" - ).format(self.face.size()) - ) - - if self.face is not None and self.num_nodes is not None: - if self.face.numel() > 0: - min_index = self.face.min() - max_index = self.face.max() - else: - min_index = max_index = 0 - if min_index < 0 or max_index > self.num_nodes - 1: - raise RuntimeError( - ( - "Face indices must lay in the interval [0, {}]" - " but found them in the interval [{}, {}]" - ).format(self.num_nodes - 1, min_index, max_index) - ) - - if self.edge_index is not None and self.edge_attr is not None: - if self.edge_index.size(1) != self.edge_attr.size(0): - raise RuntimeError( - ( - "Edge indices and edge attributes hold a differing " - "number of edges, found {} and {}" - ).format(self.edge_index.size(), self.edge_attr.size()) - ) - - if self.x is not None and self.num_nodes is not None: - if self.x.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node features should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.x.size(0)) - ) - - if self.pos is not None and self.num_nodes is not None: - if self.pos.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node positions should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.pos.size(0)) - ) - - if self.normal is not None and self.num_nodes is not None: - if self.normal.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node normals should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.normal.size(0)) - ) - - def __repr__(self): - cls = str(self.__class__.__name__) - has_dict = any([isinstance(item, dict) for _, item in self]) - - if not has_dict: - info = [size_repr(key, item) for key, item in self] - return "{}({})".format(cls, ", ".join(info)) - else: - info = [size_repr(key, item, indent=2) for key, item in self] - return "{}(\n{}\n)".format(cls, ",\n".join(info)) +# import collections +# import copy +# import re + +# import torch + +# # from ..utils.num_nodes import maybe_num_nodes + +# __num_nodes_warn_msg__ = ( +# "The number of nodes in your data object can only be inferred by its {} " +# "indices, and hence may result in unexpected batch-wise behavior, e.g., " +# "in case there exists isolated nodes. Please consider explicitly setting " +# "the number of nodes for this data object by assigning it to " +# "data.num_nodes." +# ) + + +# def size_repr(key, item, indent=0): +# indent_str = " " * indent +# if torch.is_tensor(item) and item.dim() == 0: +# out = item.item() +# elif torch.is_tensor(item): +# out = str(list(item.size())) +# elif isinstance(item, list) or isinstance(item, tuple): +# out = str([len(item)]) +# elif isinstance(item, dict): +# lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] +# out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" +# elif isinstance(item, str): +# out = f'"{item}"' +# else: +# out = str(item) + +# return f"{indent_str}{key}={out}" + + +# class Data(object): +# r"""A plain old python object modeling a single graph with various +# (optional) attributes: + +# Args: +# x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, +# num_node_features]`. (default: :obj:`None`) +# edge_index (LongTensor, optional): Graph connectivity in COO format +# with shape :obj:`[2, num_edges]`. (default: :obj:`None`) +# edge_attr (Tensor, optional): Edge feature matrix with shape +# :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) +# y (Tensor, optional): Graph or node targets with arbitrary shape. +# (default: :obj:`None`) +# pos (Tensor, optional): Node position matrix with shape +# :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) +# normal (Tensor, optional): Normal vector matrix with shape +# :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) +# face (LongTensor, optional): Face adjacency matrix with shape +# :obj:`[3, num_faces]`. (default: :obj:`None`) + +# The data object is not restricted to these attributes and can be extended +# by any other additional data. + +# Example:: + +# data = Data(x=x, edge_index=edge_index) +# data.train_idx = torch.tensor([...], dtype=torch.long) +# data.test_mask = torch.tensor([...], dtype=torch.bool) +# """ + +# def __init__( +# self, +# x=None, +# edge_index=None, +# edge_attr=None, +# y=None, +# pos=None, +# normal=None, +# face=None, +# **kwargs, +# ): +# self.x = x +# self.edge_index = edge_index +# self.edge_attr = edge_attr +# self.y = y +# self.pos = pos +# self.normal = normal +# self.face = face +# for key, item in kwargs.items(): +# if key == "num_nodes": +# self.__num_nodes__ = item +# else: +# self[key] = item + +# if edge_index is not None and edge_index.dtype != torch.long: +# raise ValueError( +# ( +# f"Argument `edge_index` needs to be of type `torch.long` but " +# f"found type `{edge_index.dtype}`." +# ) +# ) + +# if face is not None and face.dtype != torch.long: +# raise ValueError( +# ( +# f"Argument `face` needs to be of type `torch.long` but found " +# f"type `{face.dtype}`." +# ) +# ) + +# @classmethod +# def from_dict(cls, dictionary): +# r"""Creates a data object from a python dictionary.""" +# data = cls() + +# for key, item in dictionary.items(): +# data[key] = item + +# return data + +# def to_dict(self): +# return {key: item for key, item in self} + +# def to_namedtuple(self): +# keys = self.keys +# DataTuple = collections.namedtuple("DataTuple", keys) +# return DataTuple(*[self[key] for key in keys]) + +# def __getitem__(self, key): +# r"""Gets the data of the attribute :obj:`key`.""" +# return getattr(self, key, None) + +# def __setitem__(self, key, value): +# """Sets the attribute :obj:`key` to :obj:`value`.""" +# setattr(self, key, value) + +# def __delitem__(self, key): +# r"""Delete the data of the attribute :obj:`key`.""" +# return delattr(self, key) + +# @property +# def keys(self): +# r"""Returns all names of graph attributes.""" +# keys = [key for key in self.__dict__.keys() if self[key] is not None] +# keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] +# return keys + +# def __len__(self): +# r"""Returns the number of all present attributes.""" +# return len(self.keys) + +# def __contains__(self, key): +# r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the +# data.""" +# return key in self.keys + +# def __iter__(self): +# r"""Iterates over all present attributes in the data, yielding their +# attribute names and content.""" +# for key in sorted(self.keys): +# yield key, self[key] + +# def __call__(self, *keys): +# r"""Iterates over all attributes :obj:`*keys` in the data, yielding +# their attribute names and content. +# If :obj:`*keys` is not given this method will iterative over all +# present attributes.""" +# for key in sorted(self.keys) if not keys else keys: +# if key in self: +# yield key, self[key] + +# def __cat_dim__(self, key, value): +# r"""Returns the dimension for which :obj:`value` of attribute +# :obj:`key` will get concatenated when creating batches. + +# .. note:: + +# This method is for internal use only, and should only be overridden +# if the batch concatenation process is corrupted for a specific data +# attribute. +# """ +# if bool(re.search("(index|face)", key)): +# return -1 +# return 0 + +# def __inc__(self, key, value): +# r"""Returns the incremental count to cumulatively increase the value +# of the next attribute of :obj:`key` when creating batches. + +# .. note:: + +# This method is for internal use only, and should only be overridden +# if the batch concatenation process is corrupted for a specific data +# attribute. +# """ +# # Only `*index*` and `*face*` attributes should be cumulatively summed +# # up when creating batches. +# return self.num_nodes if bool(re.search("(index|face)", key)) else 0 + +# @property +# def num_nodes(self): +# r"""Returns or sets the number of nodes in the graph. + +# .. note:: +# The number of nodes in your data object is typically automatically +# inferred, *e.g.*, when node features :obj:`x` are present. +# In some cases however, a graph may only be given by its edge +# indices :obj:`edge_index`. +# PyTorch Geometric then *guesses* the number of nodes +# according to :obj:`edge_index.max().item() + 1`, but in case there +# exists isolated nodes, this number has not to be correct and can +# therefore result in unexpected batch-wise behavior. +# Thus, we recommend to set the number of nodes in your data object +# explicitly via :obj:`data.num_nodes = ...`. +# You will be given a warning that requests you to do so. +# """ +# if hasattr(self, "__num_nodes__"): +# return self.__num_nodes__ +# for key, item in self("x", "pos", "normal", "batch"): +# return item.size(self.__cat_dim__(key, item)) +# if hasattr(self, "adj"): +# return self.adj.size(0) +# if hasattr(self, "adj_t"): +# return self.adj_t.size(1) +# # if self.face is not None: +# # logging.warning(__num_nodes_warn_msg__.format("face")) +# # return maybe_num_nodes(self.face) +# # if self.edge_index is not None: +# # logging.warning(__num_nodes_warn_msg__.format("edge")) +# # return maybe_num_nodes(self.edge_index) +# return None + +# @num_nodes.setter +# def num_nodes(self, num_nodes): +# self.__num_nodes__ = num_nodes + +# @property +# def num_edges(self): +# """ +# Returns the number of edges in the graph. +# For undirected graphs, this will return the number of bi-directional +# edges, which is double the amount of unique edges. +# """ +# for key, item in self("edge_index", "edge_attr"): +# return item.size(self.__cat_dim__(key, item)) +# for key, item in self("adj", "adj_t"): +# return item.nnz() +# return None + +# @property +# def num_faces(self): +# r"""Returns the number of faces in the mesh.""" +# if self.face is not None: +# return self.face.size(self.__cat_dim__("face", self.face)) +# return None + +# @property +# def num_node_features(self): +# r"""Returns the number of features per node in the graph.""" +# if self.x is None: +# return 0 +# return 1 if self.x.dim() == 1 else self.x.size(1) + +# @property +# def num_features(self): +# r"""Alias for :py:attr:`~num_node_features`.""" +# return self.num_node_features + +# @property +# def num_edge_features(self): +# r"""Returns the number of features per edge in the graph.""" +# if self.edge_attr is None: +# return 0 +# return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) + +# def __apply__(self, item, func): +# if torch.is_tensor(item): +# return func(item) +# elif isinstance(item, (tuple, list)): +# return [self.__apply__(v, func) for v in item] +# elif isinstance(item, dict): +# return {k: self.__apply__(v, func) for k, v in item.items()} +# else: +# return item + +# def apply(self, func, *keys): +# r"""Applies the function :obj:`func` to all tensor attributes +# :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to +# all present attributes. +# """ +# for key, item in self(*keys): +# self[key] = self.__apply__(item, func) +# return self + +# def contiguous(self, *keys): +# r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. +# If :obj:`*keys` is not given, all present attributes are ensured to +# have a contiguous memory layout.""" +# return self.apply(lambda x: x.contiguous(), *keys) + +# def to(self, device, *keys, **kwargs): +# r"""Performs tensor dtype and/or device conversion to all attributes +# :obj:`*keys`. +# If :obj:`*keys` is not given, the conversion is applied to all present +# attributes.""" +# return self.apply(lambda x: x.to(device, **kwargs), *keys) + +# def cpu(self, *keys): +# r"""Copies all attributes :obj:`*keys` to CPU memory. +# If :obj:`*keys` is not given, the conversion is applied to all present +# attributes.""" +# return self.apply(lambda x: x.cpu(), *keys) + +# def cuda(self, device=None, non_blocking=False, *keys): +# r"""Copies all attributes :obj:`*keys` to CUDA memory. +# If :obj:`*keys` is not given, the conversion is applied to all present +# attributes.""" +# return self.apply( +# lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys +# ) + +# def clone(self): +# r"""Performs a deep-copy of the data object.""" +# return self.__class__.from_dict( +# { +# k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) +# for k, v in self.__dict__.items() +# } +# ) + +# def pin_memory(self, *keys): +# r"""Copies all attributes :obj:`*keys` to pinned memory. +# If :obj:`*keys` is not given, the conversion is applied to all present +# attributes.""" +# return self.apply(lambda x: x.pin_memory(), *keys) + +# def debug(self): +# if self.edge_index is not None: +# if self.edge_index.dtype != torch.long: +# raise RuntimeError( +# ( +# "Expected edge indices of dtype {}, but found dtype " " {}" +# ).format(torch.long, self.edge_index.dtype) +# ) + +# if self.face is not None: +# if self.face.dtype != torch.long: +# raise RuntimeError( +# ( +# "Expected face indices of dtype {}, but found dtype " " {}" +# ).format(torch.long, self.face.dtype) +# ) + +# if self.edge_index is not None: +# if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: +# raise RuntimeError( +# ( +# "Edge indices should have shape [2, num_edges] but found" +# " shape {}" +# ).format(self.edge_index.size()) +# ) + +# if self.edge_index is not None and self.num_nodes is not None: +# if self.edge_index.numel() > 0: +# min_index = self.edge_index.min() +# max_index = self.edge_index.max() +# else: +# min_index = max_index = 0 +# if min_index < 0 or max_index > self.num_nodes - 1: +# raise RuntimeError( +# ( +# "Edge indices must lay in the interval [0, {}]" +# " but found them in the interval [{}, {}]" +# ).format(self.num_nodes - 1, min_index, max_index) +# ) + +# if self.face is not None: +# if self.face.dim() != 2 or self.face.size(0) != 3: +# raise RuntimeError( +# ( +# "Face indices should have shape [3, num_faces] but found" +# " shape {}" +# ).format(self.face.size()) +# ) + +# if self.face is not None and self.num_nodes is not None: +# if self.face.numel() > 0: +# min_index = self.face.min() +# max_index = self.face.max() +# else: +# min_index = max_index = 0 +# if min_index < 0 or max_index > self.num_nodes - 1: +# raise RuntimeError( +# ( +# "Face indices must lay in the interval [0, {}]" +# " but found them in the interval [{}, {}]" +# ).format(self.num_nodes - 1, min_index, max_index) +# ) + +# if self.edge_index is not None and self.edge_attr is not None: +# if self.edge_index.size(1) != self.edge_attr.size(0): +# raise RuntimeError( +# ( +# "Edge indices and edge attributes hold a differing " +# "number of edges, found {} and {}" +# ).format(self.edge_index.size(), self.edge_attr.size()) +# ) + +# if self.x is not None and self.num_nodes is not None: +# if self.x.size(0) != self.num_nodes: +# raise RuntimeError( +# ( +# "Node features should hold {} elements in the first " +# "dimension but found {}" +# ).format(self.num_nodes, self.x.size(0)) +# ) + +# if self.pos is not None and self.num_nodes is not None: +# if self.pos.size(0) != self.num_nodes: +# raise RuntimeError( +# ( +# "Node positions should hold {} elements in the first " +# "dimension but found {}" +# ).format(self.num_nodes, self.pos.size(0)) +# ) + +# if self.normal is not None and self.num_nodes is not None: +# if self.normal.size(0) != self.num_nodes: +# raise RuntimeError( +# ( +# "Node normals should hold {} elements in the first " +# "dimension but found {}" +# ).format(self.num_nodes, self.normal.size(0)) +# ) + +# def __repr__(self): +# cls = str(self.__class__.__name__) +# has_dict = any([isinstance(item, dict) for _, item in self]) + +# if not has_dict: +# info = [size_repr(key, item) for key, item in self] +# return "{}({})".format(cls, ", ".join(info)) +# else: +# info = [size_repr(key, item, indent=2) for key, item in self] +# return "{}(\n{}\n)".format(cls, ",\n".join(info)) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py b/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py index 396b7e728..81786ad4a 100644 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py @@ -1,87 +1,87 @@ -from collections.abc import Mapping, Sequence -from typing import List, Optional, Union +# from collections.abc import Mapping, Sequence +# from typing import List, Optional, Union -import torch.utils.data -from torch.utils.data.dataloader import default_collate +# import torch.utils.data +# from torch.utils.data.dataloader import default_collate -from .batch import Batch -from .data import Data -from .dataset import Dataset +# from .batch import Batch +# from .data import Data +# from .dataset import Dataset -class Collater: - def __init__(self, follow_batch, exclude_keys): - self.follow_batch = follow_batch - self.exclude_keys = exclude_keys +# class Collater: +# def __init__(self, follow_batch, exclude_keys): +# self.follow_batch = follow_batch +# self.exclude_keys = exclude_keys - def __call__(self, batch): - elem = batch[0] - if isinstance(elem, Data): - return Batch.from_data_list( - batch, - follow_batch=self.follow_batch, - exclude_keys=self.exclude_keys, - ) - elif isinstance(elem, torch.Tensor): - return default_collate(batch) - elif isinstance(elem, float): - return torch.tensor(batch, dtype=torch.float) - elif isinstance(elem, int): - return torch.tensor(batch) - elif isinstance(elem, str): - return batch - elif isinstance(elem, Mapping): - return {key: self([data[key] for data in batch]) for key in elem} - elif isinstance(elem, tuple) and hasattr(elem, "_fields"): - return type(elem)(*(self(s) for s in zip(*batch))) - elif isinstance(elem, Sequence) and not isinstance(elem, str): - return [self(s) for s in zip(*batch)] +# def __call__(self, batch): +# elem = batch[0] +# if isinstance(elem, Data): +# return Batch.from_data_list( +# batch, +# follow_batch=self.follow_batch, +# exclude_keys=self.exclude_keys, +# ) +# elif isinstance(elem, torch.Tensor): +# return default_collate(batch) +# elif isinstance(elem, float): +# return torch.tensor(batch, dtype=torch.float) +# elif isinstance(elem, int): +# return torch.tensor(batch) +# elif isinstance(elem, str): +# return batch +# elif isinstance(elem, Mapping): +# return {key: self([data[key] for data in batch]) for key in elem} +# elif isinstance(elem, tuple) and hasattr(elem, "_fields"): +# return type(elem)(*(self(s) for s in zip(*batch))) +# elif isinstance(elem, Sequence) and not isinstance(elem, str): +# return [self(s) for s in zip(*batch)] - raise TypeError(f"DataLoader found invalid type: {type(elem)}") +# raise TypeError(f"DataLoader found invalid type: {type(elem)}") - def collate(self, batch): # Deprecated... - return self(batch) +# def collate(self, batch): # Deprecated... +# return self(batch) -class DataLoader(torch.utils.data.DataLoader): - r"""A data loader which merges data objects from a - :class:`torch_geometric.data.Dataset` to a mini-batch. - Data objects can be either of type :class:`~torch_geometric.data.Data` or - :class:`~torch_geometric.data.HeteroData`. - Args: - dataset (Dataset): The dataset from which to load the data. - batch_size (int, optional): How many samples per batch to load. - (default: :obj:`1`) - shuffle (bool, optional): If set to :obj:`True`, the data will be - reshuffled at every epoch. (default: :obj:`False`) - follow_batch (List[str], optional): Creates assignment batch - vectors for each key in the list. (default: :obj:`None`) - exclude_keys (List[str], optional): Will exclude each key in the - list. (default: :obj:`None`) - **kwargs (optional): Additional arguments of - :class:`torch.utils.data.DataLoader`. - """ +# class DataLoader(torch.utils.data.DataLoader): +# r"""A data loader which merges data objects from a +# :class:`torch_geometric.data.Dataset` to a mini-batch. +# Data objects can be either of type :class:`~torch_geometric.data.Data` or +# :class:`~torch_geometric.data.HeteroData`. +# Args: +# dataset (Dataset): The dataset from which to load the data. +# batch_size (int, optional): How many samples per batch to load. +# (default: :obj:`1`) +# shuffle (bool, optional): If set to :obj:`True`, the data will be +# reshuffled at every epoch. (default: :obj:`False`) +# follow_batch (List[str], optional): Creates assignment batch +# vectors for each key in the list. (default: :obj:`None`) +# exclude_keys (List[str], optional): Will exclude each key in the +# list. (default: :obj:`None`) +# **kwargs (optional): Additional arguments of +# :class:`torch.utils.data.DataLoader`. +# """ - def __init__( - self, - dataset: Dataset, - batch_size: int = 1, - shuffle: bool = False, - follow_batch: Optional[List[str]] = [None], - exclude_keys: Optional[List[str]] = [None], - **kwargs, - ): - if "collate_fn" in kwargs: - del kwargs["collate_fn"] +# def __init__( +# self, +# dataset: Dataset, +# batch_size: int = 1, +# shuffle: bool = False, +# follow_batch: Optional[List[str]] = [None], +# exclude_keys: Optional[List[str]] = [None], +# **kwargs, +# ): +# if "collate_fn" in kwargs: +# del kwargs["collate_fn"] - # Save for PyTorch Lightning < 1.6: - self.follow_batch = follow_batch - self.exclude_keys = exclude_keys +# # Save for PyTorch Lightning < 1.6: +# self.follow_batch = follow_batch +# self.exclude_keys = exclude_keys - super().__init__( - dataset, - batch_size, - shuffle, - collate_fn=Collater(follow_batch, exclude_keys), - **kwargs, - ) +# super().__init__( +# dataset, +# batch_size, +# shuffle, +# collate_fn=Collater(follow_batch, exclude_keys), +# **kwargs, +# ) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py b/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py index b4aeb2be9..4e0add3c3 100644 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py @@ -1,280 +1,280 @@ -import copy -import os.path as osp -import re -import warnings -from collections.abc import Sequence -from typing import Any, Callable, List, Optional, Tuple, Union - -import numpy as np -import torch.utils.data -from torch import Tensor - -from .data import Data -from .utils import makedirs - -IndexType = Union[slice, Tensor, np.ndarray, Sequence] - - -class Dataset(torch.utils.data.Dataset): - r"""Dataset base class for creating graph datasets. - See `here `__ for the accompanying tutorial. - - Args: - root (string, optional): Root directory where the dataset should be - saved. (optional: :obj:`None`) - transform (callable, optional): A function/transform that takes in an - :obj:`torch_geometric.data.Data` object and returns a transformed - version. The data object will be transformed before every access. - (default: :obj:`None`) - pre_transform (callable, optional): A function/transform that takes in - an :obj:`torch_geometric.data.Data` object and returns a - transformed version. The data object will be transformed before - being saved to disk. (default: :obj:`None`) - pre_filter (callable, optional): A function that takes in an - :obj:`torch_geometric.data.Data` object and returns a boolean - value, indicating whether the data object should be included in the - final dataset. (default: :obj:`None`) - """ - - @property - def raw_file_names(self) -> Union[str, List[str], Tuple]: - r"""The name of the files to find in the :obj:`self.raw_dir` folder in - order to skip the download.""" - raise NotImplementedError - - @property - def processed_file_names(self) -> Union[str, List[str], Tuple]: - r"""The name of the files to find in the :obj:`self.processed_dir` - folder in order to skip the processing.""" - raise NotImplementedError - - def download(self): - r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" - raise NotImplementedError - - def process(self): - r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" - raise NotImplementedError - - def len(self) -> int: - raise NotImplementedError - - def get(self, idx: int) -> Data: - r"""Gets the data object at index :obj:`idx`.""" - raise NotImplementedError - - def __init__( - self, - root: Optional[str] = None, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, - ): - super().__init__() - - if isinstance(root, str): - root = osp.expanduser(osp.normpath(root)) - - self.root = root - self.transform = transform - self.pre_transform = pre_transform - self.pre_filter = pre_filter - self._indices: Optional[Sequence] = None - - if "download" in self.__class__.__dict__.keys(): - self._download() - - if "process" in self.__class__.__dict__.keys(): - self._process() - - def indices(self) -> Sequence: - return range(self.len()) if self._indices is None else self._indices - - @property - def raw_dir(self) -> str: - return osp.join(self.root, "raw") - - @property - def processed_dir(self) -> str: - return osp.join(self.root, "processed") - - @property - def num_node_features(self) -> int: - r"""Returns the number of features per node in the dataset.""" - data = self[0] - if hasattr(data, "num_node_features"): - return data.num_node_features - raise AttributeError( - f"'{data.__class__.__name__}' object has no " - f"attribute 'num_node_features'" - ) - - @property - def num_features(self) -> int: - r"""Alias for :py:attr:`~num_node_features`.""" - return self.num_node_features - - @property - def num_edge_features(self) -> int: - r"""Returns the number of features per edge in the dataset.""" - data = self[0] - if hasattr(data, "num_edge_features"): - return data.num_edge_features - raise AttributeError( - f"'{data.__class__.__name__}' object has no " - f"attribute 'num_edge_features'" - ) - - @property - def raw_paths(self) -> List[str]: - r"""The filepaths to find in order to skip the download.""" - files = to_list(self.raw_file_names) - return [osp.join(self.raw_dir, f) for f in files] - - @property - def processed_paths(self) -> List[str]: - r"""The filepaths to find in the :obj:`self.processed_dir` - folder in order to skip the processing.""" - files = to_list(self.processed_file_names) - return [osp.join(self.processed_dir, f) for f in files] - - def _download(self): - if files_exist(self.raw_paths): # pragma: no cover - return - - makedirs(self.raw_dir) - self.download() - - def _process(self): - f = osp.join(self.processed_dir, "pre_transform.pt") - if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): - warnings.warn( - f"The `pre_transform` argument differs from the one used in " - f"the pre-processed version of this dataset. If you want to " - f"make use of another pre-processing technique, make sure to " - f"sure to delete '{self.processed_dir}' first" - ) - - f = osp.join(self.processed_dir, "pre_filter.pt") - if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): - warnings.warn( - "The `pre_filter` argument differs from the one used in the " - "pre-processed version of this dataset. If you want to make " - "use of another pre-fitering technique, make sure to delete " - "'{self.processed_dir}' first" - ) - - if files_exist(self.processed_paths): # pragma: no cover - return - - print("Processing...") - - makedirs(self.processed_dir) - self.process() - - path = osp.join(self.processed_dir, "pre_transform.pt") - torch.save(_repr(self.pre_transform), path) - path = osp.join(self.processed_dir, "pre_filter.pt") - torch.save(_repr(self.pre_filter), path) - - print("Done!") - - def __len__(self) -> int: - r"""The number of examples in the dataset.""" - return len(self.indices()) - - def __getitem__( - self, - idx: Union[int, np.integer, IndexType], - ) -> Union["Dataset", Data]: - r"""In case :obj:`idx` is of type integer, will return the data object - at index :obj:`idx` (and transforms it in case :obj:`transform` is - present). - In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a - tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy - :obj:`np.array`, will return a subset of the dataset at the specified - indices.""" - if ( - isinstance(idx, (int, np.integer)) - or (isinstance(idx, Tensor) and idx.dim() == 0) - or (isinstance(idx, np.ndarray) and np.isscalar(idx)) - ): - data = self.get(self.indices()[idx]) - data = data if self.transform is None else self.transform(data) - return data - - else: - return self.index_select(idx) - - def index_select(self, idx: IndexType) -> "Dataset": - indices = self.indices() - - if isinstance(idx, slice): - indices = indices[idx] - - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - idx = idx.flatten().nonzero(as_tuple=False) - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - idx = idx.flatten().nonzero()[0] - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, Sequence) and not isinstance(idx, str): - indices = [indices[i] for i in idx] - - else: - raise IndexError( - f"Only integers, slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')" - ) - - dataset = copy.copy(self) - dataset._indices = indices - return dataset - - def shuffle( - self, - return_perm: bool = False, - ) -> Union["Dataset", Tuple["Dataset", Tensor]]: - r"""Randomly shuffles the examples in the dataset. - - Args: - return_perm (bool, optional): If set to :obj:`True`, will return - the random permutation used to shuffle the dataset in addition. - (default: :obj:`False`) - """ - perm = torch.randperm(len(self)) - dataset = self.index_select(perm) - return (dataset, perm) if return_perm is True else dataset - - def __repr__(self) -> str: - arg_repr = str(len(self)) if len(self) > 1 else "" - return f"{self.__class__.__name__}({arg_repr})" - - -def to_list(value: Any) -> Sequence: - if isinstance(value, Sequence) and not isinstance(value, str): - return value - else: - return [value] - - -def files_exist(files: List[str]) -> bool: - # NOTE: We return `False` in case `files` is empty, leading to a - # re-processing of files on every instantiation. - return len(files) != 0 and all([osp.exists(f) for f in files]) - - -def _repr(obj: Any) -> str: - if obj is None: - return "None" - return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) +# import copy +# import os.path as osp +# import re +# import warnings +# from collections.abc import Sequence +# from typing import Any, Callable, List, Optional, Tuple, Union + +# import numpy as np +# import torch.utils.data +# from torch import Tensor + +# from .data import Data +# from .utils import makedirs + +# IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +# class Dataset(torch.utils.data.Dataset): +# r"""Dataset base class for creating graph datasets. +# See `here `__ for the accompanying tutorial. + +# Args: +# root (string, optional): Root directory where the dataset should be +# saved. (optional: :obj:`None`) +# transform (callable, optional): A function/transform that takes in an +# :obj:`torch_geometric.data.Data` object and returns a transformed +# version. The data object will be transformed before every access. +# (default: :obj:`None`) +# pre_transform (callable, optional): A function/transform that takes in +# an :obj:`torch_geometric.data.Data` object and returns a +# transformed version. The data object will be transformed before +# being saved to disk. (default: :obj:`None`) +# pre_filter (callable, optional): A function that takes in an +# :obj:`torch_geometric.data.Data` object and returns a boolean +# value, indicating whether the data object should be included in the +# final dataset. (default: :obj:`None`) +# """ + +# @property +# def raw_file_names(self) -> Union[str, List[str], Tuple]: +# r"""The name of the files to find in the :obj:`self.raw_dir` folder in +# order to skip the download.""" +# raise NotImplementedError + +# @property +# def processed_file_names(self) -> Union[str, List[str], Tuple]: +# r"""The name of the files to find in the :obj:`self.processed_dir` +# folder in order to skip the processing.""" +# raise NotImplementedError + +# def download(self): +# r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" +# raise NotImplementedError + +# def process(self): +# r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" +# raise NotImplementedError + +# def len(self) -> int: +# raise NotImplementedError + +# def get(self, idx: int) -> Data: +# r"""Gets the data object at index :obj:`idx`.""" +# raise NotImplementedError + +# def __init__( +# self, +# root: Optional[str] = None, +# transform: Optional[Callable] = None, +# pre_transform: Optional[Callable] = None, +# pre_filter: Optional[Callable] = None, +# ): +# super().__init__() + +# if isinstance(root, str): +# root = osp.expanduser(osp.normpath(root)) + +# self.root = root +# self.transform = transform +# self.pre_transform = pre_transform +# self.pre_filter = pre_filter +# self._indices: Optional[Sequence] = None + +# if "download" in self.__class__.__dict__.keys(): +# self._download() + +# if "process" in self.__class__.__dict__.keys(): +# self._process() + +# def indices(self) -> Sequence: +# return range(self.len()) if self._indices is None else self._indices + +# @property +# def raw_dir(self) -> str: +# return osp.join(self.root, "raw") + +# @property +# def processed_dir(self) -> str: +# return osp.join(self.root, "processed") + +# @property +# def num_node_features(self) -> int: +# r"""Returns the number of features per node in the dataset.""" +# data = self[0] +# if hasattr(data, "num_node_features"): +# return data.num_node_features +# raise AttributeError( +# f"'{data.__class__.__name__}' object has no " +# f"attribute 'num_node_features'" +# ) + +# @property +# def num_features(self) -> int: +# r"""Alias for :py:attr:`~num_node_features`.""" +# return self.num_node_features + +# @property +# def num_edge_features(self) -> int: +# r"""Returns the number of features per edge in the dataset.""" +# data = self[0] +# if hasattr(data, "num_edge_features"): +# return data.num_edge_features +# raise AttributeError( +# f"'{data.__class__.__name__}' object has no " +# f"attribute 'num_edge_features'" +# ) + +# @property +# def raw_paths(self) -> List[str]: +# r"""The filepaths to find in order to skip the download.""" +# files = to_list(self.raw_file_names) +# return [osp.join(self.raw_dir, f) for f in files] + +# @property +# def processed_paths(self) -> List[str]: +# r"""The filepaths to find in the :obj:`self.processed_dir` +# folder in order to skip the processing.""" +# files = to_list(self.processed_file_names) +# return [osp.join(self.processed_dir, f) for f in files] + +# def _download(self): +# if files_exist(self.raw_paths): # pragma: no cover +# return + +# makedirs(self.raw_dir) +# self.download() + +# def _process(self): +# f = osp.join(self.processed_dir, "pre_transform.pt") +# if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): +# warnings.warn( +# f"The `pre_transform` argument differs from the one used in " +# f"the pre-processed version of this dataset. If you want to " +# f"make use of another pre-processing technique, make sure to " +# f"sure to delete '{self.processed_dir}' first" +# ) + +# f = osp.join(self.processed_dir, "pre_filter.pt") +# if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): +# warnings.warn( +# "The `pre_filter` argument differs from the one used in the " +# "pre-processed version of this dataset. If you want to make " +# "use of another pre-fitering technique, make sure to delete " +# "'{self.processed_dir}' first" +# ) + +# if files_exist(self.processed_paths): # pragma: no cover +# return + +# print("Processing...") + +# makedirs(self.processed_dir) +# self.process() + +# path = osp.join(self.processed_dir, "pre_transform.pt") +# torch.save(_repr(self.pre_transform), path) +# path = osp.join(self.processed_dir, "pre_filter.pt") +# torch.save(_repr(self.pre_filter), path) + +# print("Done!") + +# def __len__(self) -> int: +# r"""The number of examples in the dataset.""" +# return len(self.indices()) + +# def __getitem__( +# self, +# idx: Union[int, np.integer, IndexType], +# ) -> Union["Dataset", Data]: +# r"""In case :obj:`idx` is of type integer, will return the data object +# at index :obj:`idx` (and transforms it in case :obj:`transform` is +# present). +# In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a +# tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy +# :obj:`np.array`, will return a subset of the dataset at the specified +# indices.""" +# if ( +# isinstance(idx, (int, np.integer)) +# or (isinstance(idx, Tensor) and idx.dim() == 0) +# or (isinstance(idx, np.ndarray) and np.isscalar(idx)) +# ): +# data = self.get(self.indices()[idx]) +# data = data if self.transform is None else self.transform(data) +# return data + +# else: +# return self.index_select(idx) + +# def index_select(self, idx: IndexType) -> "Dataset": +# indices = self.indices() + +# if isinstance(idx, slice): +# indices = indices[idx] + +# elif isinstance(idx, Tensor) and idx.dtype == torch.long: +# return self.index_select(idx.flatten().tolist()) + +# elif isinstance(idx, Tensor) and idx.dtype == torch.bool: +# idx = idx.flatten().nonzero(as_tuple=False) +# return self.index_select(idx.flatten().tolist()) + +# elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: +# return self.index_select(idx.flatten().tolist()) + +# elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: +# idx = idx.flatten().nonzero()[0] +# return self.index_select(idx.flatten().tolist()) + +# elif isinstance(idx, Sequence) and not isinstance(idx, str): +# indices = [indices[i] for i in idx] + +# else: +# raise IndexError( +# f"Only integers, slices (':'), list, tuples, torch.tensor and " +# f"np.ndarray of dtype long or bool are valid indices (got " +# f"'{type(idx).__name__}')" +# ) + +# dataset = copy.copy(self) +# dataset._indices = indices +# return dataset + +# def shuffle( +# self, +# return_perm: bool = False, +# ) -> Union["Dataset", Tuple["Dataset", Tensor]]: +# r"""Randomly shuffles the examples in the dataset. + +# Args: +# return_perm (bool, optional): If set to :obj:`True`, will return +# the random permutation used to shuffle the dataset in addition. +# (default: :obj:`False`) +# """ +# perm = torch.randperm(len(self)) +# dataset = self.index_select(perm) +# return (dataset, perm) if return_perm is True else dataset + +# def __repr__(self) -> str: +# arg_repr = str(len(self)) if len(self) > 1 else "" +# return f"{self.__class__.__name__}({arg_repr})" + + +# def to_list(value: Any) -> Sequence: +# if isinstance(value, Sequence) and not isinstance(value, str): +# return value +# else: +# return [value] + + +# def files_exist(files: List[str]) -> bool: +# # NOTE: We return `False` in case `files` is empty, leading to a +# # re-processing of files on every instantiation. +# return len(files) != 0 and all([osp.exists(f) for f in files]) + + +# def _repr(obj: Any) -> str: +# if obj is None: +# return "None" +# return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py b/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py index be27fcaa1..2222226f4 100644 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py @@ -1,17 +1,17 @@ -import random +# import random -import numpy as np -import torch +# import numpy as np +# import torch -def seed_everything(seed: int): - r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, - :obj:`numpy` and Python. +# def seed_everything(seed: int): +# r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, +# :obj:`numpy` and Python. - Args: - seed (int): The desired seed. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) +# Args: +# seed (int): The desired seed. +# """ +# random.seed(seed) +# np.random.seed(seed) +# torch.manual_seed(seed) +# torch.cuda.manual_seed_all(seed) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py b/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py index f53b8f809..b03a56ceb 100644 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py +++ b/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py @@ -1,54 +1,54 @@ -import os -import os.path as osp -import ssl -import urllib -import zipfile +# import os +# import os.path as osp +# import ssl +# import urllib +# import zipfile -def makedirs(dir): - os.makedirs(dir, exist_ok=True) +# def makedirs(dir): +# os.makedirs(dir, exist_ok=True) -def download_url(url, folder, log=True): - r"""Downloads the content of an URL to a specific folder. +# def download_url(url, folder, log=True): +# r"""Downloads the content of an URL to a specific folder. - Args: - url (string): The url. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ +# Args: +# url (string): The url. +# folder (string): The folder. +# log (bool, optional): If :obj:`False`, will not print anything to the +# console. (default: :obj:`True`) +# """ - filename = url.rpartition("/")[2].split("?")[0] - path = osp.join(folder, filename) +# filename = url.rpartition("/")[2].split("?")[0] +# path = osp.join(folder, filename) - if osp.exists(path): # pragma: no cover - if log: - print("Using exist file", filename) - return path +# if osp.exists(path): # pragma: no cover +# if log: +# print("Using exist file", filename) +# return path - if log: - print("Downloading", url) +# if log: +# print("Downloading", url) - makedirs(folder) +# makedirs(folder) - context = ssl._create_unverified_context() - data = urllib.request.urlopen(url, context=context) +# context = ssl._create_unverified_context() +# data = urllib.request.urlopen(url, context=context) - with open(path, "wb") as f: - f.write(data.read()) +# with open(path, "wb") as f: +# f.write(data.read()) - return path +# return path -def extract_zip(path, folder, log=True): - r"""Extracts a zip archive to a specific folder. +# def extract_zip(path, folder, log=True): +# r"""Extracts a zip archive to a specific folder. - Args: - path (string): The path to the tar archive. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - with zipfile.ZipFile(path, "r") as f: - f.extractall(folder) +# Args: +# path (string): The path to the tar archive. +# folder (string): The folder. +# log (bool, optional): If :obj:`False`, will not print anything to the +# console. (default: :obj:`True`) +# """ +# with zipfile.ZipFile(path, "r") as f: +# f.extractall(folder) From f5148b82e43ab1a652b29e0380ca365edf5ad23e Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 27 Sep 2024 11:48:08 -0400 Subject: [PATCH 30/51] formatting --- hydragnn/utils/mace_utils/modules/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index ae952ed25..52eab322d 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -15,6 +15,7 @@ from hydragnn.utils.mace_utils.tools import to_numpy from hydragnn.utils.mace_utils.tools.scatter import scatter_sum + # from hydragnn.utils.mace_utils.tools.torch_geometric.batch import Batch from .blocks import AtomicEnergiesBlock From 7f8129de83cc3ad19102f89b5ea9cda1809f2ade Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 27 Sep 2024 12:13:34 -0400 Subject: [PATCH 31/51] delete comments --- .../tools/torch_geometric/README.md | 12 - .../tools/torch_geometric/__init__.py | 8 - .../mace_utils/tools/torch_geometric/batch.py | 257 ---------- .../mace_utils/tools/torch_geometric/data.py | 441 ------------------ .../tools/torch_geometric/dataloader.py | 87 ---- .../tools/torch_geometric/dataset.py | 280 ----------- .../mace_utils/tools/torch_geometric/seed.py | 17 - .../mace_utils/tools/torch_geometric/utils.py | 54 --- 8 files changed, 1156 deletions(-) delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/README.md delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/batch.py delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/data.py delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/seed.py delete mode 100644 hydragnn/utils/mace_utils/tools/torch_geometric/utils.py diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/README.md b/hydragnn/utils/mace_utils/tools/torch_geometric/README.md deleted file mode 100644 index 261ebbbc7..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# Trimmed-down `pytorch_geometric` - -MACE uses [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. However as only use a very limited subset of that library: the most basic graph data structures. - -We follow the same approach to NequIP (https://github.com/mir-group/nequip/tree/main/nequip) and copy their code here. - -To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is necessary for our code. - -We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch. - -[1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric
-[2] https://arxiv.org/abs/1903.02428 diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py b/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py deleted file mode 100644 index ea70a022f..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# from .batch import Batch -# from .data import Data -# from .dataloader import DataLoader -# from .dataset import Dataset -# from .seed import seed_everything - -# __all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] -# __all__ = ["Data", "Dataset", "DataLoader", "seed_everything"] diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py b/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py deleted file mode 100644 index 8dfd3ddc1..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/batch.py +++ /dev/null @@ -1,257 +0,0 @@ -# from collections.abc import Sequence -# from typing import List - -# import numpy as np -# import torch -# from torch import Tensor - -# from .data import Data -# from .dataset import IndexType - - -# class Batch(Data): -# r"""A plain old python object modeling a batch of graphs as one big -# (disconnected) graph. With :class:`torch_geometric.data.Data` being the -# base class, all its methods can also be used here. -# In addition, single graphs can be reconstructed via the assignment vector -# :obj:`batch`, which maps each node to its respective graph identifier. -# """ - -# def __init__(self, batch=None, ptr=None, **kwargs): -# super(Batch, self).__init__(**kwargs) - -# for key, item in kwargs.items(): -# if key == "num_nodes": -# self.__num_nodes__ = item -# else: -# self[key] = item - -# self.batch = batch -# self.ptr = ptr -# self.__data_class__ = Data -# self.__slices__ = None -# self.__cumsum__ = None -# self.__cat_dims__ = None -# self.__num_nodes_list__ = None -# self.__num_graphs__ = None - -# @classmethod -# def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): -# r"""Constructs a batch object from a python list holding -# :class:`torch_geometric.data.Data` objects. -# The assignment vector :obj:`batch` is created on the fly. -# Additionally, creates assignment batch vectors for each key in -# :obj:`follow_batch`. -# Will exclude any keys given in :obj:`exclude_keys`.""" - -# keys = list(set(data_list[0].keys) - set(exclude_keys)) -# assert "batch" not in keys and "ptr" not in keys - -# batch = cls() -# for key in data_list[0].__dict__.keys(): -# if key[:2] != "__" and key[-2:] != "__": -# batch[key] = None - -# batch.__num_graphs__ = len(data_list) -# batch.__data_class__ = data_list[0].__class__ -# for key in keys + ["batch"]: -# batch[key] = [] -# batch["ptr"] = [0] - -# device = None -# slices = {key: [0] for key in keys} -# cumsum = {key: [0] for key in keys} -# cat_dims = {} -# num_nodes_list = [] -# for i, data in enumerate(data_list): -# for key in keys: -# item = data[key] - -# # Increase values by `cumsum` value. -# cum = cumsum[key][-1] -# if isinstance(item, Tensor) and item.dtype != torch.bool: -# if not isinstance(cum, int) or cum != 0: -# item = item + cum -# elif isinstance(item, (int, float)): -# item = item + cum - -# # Gather the size of the `cat` dimension. -# size = 1 -# cat_dim = data.__cat_dim__(key, data[key]) -# # 0-dimensional tensors have no dimension along which to -# # concatenate, so we set `cat_dim` to `None`. -# if isinstance(item, Tensor) and item.dim() == 0: -# cat_dim = None -# cat_dims[key] = cat_dim - -# # Add a batch dimension to items whose `cat_dim` is `None`: -# if isinstance(item, Tensor) and cat_dim is None: -# cat_dim = 0 # Concatenate along this new batch dimension. -# item = item.unsqueeze(0) -# device = item.device -# elif isinstance(item, Tensor): -# size = item.size(cat_dim) -# device = item.device - -# batch[key].append(item) # Append item to the attribute list. - -# slices[key].append(size + slices[key][-1]) -# inc = data.__inc__(key, item) -# if isinstance(inc, (tuple, list)): -# inc = torch.tensor(inc) -# cumsum[key].append(inc + cumsum[key][-1]) - -# if key in follow_batch: -# if isinstance(size, Tensor): -# for j, size in enumerate(size.tolist()): -# tmp = f"{key}_{j}_batch" -# batch[tmp] = [] if i == 0 else batch[tmp] -# batch[tmp].append( -# torch.full((size,), i, dtype=torch.long, device=device) -# ) -# else: -# tmp = f"{key}_batch" -# batch[tmp] = [] if i == 0 else batch[tmp] -# batch[tmp].append( -# torch.full((size,), i, dtype=torch.long, device=device) -# ) - -# if hasattr(data, "__num_nodes__"): -# num_nodes_list.append(data.__num_nodes__) -# else: -# num_nodes_list.append(None) - -# num_nodes = data.num_nodes -# if num_nodes is not None: -# item = torch.full((num_nodes,), i, dtype=torch.long, device=device) -# batch.batch.append(item) -# batch.ptr.append(batch.ptr[-1] + num_nodes) - -# batch.batch = None if len(batch.batch) == 0 else batch.batch -# batch.ptr = None if len(batch.ptr) == 1 else batch.ptr -# batch.__slices__ = slices -# batch.__cumsum__ = cumsum -# batch.__cat_dims__ = cat_dims -# batch.__num_nodes_list__ = num_nodes_list - -# ref_data = data_list[0] -# for key in batch.keys: -# items = batch[key] -# item = items[0] -# cat_dim = ref_data.__cat_dim__(key, item) -# cat_dim = 0 if cat_dim is None else cat_dim -# if isinstance(item, Tensor): -# batch[key] = torch.cat(items, cat_dim) -# elif isinstance(item, (int, float)): -# batch[key] = torch.tensor(items) - -# # if torch_geometric.is_debug_enabled(): -# # batch.debug() - -# return batch.contiguous() - -# def get_example(self, idx: int) -> Data: -# r"""Reconstructs the :class:`torch_geometric.data.Data` object at index -# :obj:`idx` from the batch object. -# The batch object must have been created via :meth:`from_data_list` in -# order to be able to reconstruct the initial objects.""" - -# if self.__slices__ is None: -# raise RuntimeError( -# ( -# "Cannot reconstruct data list from batch because the batch " -# "object was not created using `Batch.from_data_list()`." -# ) -# ) - -# data = self.__data_class__() -# idx = self.num_graphs + idx if idx < 0 else idx - -# for key in self.__slices__.keys(): -# item = self[key] -# if self.__cat_dims__[key] is None: -# # The item was concatenated along a new batch dimension, -# # so just index in that dimension: -# item = item[idx] -# else: -# # Narrow the item based on the values in `__slices__`. -# if isinstance(item, Tensor): -# dim = self.__cat_dims__[key] -# start = self.__slices__[key][idx] -# end = self.__slices__[key][idx + 1] -# item = item.narrow(dim, start, end - start) -# else: -# start = self.__slices__[key][idx] -# end = self.__slices__[key][idx + 1] -# item = item[start:end] -# item = item[0] if len(item) == 1 else item - -# # Decrease its value by `cumsum` value: -# cum = self.__cumsum__[key][idx] -# if isinstance(item, Tensor): -# if not isinstance(cum, int) or cum != 0: -# item = item - cum -# elif isinstance(item, (int, float)): -# item = item - cum - -# data[key] = item - -# if self.__num_nodes_list__[idx] is not None: -# data.num_nodes = self.__num_nodes_list__[idx] - -# return data - -# def index_select(self, idx: IndexType) -> List[Data]: -# if isinstance(idx, slice): -# idx = list(range(self.num_graphs)[idx]) - -# elif isinstance(idx, Tensor) and idx.dtype == torch.long: -# idx = idx.flatten().tolist() - -# elif isinstance(idx, Tensor) and idx.dtype == torch.bool: -# idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() - -# elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: -# idx = idx.flatten().tolist() - -# elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: -# idx = idx.flatten().nonzero()[0].flatten().tolist() - -# elif isinstance(idx, Sequence) and not isinstance(idx, str): -# pass - -# else: -# raise IndexError( -# f"Only integers, slices (':'), list, tuples, torch.tensor and " -# f"np.ndarray of dtype long or bool are valid indices (got " -# f"'{type(idx).__name__}')" -# ) - -# return [self.get_example(i) for i in idx] - -# def __getitem__(self, idx): -# if isinstance(idx, str): -# return super(Batch, self).__getitem__(idx) -# elif isinstance(idx, (int, np.integer)): -# return self.get_example(idx) -# else: -# return self.index_select(idx) - -# def to_data_list(self) -> List[Data]: -# r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects -# from the batch object. -# The batch object must have been created via :meth:`from_data_list` in -# order to be able to reconstruct the initial objects.""" -# return [self.get_example(i) for i in range(self.num_graphs)] - -# @property -# def num_graphs(self) -> int: -# """Returns the number of graphs in the batch.""" -# if self.__num_graphs__ is not None: -# return self.__num_graphs__ -# elif self.ptr is not None: -# return self.ptr.numel() - 1 -# elif self.batch is not None: -# return int(self.batch.max()) + 1 -# else: -# raise ValueError diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/data.py b/hydragnn/utils/mace_utils/tools/torch_geometric/data.py deleted file mode 100644 index 6fdd25d47..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/data.py +++ /dev/null @@ -1,441 +0,0 @@ -# import collections -# import copy -# import re - -# import torch - -# # from ..utils.num_nodes import maybe_num_nodes - -# __num_nodes_warn_msg__ = ( -# "The number of nodes in your data object can only be inferred by its {} " -# "indices, and hence may result in unexpected batch-wise behavior, e.g., " -# "in case there exists isolated nodes. Please consider explicitly setting " -# "the number of nodes for this data object by assigning it to " -# "data.num_nodes." -# ) - - -# def size_repr(key, item, indent=0): -# indent_str = " " * indent -# if torch.is_tensor(item) and item.dim() == 0: -# out = item.item() -# elif torch.is_tensor(item): -# out = str(list(item.size())) -# elif isinstance(item, list) or isinstance(item, tuple): -# out = str([len(item)]) -# elif isinstance(item, dict): -# lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] -# out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" -# elif isinstance(item, str): -# out = f'"{item}"' -# else: -# out = str(item) - -# return f"{indent_str}{key}={out}" - - -# class Data(object): -# r"""A plain old python object modeling a single graph with various -# (optional) attributes: - -# Args: -# x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, -# num_node_features]`. (default: :obj:`None`) -# edge_index (LongTensor, optional): Graph connectivity in COO format -# with shape :obj:`[2, num_edges]`. (default: :obj:`None`) -# edge_attr (Tensor, optional): Edge feature matrix with shape -# :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) -# y (Tensor, optional): Graph or node targets with arbitrary shape. -# (default: :obj:`None`) -# pos (Tensor, optional): Node position matrix with shape -# :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) -# normal (Tensor, optional): Normal vector matrix with shape -# :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) -# face (LongTensor, optional): Face adjacency matrix with shape -# :obj:`[3, num_faces]`. (default: :obj:`None`) - -# The data object is not restricted to these attributes and can be extended -# by any other additional data. - -# Example:: - -# data = Data(x=x, edge_index=edge_index) -# data.train_idx = torch.tensor([...], dtype=torch.long) -# data.test_mask = torch.tensor([...], dtype=torch.bool) -# """ - -# def __init__( -# self, -# x=None, -# edge_index=None, -# edge_attr=None, -# y=None, -# pos=None, -# normal=None, -# face=None, -# **kwargs, -# ): -# self.x = x -# self.edge_index = edge_index -# self.edge_attr = edge_attr -# self.y = y -# self.pos = pos -# self.normal = normal -# self.face = face -# for key, item in kwargs.items(): -# if key == "num_nodes": -# self.__num_nodes__ = item -# else: -# self[key] = item - -# if edge_index is not None and edge_index.dtype != torch.long: -# raise ValueError( -# ( -# f"Argument `edge_index` needs to be of type `torch.long` but " -# f"found type `{edge_index.dtype}`." -# ) -# ) - -# if face is not None and face.dtype != torch.long: -# raise ValueError( -# ( -# f"Argument `face` needs to be of type `torch.long` but found " -# f"type `{face.dtype}`." -# ) -# ) - -# @classmethod -# def from_dict(cls, dictionary): -# r"""Creates a data object from a python dictionary.""" -# data = cls() - -# for key, item in dictionary.items(): -# data[key] = item - -# return data - -# def to_dict(self): -# return {key: item for key, item in self} - -# def to_namedtuple(self): -# keys = self.keys -# DataTuple = collections.namedtuple("DataTuple", keys) -# return DataTuple(*[self[key] for key in keys]) - -# def __getitem__(self, key): -# r"""Gets the data of the attribute :obj:`key`.""" -# return getattr(self, key, None) - -# def __setitem__(self, key, value): -# """Sets the attribute :obj:`key` to :obj:`value`.""" -# setattr(self, key, value) - -# def __delitem__(self, key): -# r"""Delete the data of the attribute :obj:`key`.""" -# return delattr(self, key) - -# @property -# def keys(self): -# r"""Returns all names of graph attributes.""" -# keys = [key for key in self.__dict__.keys() if self[key] is not None] -# keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] -# return keys - -# def __len__(self): -# r"""Returns the number of all present attributes.""" -# return len(self.keys) - -# def __contains__(self, key): -# r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the -# data.""" -# return key in self.keys - -# def __iter__(self): -# r"""Iterates over all present attributes in the data, yielding their -# attribute names and content.""" -# for key in sorted(self.keys): -# yield key, self[key] - -# def __call__(self, *keys): -# r"""Iterates over all attributes :obj:`*keys` in the data, yielding -# their attribute names and content. -# If :obj:`*keys` is not given this method will iterative over all -# present attributes.""" -# for key in sorted(self.keys) if not keys else keys: -# if key in self: -# yield key, self[key] - -# def __cat_dim__(self, key, value): -# r"""Returns the dimension for which :obj:`value` of attribute -# :obj:`key` will get concatenated when creating batches. - -# .. note:: - -# This method is for internal use only, and should only be overridden -# if the batch concatenation process is corrupted for a specific data -# attribute. -# """ -# if bool(re.search("(index|face)", key)): -# return -1 -# return 0 - -# def __inc__(self, key, value): -# r"""Returns the incremental count to cumulatively increase the value -# of the next attribute of :obj:`key` when creating batches. - -# .. note:: - -# This method is for internal use only, and should only be overridden -# if the batch concatenation process is corrupted for a specific data -# attribute. -# """ -# # Only `*index*` and `*face*` attributes should be cumulatively summed -# # up when creating batches. -# return self.num_nodes if bool(re.search("(index|face)", key)) else 0 - -# @property -# def num_nodes(self): -# r"""Returns or sets the number of nodes in the graph. - -# .. note:: -# The number of nodes in your data object is typically automatically -# inferred, *e.g.*, when node features :obj:`x` are present. -# In some cases however, a graph may only be given by its edge -# indices :obj:`edge_index`. -# PyTorch Geometric then *guesses* the number of nodes -# according to :obj:`edge_index.max().item() + 1`, but in case there -# exists isolated nodes, this number has not to be correct and can -# therefore result in unexpected batch-wise behavior. -# Thus, we recommend to set the number of nodes in your data object -# explicitly via :obj:`data.num_nodes = ...`. -# You will be given a warning that requests you to do so. -# """ -# if hasattr(self, "__num_nodes__"): -# return self.__num_nodes__ -# for key, item in self("x", "pos", "normal", "batch"): -# return item.size(self.__cat_dim__(key, item)) -# if hasattr(self, "adj"): -# return self.adj.size(0) -# if hasattr(self, "adj_t"): -# return self.adj_t.size(1) -# # if self.face is not None: -# # logging.warning(__num_nodes_warn_msg__.format("face")) -# # return maybe_num_nodes(self.face) -# # if self.edge_index is not None: -# # logging.warning(__num_nodes_warn_msg__.format("edge")) -# # return maybe_num_nodes(self.edge_index) -# return None - -# @num_nodes.setter -# def num_nodes(self, num_nodes): -# self.__num_nodes__ = num_nodes - -# @property -# def num_edges(self): -# """ -# Returns the number of edges in the graph. -# For undirected graphs, this will return the number of bi-directional -# edges, which is double the amount of unique edges. -# """ -# for key, item in self("edge_index", "edge_attr"): -# return item.size(self.__cat_dim__(key, item)) -# for key, item in self("adj", "adj_t"): -# return item.nnz() -# return None - -# @property -# def num_faces(self): -# r"""Returns the number of faces in the mesh.""" -# if self.face is not None: -# return self.face.size(self.__cat_dim__("face", self.face)) -# return None - -# @property -# def num_node_features(self): -# r"""Returns the number of features per node in the graph.""" -# if self.x is None: -# return 0 -# return 1 if self.x.dim() == 1 else self.x.size(1) - -# @property -# def num_features(self): -# r"""Alias for :py:attr:`~num_node_features`.""" -# return self.num_node_features - -# @property -# def num_edge_features(self): -# r"""Returns the number of features per edge in the graph.""" -# if self.edge_attr is None: -# return 0 -# return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) - -# def __apply__(self, item, func): -# if torch.is_tensor(item): -# return func(item) -# elif isinstance(item, (tuple, list)): -# return [self.__apply__(v, func) for v in item] -# elif isinstance(item, dict): -# return {k: self.__apply__(v, func) for k, v in item.items()} -# else: -# return item - -# def apply(self, func, *keys): -# r"""Applies the function :obj:`func` to all tensor attributes -# :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to -# all present attributes. -# """ -# for key, item in self(*keys): -# self[key] = self.__apply__(item, func) -# return self - -# def contiguous(self, *keys): -# r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. -# If :obj:`*keys` is not given, all present attributes are ensured to -# have a contiguous memory layout.""" -# return self.apply(lambda x: x.contiguous(), *keys) - -# def to(self, device, *keys, **kwargs): -# r"""Performs tensor dtype and/or device conversion to all attributes -# :obj:`*keys`. -# If :obj:`*keys` is not given, the conversion is applied to all present -# attributes.""" -# return self.apply(lambda x: x.to(device, **kwargs), *keys) - -# def cpu(self, *keys): -# r"""Copies all attributes :obj:`*keys` to CPU memory. -# If :obj:`*keys` is not given, the conversion is applied to all present -# attributes.""" -# return self.apply(lambda x: x.cpu(), *keys) - -# def cuda(self, device=None, non_blocking=False, *keys): -# r"""Copies all attributes :obj:`*keys` to CUDA memory. -# If :obj:`*keys` is not given, the conversion is applied to all present -# attributes.""" -# return self.apply( -# lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys -# ) - -# def clone(self): -# r"""Performs a deep-copy of the data object.""" -# return self.__class__.from_dict( -# { -# k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) -# for k, v in self.__dict__.items() -# } -# ) - -# def pin_memory(self, *keys): -# r"""Copies all attributes :obj:`*keys` to pinned memory. -# If :obj:`*keys` is not given, the conversion is applied to all present -# attributes.""" -# return self.apply(lambda x: x.pin_memory(), *keys) - -# def debug(self): -# if self.edge_index is not None: -# if self.edge_index.dtype != torch.long: -# raise RuntimeError( -# ( -# "Expected edge indices of dtype {}, but found dtype " " {}" -# ).format(torch.long, self.edge_index.dtype) -# ) - -# if self.face is not None: -# if self.face.dtype != torch.long: -# raise RuntimeError( -# ( -# "Expected face indices of dtype {}, but found dtype " " {}" -# ).format(torch.long, self.face.dtype) -# ) - -# if self.edge_index is not None: -# if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: -# raise RuntimeError( -# ( -# "Edge indices should have shape [2, num_edges] but found" -# " shape {}" -# ).format(self.edge_index.size()) -# ) - -# if self.edge_index is not None and self.num_nodes is not None: -# if self.edge_index.numel() > 0: -# min_index = self.edge_index.min() -# max_index = self.edge_index.max() -# else: -# min_index = max_index = 0 -# if min_index < 0 or max_index > self.num_nodes - 1: -# raise RuntimeError( -# ( -# "Edge indices must lay in the interval [0, {}]" -# " but found them in the interval [{}, {}]" -# ).format(self.num_nodes - 1, min_index, max_index) -# ) - -# if self.face is not None: -# if self.face.dim() != 2 or self.face.size(0) != 3: -# raise RuntimeError( -# ( -# "Face indices should have shape [3, num_faces] but found" -# " shape {}" -# ).format(self.face.size()) -# ) - -# if self.face is not None and self.num_nodes is not None: -# if self.face.numel() > 0: -# min_index = self.face.min() -# max_index = self.face.max() -# else: -# min_index = max_index = 0 -# if min_index < 0 or max_index > self.num_nodes - 1: -# raise RuntimeError( -# ( -# "Face indices must lay in the interval [0, {}]" -# " but found them in the interval [{}, {}]" -# ).format(self.num_nodes - 1, min_index, max_index) -# ) - -# if self.edge_index is not None and self.edge_attr is not None: -# if self.edge_index.size(1) != self.edge_attr.size(0): -# raise RuntimeError( -# ( -# "Edge indices and edge attributes hold a differing " -# "number of edges, found {} and {}" -# ).format(self.edge_index.size(), self.edge_attr.size()) -# ) - -# if self.x is not None and self.num_nodes is not None: -# if self.x.size(0) != self.num_nodes: -# raise RuntimeError( -# ( -# "Node features should hold {} elements in the first " -# "dimension but found {}" -# ).format(self.num_nodes, self.x.size(0)) -# ) - -# if self.pos is not None and self.num_nodes is not None: -# if self.pos.size(0) != self.num_nodes: -# raise RuntimeError( -# ( -# "Node positions should hold {} elements in the first " -# "dimension but found {}" -# ).format(self.num_nodes, self.pos.size(0)) -# ) - -# if self.normal is not None and self.num_nodes is not None: -# if self.normal.size(0) != self.num_nodes: -# raise RuntimeError( -# ( -# "Node normals should hold {} elements in the first " -# "dimension but found {}" -# ).format(self.num_nodes, self.normal.size(0)) -# ) - -# def __repr__(self): -# cls = str(self.__class__.__name__) -# has_dict = any([isinstance(item, dict) for _, item in self]) - -# if not has_dict: -# info = [size_repr(key, item) for key, item in self] -# return "{}({})".format(cls, ", ".join(info)) -# else: -# info = [size_repr(key, item, indent=2) for key, item in self] -# return "{}(\n{}\n)".format(cls, ",\n".join(info)) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py b/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py deleted file mode 100644 index 81786ad4a..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/dataloader.py +++ /dev/null @@ -1,87 +0,0 @@ -# from collections.abc import Mapping, Sequence -# from typing import List, Optional, Union - -# import torch.utils.data -# from torch.utils.data.dataloader import default_collate - -# from .batch import Batch -# from .data import Data -# from .dataset import Dataset - - -# class Collater: -# def __init__(self, follow_batch, exclude_keys): -# self.follow_batch = follow_batch -# self.exclude_keys = exclude_keys - -# def __call__(self, batch): -# elem = batch[0] -# if isinstance(elem, Data): -# return Batch.from_data_list( -# batch, -# follow_batch=self.follow_batch, -# exclude_keys=self.exclude_keys, -# ) -# elif isinstance(elem, torch.Tensor): -# return default_collate(batch) -# elif isinstance(elem, float): -# return torch.tensor(batch, dtype=torch.float) -# elif isinstance(elem, int): -# return torch.tensor(batch) -# elif isinstance(elem, str): -# return batch -# elif isinstance(elem, Mapping): -# return {key: self([data[key] for data in batch]) for key in elem} -# elif isinstance(elem, tuple) and hasattr(elem, "_fields"): -# return type(elem)(*(self(s) for s in zip(*batch))) -# elif isinstance(elem, Sequence) and not isinstance(elem, str): -# return [self(s) for s in zip(*batch)] - -# raise TypeError(f"DataLoader found invalid type: {type(elem)}") - -# def collate(self, batch): # Deprecated... -# return self(batch) - - -# class DataLoader(torch.utils.data.DataLoader): -# r"""A data loader which merges data objects from a -# :class:`torch_geometric.data.Dataset` to a mini-batch. -# Data objects can be either of type :class:`~torch_geometric.data.Data` or -# :class:`~torch_geometric.data.HeteroData`. -# Args: -# dataset (Dataset): The dataset from which to load the data. -# batch_size (int, optional): How many samples per batch to load. -# (default: :obj:`1`) -# shuffle (bool, optional): If set to :obj:`True`, the data will be -# reshuffled at every epoch. (default: :obj:`False`) -# follow_batch (List[str], optional): Creates assignment batch -# vectors for each key in the list. (default: :obj:`None`) -# exclude_keys (List[str], optional): Will exclude each key in the -# list. (default: :obj:`None`) -# **kwargs (optional): Additional arguments of -# :class:`torch.utils.data.DataLoader`. -# """ - -# def __init__( -# self, -# dataset: Dataset, -# batch_size: int = 1, -# shuffle: bool = False, -# follow_batch: Optional[List[str]] = [None], -# exclude_keys: Optional[List[str]] = [None], -# **kwargs, -# ): -# if "collate_fn" in kwargs: -# del kwargs["collate_fn"] - -# # Save for PyTorch Lightning < 1.6: -# self.follow_batch = follow_batch -# self.exclude_keys = exclude_keys - -# super().__init__( -# dataset, -# batch_size, -# shuffle, -# collate_fn=Collater(follow_batch, exclude_keys), -# **kwargs, -# ) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py b/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py deleted file mode 100644 index 4e0add3c3..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/dataset.py +++ /dev/null @@ -1,280 +0,0 @@ -# import copy -# import os.path as osp -# import re -# import warnings -# from collections.abc import Sequence -# from typing import Any, Callable, List, Optional, Tuple, Union - -# import numpy as np -# import torch.utils.data -# from torch import Tensor - -# from .data import Data -# from .utils import makedirs - -# IndexType = Union[slice, Tensor, np.ndarray, Sequence] - - -# class Dataset(torch.utils.data.Dataset): -# r"""Dataset base class for creating graph datasets. -# See `here `__ for the accompanying tutorial. - -# Args: -# root (string, optional): Root directory where the dataset should be -# saved. (optional: :obj:`None`) -# transform (callable, optional): A function/transform that takes in an -# :obj:`torch_geometric.data.Data` object and returns a transformed -# version. The data object will be transformed before every access. -# (default: :obj:`None`) -# pre_transform (callable, optional): A function/transform that takes in -# an :obj:`torch_geometric.data.Data` object and returns a -# transformed version. The data object will be transformed before -# being saved to disk. (default: :obj:`None`) -# pre_filter (callable, optional): A function that takes in an -# :obj:`torch_geometric.data.Data` object and returns a boolean -# value, indicating whether the data object should be included in the -# final dataset. (default: :obj:`None`) -# """ - -# @property -# def raw_file_names(self) -> Union[str, List[str], Tuple]: -# r"""The name of the files to find in the :obj:`self.raw_dir` folder in -# order to skip the download.""" -# raise NotImplementedError - -# @property -# def processed_file_names(self) -> Union[str, List[str], Tuple]: -# r"""The name of the files to find in the :obj:`self.processed_dir` -# folder in order to skip the processing.""" -# raise NotImplementedError - -# def download(self): -# r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" -# raise NotImplementedError - -# def process(self): -# r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" -# raise NotImplementedError - -# def len(self) -> int: -# raise NotImplementedError - -# def get(self, idx: int) -> Data: -# r"""Gets the data object at index :obj:`idx`.""" -# raise NotImplementedError - -# def __init__( -# self, -# root: Optional[str] = None, -# transform: Optional[Callable] = None, -# pre_transform: Optional[Callable] = None, -# pre_filter: Optional[Callable] = None, -# ): -# super().__init__() - -# if isinstance(root, str): -# root = osp.expanduser(osp.normpath(root)) - -# self.root = root -# self.transform = transform -# self.pre_transform = pre_transform -# self.pre_filter = pre_filter -# self._indices: Optional[Sequence] = None - -# if "download" in self.__class__.__dict__.keys(): -# self._download() - -# if "process" in self.__class__.__dict__.keys(): -# self._process() - -# def indices(self) -> Sequence: -# return range(self.len()) if self._indices is None else self._indices - -# @property -# def raw_dir(self) -> str: -# return osp.join(self.root, "raw") - -# @property -# def processed_dir(self) -> str: -# return osp.join(self.root, "processed") - -# @property -# def num_node_features(self) -> int: -# r"""Returns the number of features per node in the dataset.""" -# data = self[0] -# if hasattr(data, "num_node_features"): -# return data.num_node_features -# raise AttributeError( -# f"'{data.__class__.__name__}' object has no " -# f"attribute 'num_node_features'" -# ) - -# @property -# def num_features(self) -> int: -# r"""Alias for :py:attr:`~num_node_features`.""" -# return self.num_node_features - -# @property -# def num_edge_features(self) -> int: -# r"""Returns the number of features per edge in the dataset.""" -# data = self[0] -# if hasattr(data, "num_edge_features"): -# return data.num_edge_features -# raise AttributeError( -# f"'{data.__class__.__name__}' object has no " -# f"attribute 'num_edge_features'" -# ) - -# @property -# def raw_paths(self) -> List[str]: -# r"""The filepaths to find in order to skip the download.""" -# files = to_list(self.raw_file_names) -# return [osp.join(self.raw_dir, f) for f in files] - -# @property -# def processed_paths(self) -> List[str]: -# r"""The filepaths to find in the :obj:`self.processed_dir` -# folder in order to skip the processing.""" -# files = to_list(self.processed_file_names) -# return [osp.join(self.processed_dir, f) for f in files] - -# def _download(self): -# if files_exist(self.raw_paths): # pragma: no cover -# return - -# makedirs(self.raw_dir) -# self.download() - -# def _process(self): -# f = osp.join(self.processed_dir, "pre_transform.pt") -# if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): -# warnings.warn( -# f"The `pre_transform` argument differs from the one used in " -# f"the pre-processed version of this dataset. If you want to " -# f"make use of another pre-processing technique, make sure to " -# f"sure to delete '{self.processed_dir}' first" -# ) - -# f = osp.join(self.processed_dir, "pre_filter.pt") -# if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): -# warnings.warn( -# "The `pre_filter` argument differs from the one used in the " -# "pre-processed version of this dataset. If you want to make " -# "use of another pre-fitering technique, make sure to delete " -# "'{self.processed_dir}' first" -# ) - -# if files_exist(self.processed_paths): # pragma: no cover -# return - -# print("Processing...") - -# makedirs(self.processed_dir) -# self.process() - -# path = osp.join(self.processed_dir, "pre_transform.pt") -# torch.save(_repr(self.pre_transform), path) -# path = osp.join(self.processed_dir, "pre_filter.pt") -# torch.save(_repr(self.pre_filter), path) - -# print("Done!") - -# def __len__(self) -> int: -# r"""The number of examples in the dataset.""" -# return len(self.indices()) - -# def __getitem__( -# self, -# idx: Union[int, np.integer, IndexType], -# ) -> Union["Dataset", Data]: -# r"""In case :obj:`idx` is of type integer, will return the data object -# at index :obj:`idx` (and transforms it in case :obj:`transform` is -# present). -# In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a -# tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy -# :obj:`np.array`, will return a subset of the dataset at the specified -# indices.""" -# if ( -# isinstance(idx, (int, np.integer)) -# or (isinstance(idx, Tensor) and idx.dim() == 0) -# or (isinstance(idx, np.ndarray) and np.isscalar(idx)) -# ): -# data = self.get(self.indices()[idx]) -# data = data if self.transform is None else self.transform(data) -# return data - -# else: -# return self.index_select(idx) - -# def index_select(self, idx: IndexType) -> "Dataset": -# indices = self.indices() - -# if isinstance(idx, slice): -# indices = indices[idx] - -# elif isinstance(idx, Tensor) and idx.dtype == torch.long: -# return self.index_select(idx.flatten().tolist()) - -# elif isinstance(idx, Tensor) and idx.dtype == torch.bool: -# idx = idx.flatten().nonzero(as_tuple=False) -# return self.index_select(idx.flatten().tolist()) - -# elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: -# return self.index_select(idx.flatten().tolist()) - -# elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: -# idx = idx.flatten().nonzero()[0] -# return self.index_select(idx.flatten().tolist()) - -# elif isinstance(idx, Sequence) and not isinstance(idx, str): -# indices = [indices[i] for i in idx] - -# else: -# raise IndexError( -# f"Only integers, slices (':'), list, tuples, torch.tensor and " -# f"np.ndarray of dtype long or bool are valid indices (got " -# f"'{type(idx).__name__}')" -# ) - -# dataset = copy.copy(self) -# dataset._indices = indices -# return dataset - -# def shuffle( -# self, -# return_perm: bool = False, -# ) -> Union["Dataset", Tuple["Dataset", Tensor]]: -# r"""Randomly shuffles the examples in the dataset. - -# Args: -# return_perm (bool, optional): If set to :obj:`True`, will return -# the random permutation used to shuffle the dataset in addition. -# (default: :obj:`False`) -# """ -# perm = torch.randperm(len(self)) -# dataset = self.index_select(perm) -# return (dataset, perm) if return_perm is True else dataset - -# def __repr__(self) -> str: -# arg_repr = str(len(self)) if len(self) > 1 else "" -# return f"{self.__class__.__name__}({arg_repr})" - - -# def to_list(value: Any) -> Sequence: -# if isinstance(value, Sequence) and not isinstance(value, str): -# return value -# else: -# return [value] - - -# def files_exist(files: List[str]) -> bool: -# # NOTE: We return `False` in case `files` is empty, leading to a -# # re-processing of files on every instantiation. -# return len(files) != 0 and all([osp.exists(f) for f in files]) - - -# def _repr(obj: Any) -> str: -# if obj is None: -# return "None" -# return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py b/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py deleted file mode 100644 index 2222226f4..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/seed.py +++ /dev/null @@ -1,17 +0,0 @@ -# import random - -# import numpy as np -# import torch - - -# def seed_everything(seed: int): -# r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, -# :obj:`numpy` and Python. - -# Args: -# seed (int): The desired seed. -# """ -# random.seed(seed) -# np.random.seed(seed) -# torch.manual_seed(seed) -# torch.cuda.manual_seed_all(seed) diff --git a/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py b/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py deleted file mode 100644 index b03a56ceb..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_geometric/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -# import os -# import os.path as osp -# import ssl -# import urllib -# import zipfile - - -# def makedirs(dir): -# os.makedirs(dir, exist_ok=True) - - -# def download_url(url, folder, log=True): -# r"""Downloads the content of an URL to a specific folder. - -# Args: -# url (string): The url. -# folder (string): The folder. -# log (bool, optional): If :obj:`False`, will not print anything to the -# console. (default: :obj:`True`) -# """ - -# filename = url.rpartition("/")[2].split("?")[0] -# path = osp.join(folder, filename) - -# if osp.exists(path): # pragma: no cover -# if log: -# print("Using exist file", filename) -# return path - -# if log: -# print("Downloading", url) - -# makedirs(folder) - -# context = ssl._create_unverified_context() -# data = urllib.request.urlopen(url, context=context) - -# with open(path, "wb") as f: -# f.write(data.read()) - -# return path - - -# def extract_zip(path, folder, log=True): -# r"""Extracts a zip archive to a specific folder. - -# Args: -# path (string): The path to the tar archive. -# folder (string): The folder. -# log (bool, optional): If :obj:`False`, will not print anything to the -# console. (default: :obj:`True`) -# """ -# with zipfile.ZipFile(path, "r") as f: -# f.extractall(folder) From b2a95983da928764ae50632ce58ccc10640912d0 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 27 Sep 2024 12:37:34 -0400 Subject: [PATCH 32/51] distributed, parsing, and checkpointing utils taken out (draft 3) --- hydragnn/utils/mace_utils/modules/utils.py | 2 - hydragnn/utils/mace_utils/tools/__init__.py | 17 - hydragnn/utils/mace_utils/tools/arg_parser.py | 792 ------------------ .../mace_utils/tools/arg_parser_tools.py | 113 --- hydragnn/utils/mace_utils/tools/checkpoint.py | 227 ----- .../utils/mace_utils/tools/scripts_utils.py | 653 --------------- .../mace_utils/tools/slurm_distributed.py | 34 - hydragnn/utils/mace_utils/tools/train.py | 524 ------------ 8 files changed, 2362 deletions(-) delete mode 100644 hydragnn/utils/mace_utils/tools/arg_parser.py delete mode 100644 hydragnn/utils/mace_utils/tools/arg_parser_tools.py delete mode 100644 hydragnn/utils/mace_utils/tools/checkpoint.py delete mode 100644 hydragnn/utils/mace_utils/tools/scripts_utils.py delete mode 100644 hydragnn/utils/mace_utils/tools/slurm_distributed.py delete mode 100644 hydragnn/utils/mace_utils/tools/train.py diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index 52eab322d..488cd90a3 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -16,8 +16,6 @@ from hydragnn.utils.mace_utils.tools import to_numpy from hydragnn.utils.mace_utils.tools.scatter import scatter_sum -# from hydragnn.utils.mace_utils.tools.torch_geometric.batch import Batch - from .blocks import AtomicEnergiesBlock diff --git a/hydragnn/utils/mace_utils/tools/__init__.py b/hydragnn/utils/mace_utils/tools/__init__.py index 8d1e7bc22..3703f3152 100644 --- a/hydragnn/utils/mace_utils/tools/__init__.py +++ b/hydragnn/utils/mace_utils/tools/__init__.py @@ -18,23 +18,6 @@ voigt_to_matrix, ) -# from .train import SWAContainer, evaluate, train -# from .utils import ( -# AtomicNumberTable, -# MetricsLogger, -# atomic_numbers_to_indices, -# compute_c, -# compute_mae, -# compute_q95, -# compute_rel_mae, -# compute_rel_rmse, -# compute_rmse, -# get_atomic_number_table_from_zs, -# get_optimizer, -# get_tag, -# setup_logger, -# ) - __all__ = [ "TensorDict", "AtomicNumberTable", diff --git a/hydragnn/utils/mace_utils/tools/arg_parser.py b/hydragnn/utils/mace_utils/tools/arg_parser.py deleted file mode 100644 index 73e1e9d24..000000000 --- a/hydragnn/utils/mace_utils/tools/arg_parser.py +++ /dev/null @@ -1,792 +0,0 @@ -# ########################################################################################### -# # Parsing functionalities -# # Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# import argparse -# import os -# from typing import Optional - - -# def build_default_arg_parser() -> argparse.ArgumentParser: -# try: -# import configargparse - -# parser = configargparse.ArgumentParser( -# config_file_parser_class=configargparse.YAMLConfigFileParser, -# ) -# parser.add( -# "--config", -# type=str, -# is_config_file=True, -# help="config file to agregate options", -# ) -# except ImportError: -# parser = argparse.ArgumentParser() - -# # Name and seed -# parser.add_argument("--name", help="experiment name", required=True) -# parser.add_argument("--seed", help="random seed", type=int, default=123) - -# # Directories -# parser.add_argument( -# "--work_dir", -# help="set directory for all files and folders", -# type=str, -# default=".", -# ) -# parser.add_argument( -# "--log_dir", help="directory for log files", type=str, default=None -# ) -# parser.add_argument( -# "--model_dir", help="directory for final model", type=str, default=None -# ) -# parser.add_argument( -# "--checkpoints_dir", -# help="directory for checkpoint files", -# type=str, -# default=None, -# ) -# parser.add_argument( -# "--results_dir", help="directory for results", type=str, default=None -# ) -# parser.add_argument( -# "--downloads_dir", help="directory for downloads", type=str, default=None -# ) - -# # Device and logging -# parser.add_argument( -# "--device", -# help="select device", -# type=str, -# choices=["cpu", "cuda", "mps"], -# default="cpu", -# ) -# parser.add_argument( -# "--default_dtype", -# help="set default dtype", -# type=str, -# choices=["float32", "float64"], -# default="float64", -# ) -# parser.add_argument( -# "--distributed", -# help="train in multi-GPU data parallel mode", -# action="store_true", -# default=False, -# ) -# parser.add_argument("--log_level", help="log level", type=str, default="INFO") - -# parser.add_argument( -# "--error_table", -# help="Type of error table produced at the end of the training", -# type=str, -# choices=[ -# "PerAtomRMSE", -# "TotalRMSE", -# "PerAtomRMSEstressvirials", -# "PerAtomMAEstressvirials", -# "PerAtomMAE", -# "TotalMAE", -# "DipoleRMSE", -# "DipoleMAE", -# "EnergyDipoleRMSE", -# ], -# default="PerAtomRMSE", -# ) - -# # Model -# parser.add_argument( -# "--model", -# help="model type", -# default="MACE", -# choices=[ -# "BOTNet", -# "MACE", -# "ScaleShiftMACE", -# "ScaleShiftBOTNet", -# "AtomicDipolesMACE", -# "EnergyDipolesMACE", -# ], -# ) -# parser.add_argument( -# "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 -# ) -# parser.add_argument( -# "--radial_type", -# help="type of radial basis functions", -# type=str, -# default="bessel", -# choices=["bessel", "gaussian", "chebyshev"], -# ) -# parser.add_argument( -# "--num_radial_basis", -# help="number of radial basis functions", -# type=int, -# default=8, -# ) -# parser.add_argument( -# "--num_cutoff_basis", -# help="number of basis functions for smooth cutoff", -# type=int, -# default=5, -# ) -# parser.add_argument( -# "--pair_repulsion", -# help="use pair repulsion term with ZBL potential", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--distance_transform", -# help="use distance transform for radial basis functions", -# default="None", -# choices=["None", "Agnesi", "Soft"], -# ) -# parser.add_argument( -# "--interaction", -# help="name of interaction block", -# type=str, -# default="RealAgnosticResidualInteractionBlock", -# choices=[ -# "RealAgnosticResidualInteractionBlock", -# "RealAgnosticAttResidualInteractionBlock", -# "RealAgnosticInteractionBlock", -# ], -# ) -# parser.add_argument( -# "--interaction_first", -# help="name of interaction block", -# type=str, -# default="RealAgnosticResidualInteractionBlock", -# choices=[ -# "RealAgnosticResidualInteractionBlock", -# "RealAgnosticInteractionBlock", -# ], -# ) -# parser.add_argument( -# "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 -# ) -# parser.add_argument( -# "--correlation", help="correlation order at each layer", type=int, default=3 -# ) -# parser.add_argument( -# "--num_interactions", help="number of interactions", type=int, default=2 -# ) -# parser.add_argument( -# "--MLP_irreps", -# help="hidden irreps of the MLP in last readout", -# type=str, -# default="16x0e", -# ) -# parser.add_argument( -# "--radial_MLP", -# help="width of the radial MLP", -# type=str, -# default="[64, 64, 64]", -# ) -# parser.add_argument( -# "--hidden_irreps", -# help="irreps for hidden node states", -# type=str, -# default=None, -# ) -# # add option to specify irreps by channel number and max L -# parser.add_argument( -# "--num_channels", -# help="number of embedding channels", -# type=int, -# default=None, -# ) -# parser.add_argument( -# "--max_L", -# help="max L equivariance of the message", -# type=int, -# default=None, -# ) -# parser.add_argument( -# "--gate", -# help="non linearity for last readout", -# type=str, -# default="silu", -# choices=["silu", "tanh", "abs", "None"], -# ) -# parser.add_argument( -# "--scaling", -# help="type of scaling to the output", -# type=str, -# default="rms_forces_scaling", -# choices=["std_scaling", "rms_forces_scaling", "no_scaling"], -# ) -# parser.add_argument( -# "--avg_num_neighbors", -# help="normalization factor for the message", -# type=float, -# default=1, -# ) -# parser.add_argument( -# "--compute_avg_num_neighbors", -# help="normalization factor for the message", -# type=bool, -# default=True, -# ) -# parser.add_argument( -# "--compute_stress", -# help="Select True to compute stress", -# type=bool, -# default=False, -# ) -# parser.add_argument( -# "--compute_forces", -# help="Select True to compute forces", -# type=bool, -# default=True, -# ) - -# # Dataset -# parser.add_argument( -# "--train_file", -# help="Training set file, format is .xyz or .h5", -# type=str, -# required=True, -# ) -# parser.add_argument( -# "--valid_file", -# help="Validation set .xyz or .h5 file", -# default=None, -# type=str, -# required=False, -# ) -# parser.add_argument( -# "--valid_fraction", -# help="Fraction of training set used for validation", -# type=float, -# default=0.1, -# required=False, -# ) -# parser.add_argument( -# "--test_file", -# help="Test set .xyz pt .h5 file", -# type=str, -# ) -# parser.add_argument( -# "--test_dir", -# help="Path to directory with test files named as test_*.h5", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--multi_processed_test", -# help="Boolean value for whether the test data was multiprocessed", -# type=bool, -# default=False, -# required=False, -# ) -# parser.add_argument( -# "--num_workers", -# help="Number of workers for data loading", -# type=int, -# default=0, -# ) -# parser.add_argument( -# "--pin_memory", -# help="Pin memory for data loading", -# default=True, -# type=bool, -# ) -# parser.add_argument( -# "--atomic_numbers", -# help="List of atomic numbers", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--mean", -# help="Mean energy per atom of training set", -# type=float, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--std", -# help="Standard deviation of force components in the training set", -# type=float, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--statistics_file", -# help="json file containing statistics of training set", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--E0s", -# help="Dictionary of isolated atom energies", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--keep_isolated_atoms", -# help="Keep isolated atoms in the dataset, useful for transfer learning", -# type=bool, -# default=False, -# ) -# parser.add_argument( -# "--energy_key", -# help="Key of reference energies in training xyz", -# type=str, -# default="REF_energy", -# ) -# parser.add_argument( -# "--forces_key", -# help="Key of reference forces in training xyz", -# type=str, -# default="REF_forces", -# ) -# parser.add_argument( -# "--virials_key", -# help="Key of reference virials in training xyz", -# type=str, -# default="REF_virials", -# ) -# parser.add_argument( -# "--stress_key", -# help="Key of reference stress in training xyz", -# type=str, -# default="REF_stress", -# ) -# parser.add_argument( -# "--dipole_key", -# help="Key of reference dipoles in training xyz", -# type=str, -# default="REF_dipole", -# ) -# parser.add_argument( -# "--charges_key", -# help="Key of atomic charges in training xyz", -# type=str, -# default="REF_charges", -# ) - -# # Loss and optimization -# parser.add_argument( -# "--loss", -# help="type of loss", -# default="weighted", -# choices=[ -# "ef", -# "weighted", -# "forces_only", -# "virials", -# "stress", -# "dipole", -# "huber", -# "universal", -# "energy_forces_dipole", -# ], -# ) -# parser.add_argument( -# "--forces_weight", help="weight of forces loss", type=float, default=100.0 -# ) -# parser.add_argument( -# "--swa_forces_weight", -# "--stage_two_forces_weight", -# help="weight of forces loss after starting Stage Two (previously called swa)", -# type=float, -# default=100.0, -# dest="swa_forces_weight", -# ) -# parser.add_argument( -# "--energy_weight", help="weight of energy loss", type=float, default=1.0 -# ) -# parser.add_argument( -# "--swa_energy_weight", -# "--stage_two_energy_weight", -# help="weight of energy loss after starting Stage Two (previously called swa)", -# type=float, -# default=1000.0, -# dest="swa_energy_weight", -# ) -# parser.add_argument( -# "--virials_weight", help="weight of virials loss", type=float, default=1.0 -# ) -# parser.add_argument( -# "--swa_virials_weight", -# "--stage_two_virials_weight", -# help="weight of virials loss after starting Stage Two (previously called swa)", -# type=float, -# default=10.0, -# dest="swa_virials_weight", -# ) -# parser.add_argument( -# "--stress_weight", help="weight of virials loss", type=float, default=1.0 -# ) -# parser.add_argument( -# "--swa_stress_weight", -# "--stage_two_stress_weight", -# help="weight of stress loss after starting Stage Two (previously called swa)", -# type=float, -# default=10.0, -# dest="swa_stress_weight", -# ) -# parser.add_argument( -# "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 -# ) -# parser.add_argument( -# "--swa_dipole_weight", -# "--stage_two_dipole_weight", -# help="weight of dipoles after starting Stage Two (previously called swa)", -# type=float, -# default=1.0, -# dest="swa_dipole_weight", -# ) -# parser.add_argument( -# "--config_type_weights", -# help="String of dictionary containing the weights for each config type", -# type=str, -# default='{"Default":1.0}', -# ) -# parser.add_argument( -# "--huber_delta", -# help="delta parameter for huber loss", -# type=float, -# default=0.01, -# ) -# parser.add_argument( -# "--optimizer", -# help="Optimizer for parameter optimization", -# type=str, -# default="adam", -# choices=["adam", "adamw", "schedulefree"], -# ) -# parser.add_argument( -# "--beta", -# help="Beta parameter for the optimizer", -# type=float, -# default=0.9, -# ) -# parser.add_argument("--batch_size", help="batch size", type=int, default=10) -# parser.add_argument( -# "--valid_batch_size", help="Validation batch size", type=int, default=10 -# ) -# parser.add_argument( -# "--lr", help="Learning rate of optimizer", type=float, default=0.01 -# ) -# parser.add_argument( -# "--swa_lr", -# "--stage_two_lr", -# help="Learning rate of optimizer in Stage Two (previously called swa)", -# type=float, -# default=1e-3, -# dest="swa_lr", -# ) -# parser.add_argument( -# "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 -# ) -# parser.add_argument( -# "--amsgrad", -# help="use amsgrad variant of optimizer", -# action="store_true", -# default=True, -# ) -# parser.add_argument( -# "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" -# ) -# parser.add_argument( -# "--lr_factor", help="Learning rate factor", type=float, default=0.8 -# ) -# parser.add_argument( -# "--scheduler_patience", help="Learning rate factor", type=int, default=50 -# ) -# parser.add_argument( -# "--lr_scheduler_gamma", -# help="Gamma of learning rate scheduler", -# type=float, -# default=0.9993, -# ) -# parser.add_argument( -# "--swa", -# "--stage_two", -# help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", -# action="store_true", -# default=False, -# dest="swa", -# ) -# parser.add_argument( -# "--start_swa", -# "--start_stage_two", -# help="Number of epochs before changing to Stage Two loss weights", -# type=int, -# default=None, -# dest="start_swa", -# ) -# parser.add_argument( -# "--ema", -# help="use Exponential Moving Average", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--ema_decay", -# help="Exponential Moving Average decay", -# type=float, -# default=0.99, -# ) -# parser.add_argument( -# "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 -# ) -# parser.add_argument( -# "--patience", -# help="Maximum number of consecutive epochs of increasing loss", -# type=int, -# default=2048, -# ) -# parser.add_argument( -# "--foundation_model", -# help="Path to the foundation model for transfer learning", -# type=str, -# default=None, -# ) -# parser.add_argument( -# "--foundation_model_readout", -# help="Use readout of foundation model for transfer learning", -# action="store_false", -# default=True, -# ) -# parser.add_argument( -# "--eval_interval", help="evaluate model every epochs", type=int, default=1 -# ) -# parser.add_argument( -# "--keep_checkpoints", -# help="keep all checkpoints", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--save_all_checkpoints", -# help="save all checkpoints", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--restart_latest", -# help="restart optimizer from latest checkpoint", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--save_cpu", -# help="Save a model to be loaded on cpu", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--clip_grad", -# help="Gradient Clipping Value", -# type=check_float_or_none, -# default=10.0, -# ) -# # options for using Weights and Biases for experiment tracking -# # to install see https://wandb.ai -# parser.add_argument( -# "--wandb", -# help="Use Weights and Biases for experiment tracking", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--wandb_dir", -# help="An absolute path to a directory where Weights and Biases metadata will be stored", -# type=str, -# default=None, -# ) -# parser.add_argument( -# "--wandb_project", -# help="Weights and Biases project name", -# type=str, -# default="", -# ) -# parser.add_argument( -# "--wandb_entity", -# help="Weights and Biases entity name", -# type=str, -# default="", -# ) -# parser.add_argument( -# "--wandb_name", -# help="Weights and Biases experiment name", -# type=str, -# default="", -# ) -# parser.add_argument( -# "--wandb_log_hypers", -# help="The hyperparameters to log in Weights and Biases", -# type=list, -# default=[ -# "num_channels", -# "max_L", -# "correlation", -# "lr", -# "swa_lr", -# "weight_decay", -# "batch_size", -# "max_num_epochs", -# "start_swa", -# "energy_weight", -# "forces_weight", -# ], -# ) -# return parser - - -# def build_preprocess_arg_parser() -> argparse.ArgumentParser: -# parser = argparse.ArgumentParser() -# parser.add_argument( -# "--train_file", -# help="Training set h5 file", -# type=str, -# default=None, -# required=True, -# ) -# parser.add_argument( -# "--valid_file", -# help="Training set xyz file", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--num_process", -# help="The user defined number of processes to use, as well as the number of files created.", -# type=int, -# default=int(os.cpu_count() / 4), -# ) -# parser.add_argument( -# "--valid_fraction", -# help="Fraction of training set used for validation", -# type=float, -# default=0.1, -# required=False, -# ) -# parser.add_argument( -# "--test_file", -# help="Test set xyz file", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--h5_prefix", -# help="Prefix for h5 files when saving", -# type=str, -# default="", -# ) -# parser.add_argument( -# "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 -# ) -# parser.add_argument( -# "--config_type_weights", -# help="String of dictionary containing the weights for each config type", -# type=str, -# default='{"Default":1.0}', -# ) -# parser.add_argument( -# "--energy_key", -# help="Key of reference energies in training xyz", -# type=str, -# default="REF_energy", -# ) -# parser.add_argument( -# "--forces_key", -# help="Key of reference forces in training xyz", -# type=str, -# default="REF_forces", -# ) -# parser.add_argument( -# "--virials_key", -# help="Key of reference virials in training xyz", -# type=str, -# default="REF_virials", -# ) -# parser.add_argument( -# "--stress_key", -# help="Key of reference stress in training xyz", -# type=str, -# default="REF_stress", -# ) -# parser.add_argument( -# "--dipole_key", -# help="Key of reference dipoles in training xyz", -# type=str, -# default="REF_dipole", -# ) -# parser.add_argument( -# "--charges_key", -# help="Key of atomic charges in training xyz", -# type=str, -# default="REF_charges", -# ) -# parser.add_argument( -# "--atomic_numbers", -# help="List of atomic numbers", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--compute_statistics", -# help="Compute statistics for the dataset", -# action="store_true", -# default=False, -# ) -# parser.add_argument( -# "--batch_size", -# help="batch size to compute average number of neighbours", -# type=int, -# default=16, -# ) - -# parser.add_argument( -# "--scaling", -# help="type of scaling to the output", -# type=str, -# default="rms_forces_scaling", -# choices=["std_scaling", "rms_forces_scaling", "no_scaling"], -# ) -# parser.add_argument( -# "--E0s", -# help="Dictionary of isolated atom energies", -# type=str, -# default=None, -# required=False, -# ) -# parser.add_argument( -# "--shuffle", -# help="Shuffle the training dataset", -# type=bool, -# default=True, -# ) -# parser.add_argument( -# "--seed", -# help="Random seed for splitting training and validation sets", -# type=int, -# default=123, -# ) -# return parser - - -# def check_float_or_none(value: str) -> Optional[float]: -# try: -# return float(value) -# except ValueError: -# if value != "None": -# raise argparse.ArgumentTypeError( -# f"{value} is an invalid value (float or None)" -# ) from None -# return None diff --git a/hydragnn/utils/mace_utils/tools/arg_parser_tools.py b/hydragnn/utils/mace_utils/tools/arg_parser_tools.py deleted file mode 100644 index dc76c1f43..000000000 --- a/hydragnn/utils/mace_utils/tools/arg_parser_tools.py +++ /dev/null @@ -1,113 +0,0 @@ -# import logging -# import os - -# from e3nn import o3 - - -# def check_args(args): -# """ -# Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing -# the (potentially) modified args and a list of log messages. -# """ -# log_messages = [] - -# # Directories -# # Use work_dir for all other directories as well, unless they were specified by the user -# if args.log_dir is None: -# args.log_dir = os.path.join(args.work_dir, "logs") -# if args.model_dir is None: -# args.model_dir = args.work_dir -# if args.checkpoints_dir is None: -# args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") -# if args.results_dir is None: -# args.results_dir = os.path.join(args.work_dir, "results") -# if args.downloads_dir is None: -# args.downloads_dir = os.path.join(args.work_dir, "downloads") - -# # Model -# # Check if hidden_irreps, num_channels and max_L are consistent -# if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: -# args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 -# elif ( -# args.hidden_irreps is not None -# and args.num_channels is not None -# and args.max_L is not None -# ): -# args.hidden_irreps = o3.Irreps( -# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) -# .sort() -# .irreps.simplify() -# ) -# log_messages.append( -# ( -# "All of hidden_irreps, num_channels and max_L are specified", -# logging.WARNING, -# ) -# ) -# log_messages.append( -# ( -# f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", -# logging.WARNING, -# ) -# ) -# assert ( -# len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 -# ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" -# elif args.num_channels is not None and args.max_L is not None: -# assert args.num_channels > 0, "num_channels must be positive integer" -# assert args.max_L >= 0, "max_L must be non-negative integer" -# args.hidden_irreps = o3.Irreps( -# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) -# .sort() -# .irreps.simplify() -# ) -# assert ( -# len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 -# ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" -# elif args.hidden_irreps is not None: -# assert ( -# len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 -# ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - -# args.num_channels = list( -# {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} -# )[0] -# args.max_L = o3.Irreps(args.hidden_irreps).lmax -# elif args.max_L is not None and args.num_channels is None: -# assert args.max_L >= 0, "max_L must be non-negative integer" -# args.num_channels = 128 -# args.hidden_irreps = o3.Irreps( -# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) -# .sort() -# .irreps.simplify() -# ) -# elif args.max_L is None and args.num_channels is not None: -# assert args.num_channels > 0, "num_channels must be positive integer" -# args.max_L = 1 -# args.hidden_irreps = o3.Irreps( -# (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) -# .sort() -# .irreps.simplify() -# ) - -# # Loss and optimization -# # Check Stage Two loss start -# if args.swa: -# if args.start_swa is None: -# args.start_swa = max(1, args.max_num_epochs // 4 * 3) -# if args.start_swa > args.max_num_epochs: -# log_messages.append( -# ( -# f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", -# logging.WARNING, -# ) -# ) -# log_messages.append( -# ( -# "Stage Two will not start, as start_stage_two > max_num_epochs", -# logging.WARNING, -# ) -# ) -# args.swa = False - -# return args, log_messages diff --git a/hydragnn/utils/mace_utils/tools/checkpoint.py b/hydragnn/utils/mace_utils/tools/checkpoint.py deleted file mode 100644 index c1f2f690e..000000000 --- a/hydragnn/utils/mace_utils/tools/checkpoint.py +++ /dev/null @@ -1,227 +0,0 @@ -# ########################################################################################### -# # Checkpointing -# # Authors: Gregor Simm -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# import dataclasses -# import logging -# import os -# import re -# from typing import Dict, List, Optional, Tuple - -# import torch - -# from .torch_tools import TensorDict - -# Checkpoint = Dict[str, TensorDict] - - -# @dataclasses.dataclass -# class CheckpointState: -# model: torch.nn.Module -# optimizer: torch.optim.Optimizer -# lr_scheduler: torch.optim.lr_scheduler.ExponentialLR - - -# class CheckpointBuilder: -# @staticmethod -# def create_checkpoint(state: CheckpointState) -> Checkpoint: -# return { -# "model": state.model.state_dict(), -# "optimizer": state.optimizer.state_dict(), -# "lr_scheduler": state.lr_scheduler.state_dict(), -# } - -# @staticmethod -# def load_checkpoint( -# state: CheckpointState, checkpoint: Checkpoint, strict: bool -# ) -> None: -# state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore -# state.optimizer.load_state_dict(checkpoint["optimizer"]) -# state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - - -# @dataclasses.dataclass -# class CheckpointPathInfo: -# path: str -# tag: str -# epochs: int -# swa: bool - - -# class CheckpointIO: -# def __init__( -# self, directory: str, tag: str, keep: bool = False, swa_start: int = None -# ) -> None: -# self.directory = directory -# self.tag = tag -# self.keep = keep -# self.old_path: Optional[str] = None -# self.swa_start = swa_start - -# self._epochs_string = "_epoch-" -# self._filename_extension = "pt" - -# def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: -# if swa_start is not None and epochs > swa_start: -# return ( -# self.tag -# + self._epochs_string -# + str(epochs) -# + "_swa" -# + "." -# + self._filename_extension -# ) -# return ( -# self.tag -# + self._epochs_string -# + str(epochs) -# + "." -# + self._filename_extension -# ) - -# def _list_file_paths(self) -> List[str]: -# if not os.path.isdir(self.directory): -# return [] -# all_paths = [ -# os.path.join(self.directory, f) for f in os.listdir(self.directory) -# ] -# return [path for path in all_paths if os.path.isfile(path)] - -# def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: -# filename = os.path.basename(path) -# regex = re.compile( -# rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" -# ) -# regex2 = re.compile( -# rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" -# ) -# match = regex.match(filename) -# match2 = regex2.match(filename) -# swa = False -# if not match: -# if not match2: -# return None -# match = match2 -# swa = True - -# return CheckpointPathInfo( -# path=path, -# tag=match.group("tag"), -# epochs=int(match.group("epochs")), -# swa=swa, -# ) - -# def _get_latest_checkpoint_path(self, swa) -> Optional[str]: -# all_file_paths = self._list_file_paths() -# checkpoint_info_list = [ -# self._parse_checkpoint_path(path) for path in all_file_paths -# ] -# selected_checkpoint_info_list = [ -# info for info in checkpoint_info_list if info and info.tag == self.tag -# ] - -# if len(selected_checkpoint_info_list) == 0: -# logging.warning( -# f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" -# ) -# return None - -# selected_checkpoint_info_list_swa = [] -# selected_checkpoint_info_list_no_swa = [] - -# for ckp in selected_checkpoint_info_list: -# if ckp.swa: -# selected_checkpoint_info_list_swa.append(ckp) -# else: -# selected_checkpoint_info_list_no_swa.append(ckp) -# if swa: -# try: -# latest_checkpoint_info = max( -# selected_checkpoint_info_list_swa, key=lambda info: info.epochs -# ) -# except ValueError: -# logging.warning( -# "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." -# ) -# else: -# latest_checkpoint_info = max( -# selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs -# ) -# return latest_checkpoint_info.path - -# def save( -# self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False -# ) -> None: -# if not self.keep and self.old_path and not keep_last: -# logging.debug(f"Deleting old checkpoint file: {self.old_path}") -# os.remove(self.old_path) - -# filename = self._get_checkpoint_filename(epochs, self.swa_start) -# path = os.path.join(self.directory, filename) -# logging.debug(f"Saving checkpoint: {path}") -# os.makedirs(self.directory, exist_ok=True) -# torch.save(obj=checkpoint, f=path) -# self.old_path = path - -# def load_latest( -# self, swa: Optional[bool] = False, device: Optional[torch.device] = None -# ) -> Optional[Tuple[Checkpoint, int]]: -# path = self._get_latest_checkpoint_path(swa=swa) -# if path is None: -# return None - -# return self.load(path, device=device) - -# def load( -# self, path: str, device: Optional[torch.device] = None -# ) -> Tuple[Checkpoint, int]: -# checkpoint_info = self._parse_checkpoint_path(path) - -# if checkpoint_info is None: -# raise RuntimeError(f"Cannot find path '{path}'") - -# logging.info(f"Loading checkpoint: {checkpoint_info.path}") -# return ( -# torch.load(f=checkpoint_info.path, map_location=device), -# checkpoint_info.epochs, -# ) - - -# class CheckpointHandler: -# def __init__(self, *args, **kwargs) -> None: -# self.io = CheckpointIO(*args, **kwargs) -# self.builder = CheckpointBuilder() - -# def save( -# self, state: CheckpointState, epochs: int, keep_last: bool = False -# ) -> None: -# checkpoint = self.builder.create_checkpoint(state) -# self.io.save(checkpoint, epochs, keep_last) - -# def load_latest( -# self, -# state: CheckpointState, -# swa: Optional[bool] = False, -# device: Optional[torch.device] = None, -# strict=False, -# ) -> Optional[int]: -# result = self.io.load_latest(swa=swa, device=device) -# if result is None: -# return None - -# checkpoint, epochs = result -# self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) -# return epochs - -# def load( -# self, -# state: CheckpointState, -# path: str, -# strict=False, -# device: Optional[torch.device] = None, -# ) -> int: -# checkpoint, epochs = self.io.load(path, device=device) -# self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) -# return epochs diff --git a/hydragnn/utils/mace_utils/tools/scripts_utils.py b/hydragnn/utils/mace_utils/tools/scripts_utils.py deleted file mode 100644 index 15dd155d1..000000000 --- a/hydragnn/utils/mace_utils/tools/scripts_utils.py +++ /dev/null @@ -1,653 +0,0 @@ -# ########################################################################################### -# # Training utils -# # Authors: David Kovacs, Ilyes Batatia -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# import ast -# import dataclasses -# import json -# import logging -# import os -# from typing import Any, Dict, List, Optional, Tuple - -# import numpy as np -# import torch -# import torch.distributed -# from e3nn import o3 -# from prettytable import PrettyTable - -# from mace import data, modules -# from mace.tools import evaluate - - -# @dataclasses.dataclass -# class SubsetCollection: -# train: data.Configurations -# valid: data.Configurations -# tests: List[Tuple[str, data.Configurations]] - - -# def get_dataset_from_xyz( -# work_dir: str, -# train_path: str, -# valid_path: str, -# valid_fraction: float, -# config_type_weights: Dict, -# test_path: str = None, -# seed: int = 1234, -# keep_isolated_atoms: bool = False, -# energy_key: str = "REF_energy", -# forces_key: str = "REF_forces", -# stress_key: str = "REF_stress", -# virials_key: str = "virials", -# dipole_key: str = "dipoles", -# charges_key: str = "charges", -# ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: -# """Load training and test dataset from xyz file""" -# atomic_energies_dict, all_train_configs = data.load_from_xyz( -# file_path=train_path, -# config_type_weights=config_type_weights, -# energy_key=energy_key, -# forces_key=forces_key, -# stress_key=stress_key, -# virials_key=virials_key, -# dipole_key=dipole_key, -# charges_key=charges_key, -# extract_atomic_energies=True, -# keep_isolated_atoms=keep_isolated_atoms, -# ) -# logging.info( -# f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" -# ) -# if valid_path is not None: -# _, valid_configs = data.load_from_xyz( -# file_path=valid_path, -# config_type_weights=config_type_weights, -# energy_key=energy_key, -# forces_key=forces_key, -# stress_key=stress_key, -# virials_key=virials_key, -# dipole_key=dipole_key, -# charges_key=charges_key, -# extract_atomic_energies=False, -# ) -# logging.info( -# f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" -# ) -# train_configs = all_train_configs -# else: -# train_configs, valid_configs = data.random_train_valid_split( -# all_train_configs, valid_fraction, seed, work_dir -# ) -# logging.info( -# f"Validaton set contains {len(valid_configs)} configurations [{np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces]" -# ) - -# test_configs = [] -# if test_path is not None: -# _, all_test_configs = data.load_from_xyz( -# file_path=test_path, -# config_type_weights=config_type_weights, -# energy_key=energy_key, -# forces_key=forces_key, -# dipole_key=dipole_key, -# stress_key=stress_key, -# virials_key=virials_key, -# charges_key=charges_key, -# extract_atomic_energies=False, -# ) -# # create list of tuples (config_type, list(Atoms)) -# test_configs = data.test_config_types(all_test_configs) -# logging.info( -# f"Test set ({len(all_test_configs)} configs) loaded from '{test_path}':" -# ) -# for name, tmp_configs in test_configs: -# logging.info( -# f"{name}: {len(tmp_configs)} configs, {np.sum([1 if config.energy else 0 for config in tmp_configs])} energy, {np.sum([config.forces.size for config in tmp_configs])} forces" -# ) - -# return ( -# SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), -# atomic_energies_dict, -# ) - - -# def get_config_type_weights(ct_weights): -# """ -# Get config type weights from command line argument -# """ -# try: -# config_type_weights = ast.literal_eval(ct_weights) -# assert isinstance(config_type_weights, dict) -# except Exception as e: # pylint: disable=W0703 -# logging.warning( -# f"Config type weights not specified correctly ({e}), using Default" -# ) -# config_type_weights = {"Default": 1.0} -# return config_type_weights - - -# def print_git_commit(): -# try: -# import git - -# repo = git.Repo(search_parent_directories=True) -# commit = repo.head.commit.hexsha -# logging.debug(f"Current Git commit: {commit}") -# return commit -# except Exception as e: # pylint: disable=W0703 -# logging.debug(f"Error accessing Git repository: {e}") -# return "None" - - -# def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: -# if model.__class__.__name__ != "ScaleShiftMACE": -# return {"error": "Model is not a ScaleShiftMACE model"} - -# def radial_to_name(radial_type): -# if radial_type == "BesselBasis": -# return "bessel" -# if radial_type == "GaussianBasis": -# return "gaussian" -# if radial_type == "ChebychevBasis": -# return "chebyshev" -# return radial_type - -# def radial_to_transform(radial): -# if not hasattr(radial, "distance_transform"): -# return None -# if radial.distance_transform.__class__.__name__ == "AgnesiTransform": -# return "Agnesi" -# if radial.distance_transform.__class__.__name__ == "SoftTransform": -# return "Soft" -# return radial.distance_transform.__class__.__name__ - -# config = { -# "r_max": model.r_max.item(), -# "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), -# "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), -# "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access -# "interaction_cls": model.interactions[-1].__class__, -# "interaction_cls_first": model.interactions[0].__class__, -# "num_interactions": model.num_interactions.item(), -# "num_elements": len(model.atomic_numbers), -# "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), -# "MLP_irreps": ( -# o3.Irreps(str(model.readouts[-1].hidden_irreps)) -# if model.num_interactions.item() > 1 -# else 1 -# ), -# "gate": ( -# model.readouts[-1] # pylint: disable=protected-access -# .non_linearity._modules["acts"][0] -# .f -# if model.num_interactions.item() > 1 -# else None -# ), -# "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), -# "avg_num_neighbors": model.interactions[0].avg_num_neighbors, -# "atomic_numbers": model.atomic_numbers, -# "correlation": len( -# model.products[0].symmetric_contractions.contractions[0].weights -# ) -# + 1, -# "radial_type": radial_to_name( -# model.radial_embedding.bessel_fn.__class__.__name__ -# ), -# "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], -# "pair_repulsion": hasattr(model, "pair_repulsion_fn"), -# "distance_transform": radial_to_transform(model.radial_embedding), -# "atomic_inter_scale": model.scale_shift.scale.item(), -# "atomic_inter_shift": model.scale_shift.shift.item(), -# } -# return config - - -# def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: -# model = torch.load(f=f, map_location=map_location) -# model_copy = model.__class__(**extract_config_mace_model(model)) -# model_copy.load_state_dict(model.state_dict()) -# return model_copy.to(map_location) - - -# def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: -# model_copy = model.__class__(**extract_config_mace_model(model)) -# model_copy.load_state_dict(model.state_dict()) -# return model_copy.to(map_location) - - -# def convert_to_json_format(dict_input): -# for key, value in dict_input.items(): -# if isinstance(value, (np.ndarray, torch.Tensor)): -# dict_input[key] = value.tolist() -# # # check if the value is a class and convert it to a string -# elif hasattr(value, "__class__"): -# dict_input[key] = str(value) -# return dict_input - - -# def convert_from_json_format(dict_input): -# dict_output = dict_input.copy() -# if ( -# dict_input["interaction_cls"] -# == "" -# ): -# dict_output[ -# "interaction_cls" -# ] = modules.blocks.RealAgnosticResidualInteractionBlock -# if ( -# dict_input["interaction_cls"] -# == "" -# ): -# dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock -# if ( -# dict_input["interaction_cls_first"] -# == "" -# ): -# dict_output[ -# "interaction_cls_first" -# ] = modules.blocks.RealAgnosticResidualInteractionBlock -# if ( -# dict_input["interaction_cls_first"] -# == "" -# ): -# dict_output[ -# "interaction_cls_first" -# ] = modules.blocks.RealAgnosticInteractionBlock -# dict_output["r_max"] = float(dict_input["r_max"]) -# dict_output["num_bessel"] = int(dict_input["num_bessel"]) -# dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) -# dict_output["max_ell"] = int(dict_input["max_ell"]) -# dict_output["num_interactions"] = int(dict_input["num_interactions"]) -# dict_output["num_elements"] = int(dict_input["num_elements"]) -# dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) -# dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) -# dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) -# dict_output["gate"] = torch.nn.functional.silu -# dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) -# dict_output["atomic_numbers"] = dict_input["atomic_numbers"] -# dict_output["correlation"] = int(dict_input["correlation"]) -# dict_output["radial_type"] = dict_input["radial_type"] -# dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) -# dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) -# dict_output["distance_transform"] = dict_input["distance_transform"] -# dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) -# dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) - -# return dict_output - - -# def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: -# extra_files_extract = {"commit.txt": None, "config.json": None} -# model_jit_load = torch.jit.load( -# f, _extra_files=extra_files_extract, map_location=map_location -# ) -# model_load_yaml = modules.ScaleShiftMACE( -# **convert_from_json_format(json.loads(extra_files_extract["config.json"])) -# ) -# model_load_yaml.load_state_dict(model_jit_load.state_dict()) -# return model_load_yaml.to(map_location) - - -# def get_atomic_energies(E0s, train_collection, z_table) -> dict: -# if E0s is not None: -# logging.info( -# "Isolated Atomic Energies (E0s) not in training file, using command line argument" -# ) -# if E0s.lower() == "average": -# logging.info( -# "Computing average Atomic Energies using least squares regression" -# ) -# # catch if colections.train not defined above -# try: -# assert train_collection is not None -# atomic_energies_dict = data.compute_average_E0s( -# train_collection, z_table -# ) -# except Exception as e: -# raise RuntimeError( -# f"Could not compute average E0s if no training xyz given, error {e} occured" -# ) from e -# else: -# if E0s.endswith(".json"): -# logging.info(f"Loading atomic energies from {E0s}") -# with open(E0s, "r", encoding="utf-8") as f: -# atomic_energies_dict = json.load(f) -# else: -# try: -# atomic_energies_dict = ast.literal_eval(E0s) -# assert isinstance(atomic_energies_dict, dict) -# except Exception as e: -# raise RuntimeError( -# f"E0s specified invalidly, error {e} occured" -# ) from e -# else: -# raise RuntimeError( -# "E0s not found in training file and not specified in command line" -# ) -# return atomic_energies_dict - - -# def get_loss_fn( -# loss: str, -# energy_weight: float, -# forces_weight: float, -# stress_weight: float, -# virials_weight: float, -# dipole_weight: float, -# dipole_only: bool, -# compute_dipole: bool, -# ) -> torch.nn.Module: -# if loss == "weighted": -# loss_fn = modules.WeightedEnergyForcesLoss( -# energy_weight=energy_weight, forces_weight=forces_weight -# ) -# elif loss == "forces_only": -# loss_fn = modules.WeightedForcesLoss(forces_weight=forces_weight) -# elif loss == "virials": -# loss_fn = modules.WeightedEnergyForcesVirialsLoss( -# energy_weight=energy_weight, -# forces_weight=forces_weight, -# virials_weight=virials_weight, -# ) -# elif loss == "stress": -# loss_fn = modules.WeightedEnergyForcesStressLoss( -# energy_weight=energy_weight, -# forces_weight=forces_weight, -# stress_weight=stress_weight, -# ) -# elif loss == "dipole": -# assert ( -# dipole_only is True -# ), "dipole loss can only be used with AtomicDipolesMACE model" -# loss_fn = modules.DipoleSingleLoss( -# dipole_weight=dipole_weight, -# ) -# elif loss == "energy_forces_dipole": -# assert dipole_only is False and compute_dipole is True -# loss_fn = modules.WeightedEnergyForcesDipoleLoss( -# energy_weight=energy_weight, -# forces_weight=forces_weight, -# dipole_weight=dipole_weight, -# ) -# else: -# loss_fn = modules.EnergyForcesLoss( -# energy_weight=energy_weight, forces_weight=forces_weight -# ) -# return loss_fn - - -# def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: -# return [ -# os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) -# ] - - -# def custom_key(key): -# """ -# Helper function to sort the keys of the data loader dictionary -# to ensure that the training set, and validation set -# are evaluated first -# """ -# if key == "train": -# return (0, key) -# if key == "valid": -# return (1, key) -# return (2, key) - - -# class LRScheduler: -# def __init__(self, optimizer, args) -> None: -# self.scheduler = args.scheduler -# self._optimizer_type = ( -# args.optimizer -# ) # Schedulefree does not need an optimizer but checkpoint handler does. -# if args.scheduler == "ExponentialLR": -# self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( -# optimizer=optimizer, gamma=args.lr_scheduler_gamma -# ) -# elif args.scheduler == "ReduceLROnPlateau": -# self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( -# optimizer=optimizer, -# factor=args.lr_factor, -# patience=args.scheduler_patience, -# ) -# else: -# raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") - -# def step(self, metrics=None, epoch=None): # pylint: disable=E1123 -# if self._optimizer_type == "schedulefree": -# return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary -# if self.scheduler == "ExponentialLR": -# self.lr_scheduler.step(epoch=epoch) -# elif self.scheduler == "ReduceLROnPlateau": -# self.lr_scheduler.step( # pylint: disable=E1123 -# metrics=metrics, epoch=epoch -# ) - -# def __getattr__(self, name): -# if name == "step": -# return self.step -# return getattr(self.lr_scheduler, name) - - -# def create_error_table( -# table_type: str, -# all_data_loaders: dict, -# model: torch.nn.Module, -# loss_fn: torch.nn.Module, -# output_args: Dict[str, bool], -# log_wandb: bool, -# device: str, -# distributed: bool = False, -# ) -> PrettyTable: -# if log_wandb: -# import wandb -# table = PrettyTable() -# if table_type == "TotalRMSE": -# table.field_names = [ -# "config_type", -# "RMSE E / meV", -# "RMSE F / meV / A", -# "relative F RMSE %", -# ] -# elif table_type == "PerAtomRMSE": -# table.field_names = [ -# "config_type", -# "RMSE E / meV / atom", -# "RMSE F / meV / A", -# "relative F RMSE %", -# ] -# elif table_type == "PerAtomRMSEstressvirials": -# table.field_names = [ -# "config_type", -# "RMSE E / meV / atom", -# "RMSE F / meV / A", -# "relative F RMSE %", -# "RMSE Stress (Virials) / meV / A (A^3)", -# ] -# elif table_type == "PerAtomMAEstressvirials": -# table.field_names = [ -# "config_type", -# "MAE E / meV / atom", -# "MAE F / meV / A", -# "relative F MAE %", -# "MAE Stress (Virials) / meV / A (A^3)", -# ] -# elif table_type == "TotalMAE": -# table.field_names = [ -# "config_type", -# "MAE E / meV", -# "MAE F / meV / A", -# "relative F MAE %", -# ] -# elif table_type == "PerAtomMAE": -# table.field_names = [ -# "config_type", -# "MAE E / meV / atom", -# "MAE F / meV / A", -# "relative F MAE %", -# ] -# elif table_type == "DipoleRMSE": -# table.field_names = [ -# "config_type", -# "RMSE MU / mDebye / atom", -# "relative MU RMSE %", -# ] -# elif table_type == "DipoleMAE": -# table.field_names = [ -# "config_type", -# "MAE MU / mDebye / atom", -# "relative MU MAE %", -# ] -# elif table_type == "EnergyDipoleRMSE": -# table.field_names = [ -# "config_type", -# "RMSE E / meV / atom", -# "RMSE F / meV / A", -# "rel F RMSE %", -# "RMSE MU / mDebye / atom", -# "rel MU RMSE %", -# ] - -# for name in sorted(all_data_loaders, key=custom_key): -# data_loader = all_data_loaders[name] -# logging.info(f"Evaluating {name} ...") -# _, metrics = evaluate( -# model, -# loss_fn=loss_fn, -# data_loader=data_loader, -# output_args=output_args, -# device=device, -# ) -# if distributed: -# torch.distributed.barrier() - -# del data_loader -# torch.cuda.empty_cache() -# if log_wandb: -# wandb_log_dict = { -# name -# + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] -# * 1e3, # meV / atom -# name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A -# name + "_final_rel_rmse_f": metrics["rel_rmse_f"], -# } -# wandb.log(wandb_log_dict) -# if table_type == "TotalRMSE": -# table.add_row( -# [ -# name, -# f"{metrics['rmse_e'] * 1000:8.1f}", -# f"{metrics['rmse_f'] * 1000:8.1f}", -# f"{metrics['rel_rmse_f']:8.2f}", -# ] -# ) -# elif table_type == "PerAtomRMSE": -# table.add_row( -# [ -# name, -# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", -# f"{metrics['rmse_f'] * 1000:8.1f}", -# f"{metrics['rel_rmse_f']:8.2f}", -# ] -# ) -# elif ( -# table_type == "PerAtomRMSEstressvirials" -# and metrics["rmse_stress"] is not None -# ): -# table.add_row( -# [ -# name, -# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", -# f"{metrics['rmse_f'] * 1000:8.1f}", -# f"{metrics['rel_rmse_f']:8.2f}", -# f"{metrics['rmse_stress'] * 1000:8.1f}", -# ] -# ) -# elif ( -# table_type == "PerAtomRMSEstressvirials" -# and metrics["rmse_virials"] is not None -# ): -# table.add_row( -# [ -# name, -# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", -# f"{metrics['rmse_f'] * 1000:8.1f}", -# f"{metrics['rel_rmse_f']:8.2f}", -# f"{metrics['rmse_virials'] * 1000:8.1f}", -# ] -# ) -# elif ( -# table_type == "PerAtomMAEstressvirials" -# and metrics["mae_stress"] is not None -# ): -# table.add_row( -# [ -# name, -# f"{metrics['mae_e_per_atom'] * 1000:8.1f}", -# f"{metrics['mae_f'] * 1000:8.1f}", -# f"{metrics['rel_mae_f']:8.2f}", -# f"{metrics['mae_stress'] * 1000:8.1f}", -# ] -# ) -# elif ( -# table_type == "PerAtomMAEstressvirials" -# and metrics["mae_virials"] is not None -# ): -# table.add_row( -# [ -# name, -# f"{metrics['mae_e_per_atom'] * 1000:8.1f}", -# f"{metrics['mae_f'] * 1000:8.1f}", -# f"{metrics['rel_mae_f']:8.2f}", -# f"{metrics['mae_virials'] * 1000:8.1f}", -# ] -# ) -# elif table_type == "TotalMAE": -# table.add_row( -# [ -# name, -# f"{metrics['mae_e'] * 1000:8.1f}", -# f"{metrics['mae_f'] * 1000:8.1f}", -# f"{metrics['rel_mae_f']:8.2f}", -# ] -# ) -# elif table_type == "PerAtomMAE": -# table.add_row( -# [ -# name, -# f"{metrics['mae_e_per_atom'] * 1000:8.1f}", -# f"{metrics['mae_f'] * 1000:8.1f}", -# f"{metrics['rel_mae_f']:8.2f}", -# ] -# ) -# elif table_type == "DipoleRMSE": -# table.add_row( -# [ -# name, -# f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", -# f"{metrics['rel_rmse_mu']:8.1f}", -# ] -# ) -# elif table_type == "DipoleMAE": -# table.add_row( -# [ -# name, -# f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", -# f"{metrics['rel_mae_mu']:8.1f}", -# ] -# ) -# elif table_type == "EnergyDipoleRMSE": -# table.add_row( -# [ -# name, -# f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", -# f"{metrics['rmse_f'] * 1000:8.1f}", -# f"{metrics['rel_rmse_f']:8.1f}", -# f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", -# f"{metrics['rel_rmse_mu']:8.1f}", -# ] -# ) -# return table diff --git a/hydragnn/utils/mace_utils/tools/slurm_distributed.py b/hydragnn/utils/mace_utils/tools/slurm_distributed.py deleted file mode 100644 index 35915fe26..000000000 --- a/hydragnn/utils/mace_utils/tools/slurm_distributed.py +++ /dev/null @@ -1,34 +0,0 @@ -# ########################################################################################### -# # Slurm environment setup for distributed training. -# # This code is refactored from rsarm's contribution at: -# # https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# import os - -# import hostlist - - -# class DistributedEnvironment: -# def __init__(self): -# self._setup_distr_env() -# self.master_addr = os.environ["MASTER_ADDR"] -# self.master_port = os.environ["MASTER_PORT"] -# self.world_size = int(os.environ["WORLD_SIZE"]) -# self.local_rank = int(os.environ["LOCAL_RANK"]) -# self.rank = int(os.environ["RANK"]) - -# def _setup_distr_env(self): -# hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] -# os.environ["MASTER_ADDR"] = hostname -# os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") -# os.environ["WORLD_SIZE"] = os.environ.get( -# "SLURM_NTASKS", -# str( -# int(os.environ["SLURM_NTASKS_PER_NODE"]) -# * int(os.environ["SLURM_NNODES"]) -# ), -# ) -# os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] -# os.environ["RANK"] = os.environ["SLURM_PROCID"] diff --git a/hydragnn/utils/mace_utils/tools/train.py b/hydragnn/utils/mace_utils/tools/train.py deleted file mode 100644 index a0b710222..000000000 --- a/hydragnn/utils/mace_utils/tools/train.py +++ /dev/null @@ -1,524 +0,0 @@ -# ########################################################################################### -# # Training script -# # Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# # This program is distributed under the MIT License (see MIT.md) -# ########################################################################################### - -# import dataclasses -# import logging -# import time -# from contextlib import nullcontext -# from typing import Any, Dict, List, Optional, Tuple, Union - -# import numpy as np -# import torch -# import torch.distributed -# from torch.nn.parallel import DistributedDataParallel -# from torch.optim.swa_utils import SWALR, AveragedModel -# from torch.utils.data import DataLoader -# from torch.utils.data.distributed import DistributedSampler -# from torch_ema import ExponentialMovingAverage -# from torchmetrics import Metric - -# from . import torch_geometric -# from .checkpoint import CheckpointHandler, CheckpointState -# from .torch_tools import to_numpy -# from .utils import ( -# MetricsLogger, -# compute_mae, -# compute_q95, -# compute_rel_mae, -# compute_rel_rmse, -# compute_rmse, -# ) - - -# @dataclasses.dataclass -# class SWAContainer: -# model: AveragedModel -# scheduler: SWALR -# start: int -# loss_fn: torch.nn.Module - - -# def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): -# eval_metrics["mode"] = "eval" -# eval_metrics["epoch"] = epoch -# logger.log(eval_metrics) -# if epoch is None: -# inintial_phrase = "Initial" -# else: -# inintial_phrase = f"Epoch {epoch}" -# if log_errors == "PerAtomRMSE": -# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 -# error_f = eval_metrics["rmse_f"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" -# ) -# elif ( -# log_errors == "PerAtomRMSEstressvirials" -# and eval_metrics["rmse_stress_per_atom"] is not None -# ): -# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 -# error_f = eval_metrics["rmse_f"] * 1e3 -# error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", -# ) -# elif ( -# log_errors == "PerAtomRMSEstressvirials" -# and eval_metrics["rmse_virials_per_atom"] is not None -# ): -# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 -# error_f = eval_metrics["rmse_f"] * 1e3 -# error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", -# ) -# elif ( -# log_errors == "PerAtomMAEstressvirials" -# and eval_metrics["mae_stress_per_atom"] is not None -# ): -# error_e = eval_metrics["mae_e_per_atom"] * 1e3 -# error_f = eval_metrics["mae_f"] * 1e3 -# error_stress = eval_metrics["mae_stress"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_stress={error_stress:8.1f} meV / A^3" -# ) -# elif ( -# log_errors == "PerAtomMAEstressvirials" -# and eval_metrics["mae_virials_per_atom"] is not None -# ): -# error_e = eval_metrics["mae_e_per_atom"] * 1e3 -# error_f = eval_metrics["mae_f"] * 1e3 -# error_virials = eval_metrics["mae_virials"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A, MAE_virials={error_virials:8.1f} meV" -# ) -# elif log_errors == "TotalRMSE": -# error_e = eval_metrics["rmse_e"] * 1e3 -# error_f = eval_metrics["rmse_f"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", -# ) -# elif log_errors == "PerAtomMAE": -# error_e = eval_metrics["mae_e_per_atom"] * 1e3 -# error_f = eval_metrics["mae_f"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", -# ) -# elif log_errors == "TotalMAE": -# error_e = eval_metrics["mae_e"] * 1e3 -# error_f = eval_metrics["mae_f"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", -# ) -# elif log_errors == "DipoleRMSE": -# error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", -# ) -# elif log_errors == "EnergyDipoleRMSE": -# error_e = eval_metrics["rmse_e_per_atom"] * 1e3 -# error_f = eval_metrics["rmse_f"] * 1e3 -# error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 -# logging.info( -# f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", -# ) - - -# def train( -# model: torch.nn.Module, -# loss_fn: torch.nn.Module, -# train_loader: DataLoader, -# valid_loader: Dict[str, DataLoader], -# optimizer: torch.optim.Optimizer, -# lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, -# start_epoch: int, -# max_num_epochs: int, -# patience: int, -# checkpoint_handler: CheckpointHandler, -# logger: MetricsLogger, -# eval_interval: int, -# output_args: Dict[str, bool], -# device: torch.device, -# log_errors: str, -# swa: Optional[SWAContainer] = None, -# ema: Optional[ExponentialMovingAverage] = None, -# max_grad_norm: Optional[float] = 10.0, -# log_wandb: bool = False, -# distributed: bool = False, -# save_all_checkpoints: bool = False, -# distributed_model: Optional[DistributedDataParallel] = None, -# train_sampler: Optional[DistributedSampler] = None, -# rank: Optional[int] = 0, -# ): -# lowest_loss = np.inf -# valid_loss = np.inf -# patience_counter = 0 -# swa_start = True -# keep_last = False -# if log_wandb: -# import wandb - -# if max_grad_norm is not None: -# logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") - -# logging.info("") -# logging.info("===========TRAINING===========") -# logging.info("Started training, reporting errors on validation set") -# logging.info("Loss metrics on validation set") -# epoch = start_epoch - -# # # log validation loss before _any_ training -# param_context = ema.average_parameters() if ema is not None else nullcontext() -# with param_context: -# valid_loss, eval_metrics = evaluate( -# model=model, -# loss_fn=loss_fn, -# data_loader=valid_loader, -# output_args=output_args, -# device=device, -# ) -# valid_err_log(valid_loss, eval_metrics, logger, log_errors, None) - -# while epoch < max_num_epochs: -# # LR scheduler and SWA update -# if swa is None or epoch < swa.start: -# if epoch > start_epoch: -# lr_scheduler.step( -# metrics=valid_loss -# ) # Can break if exponential LR, TODO fix that! -# else: -# if swa_start: -# logging.info("Changing loss based on Stage Two Weights") -# lowest_loss = np.inf -# swa_start = False -# keep_last = True -# loss_fn = swa.loss_fn -# swa.model.update_parameters(model) -# if epoch > start_epoch: -# swa.scheduler.step() - -# # Train -# if distributed: -# train_sampler.set_epoch(epoch) -# if "ScheduleFree" in type(optimizer).__name__: -# optimizer.train() -# train_one_epoch( -# model=model, -# loss_fn=loss_fn, -# data_loader=train_loader, -# optimizer=optimizer, -# epoch=epoch, -# output_args=output_args, -# max_grad_norm=max_grad_norm, -# ema=ema, -# logger=logger, -# device=device, -# distributed_model=distributed_model, -# rank=rank, -# ) -# if distributed: -# torch.distributed.barrier() - -# # Validate -# if epoch % eval_interval == 0: -# model_to_evaluate = ( -# model if distributed_model is None else distributed_model -# ) -# param_context = ( -# ema.average_parameters() if ema is not None else nullcontext() -# ) -# if "ScheduleFree" in type(optimizer).__name__: -# optimizer.eval() -# with param_context: -# valid_loss, eval_metrics = evaluate( -# model=model_to_evaluate, -# loss_fn=loss_fn, -# data_loader=valid_loader, -# output_args=output_args, -# device=device, -# ) -# if rank == 0: -# valid_err_log( -# valid_loss, -# eval_metrics, -# logger, -# log_errors, -# epoch, -# ) -# if log_wandb: -# wandb_log_dict = { -# "epoch": epoch, -# "valid_loss": valid_loss, -# "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], -# "valid_rmse_f": eval_metrics["rmse_f"], -# } -# wandb.log(wandb_log_dict) - -# if valid_loss >= lowest_loss: -# patience_counter += 1 -# if patience_counter >= patience and epoch < swa.start: -# logging.info( -# f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" -# ) -# epoch = swa.start -# elif patience_counter >= patience and epoch >= swa.start: -# logging.info( -# f"Stopping optimization after {patience_counter} epochs without improvement" -# ) -# break -# if save_all_checkpoints: -# param_context = ( -# ema.average_parameters() -# if ema is not None -# else nullcontext() -# ) -# with param_context: -# checkpoint_handler.save( -# state=CheckpointState(model, optimizer, lr_scheduler), -# epochs=epoch, -# keep_last=True, -# ) -# else: -# lowest_loss = valid_loss -# patience_counter = 0 -# param_context = ( -# ema.average_parameters() if ema is not None else nullcontext() -# ) -# with param_context: -# checkpoint_handler.save( -# state=CheckpointState(model, optimizer, lr_scheduler), -# epochs=epoch, -# keep_last=keep_last, -# ) -# keep_last = False or save_all_checkpoints -# if distributed: -# torch.distributed.barrier() -# epoch += 1 - -# logging.info("Training complete") - - -# def train_one_epoch( -# model: torch.nn.Module, -# loss_fn: torch.nn.Module, -# data_loader: DataLoader, -# optimizer: torch.optim.Optimizer, -# epoch: int, -# output_args: Dict[str, bool], -# max_grad_norm: Optional[float], -# ema: Optional[ExponentialMovingAverage], -# logger: MetricsLogger, -# device: torch.device, -# distributed_model: Optional[DistributedDataParallel] = None, -# rank: Optional[int] = 0, -# ) -> None: -# model_to_train = model if distributed_model is None else distributed_model -# for batch in data_loader: -# _, opt_metrics = take_step( -# model=model_to_train, -# loss_fn=loss_fn, -# batch=batch, -# optimizer=optimizer, -# ema=ema, -# output_args=output_args, -# max_grad_norm=max_grad_norm, -# device=device, -# ) -# opt_metrics["mode"] = "opt" -# opt_metrics["epoch"] = epoch -# if rank == 0: -# logger.log(opt_metrics) - - -# def take_step( -# model: torch.nn.Module, -# loss_fn: torch.nn.Module, -# batch: torch_geometric.batch.Batch, -# optimizer: torch.optim.Optimizer, -# ema: Optional[ExponentialMovingAverage], -# output_args: Dict[str, bool], -# max_grad_norm: Optional[float], -# device: torch.device, -# ) -> Tuple[float, Dict[str, Any]]: -# start_time = time.time() -# batch = batch.to(device) -# optimizer.zero_grad(set_to_none=True) -# batch_dict = batch.to_dict() -# output = model( -# batch_dict, -# training=True, -# compute_force=output_args["forces"], -# compute_virials=output_args["virials"], -# compute_stress=output_args["stress"], -# ) -# loss = loss_fn(pred=output, ref=batch) -# loss.backward() -# if max_grad_norm is not None: -# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) -# optimizer.step() - -# if ema is not None: -# ema.update() - -# loss_dict = { -# "loss": to_numpy(loss), -# "time": time.time() - start_time, -# } - -# return loss, loss_dict - - -# def evaluate( -# model: torch.nn.Module, -# loss_fn: torch.nn.Module, -# data_loader: DataLoader, -# output_args: Dict[str, bool], -# device: torch.device, -# ) -> Tuple[float, Dict[str, Any]]: -# for param in model.parameters(): -# param.requires_grad = False - -# metrics = MACELoss(loss_fn=loss_fn).to(device) - -# start_time = time.time() -# for batch in data_loader: -# batch = batch.to(device) -# batch_dict = batch.to_dict() -# output = model( -# batch_dict, -# training=False, -# compute_force=output_args["forces"], -# compute_virials=output_args["virials"], -# compute_stress=output_args["stress"], -# ) -# avg_loss, aux = metrics(batch, output) - -# avg_loss, aux = metrics.compute() -# aux["time"] = time.time() - start_time -# metrics.reset() - -# for param in model.parameters(): -# param.requires_grad = True - -# return avg_loss, aux - - -# class MACELoss(Metric): -# def __init__(self, loss_fn: torch.nn.Module): -# super().__init__() -# self.loss_fn = loss_fn -# self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") -# self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") -# self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") -# self.add_state("delta_es", default=[], dist_reduce_fx="cat") -# self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") -# self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") -# self.add_state("fs", default=[], dist_reduce_fx="cat") -# self.add_state("delta_fs", default=[], dist_reduce_fx="cat") -# self.add_state( -# "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" -# ) -# self.add_state("delta_stress", default=[], dist_reduce_fx="cat") -# self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat") -# self.add_state( -# "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" -# ) -# self.add_state("delta_virials", default=[], dist_reduce_fx="cat") -# self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") -# self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") -# self.add_state("mus", default=[], dist_reduce_fx="cat") -# self.add_state("delta_mus", default=[], dist_reduce_fx="cat") -# self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") - -# def update(self, batch, output): # pylint: disable=arguments-differ -# loss = self.loss_fn(pred=output, ref=batch) -# self.total_loss += loss -# self.num_data += batch.num_graphs - -# if output.get("energy") is not None and batch.energy is not None: -# self.E_computed += 1.0 -# self.delta_es.append(batch.energy - output["energy"]) -# self.delta_es_per_atom.append( -# (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) -# ) -# if output.get("forces") is not None and batch.forces is not None: -# self.Fs_computed += 1.0 -# self.fs.append(batch.forces) -# self.delta_fs.append(batch.forces - output["forces"]) -# if output.get("stress") is not None and batch.stress is not None: -# self.stress_computed += 1.0 -# self.delta_stress.append(batch.stress - output["stress"]) -# self.delta_stress_per_atom.append( -# (batch.stress - output["stress"]) -# / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) -# ) -# if output.get("virials") is not None and batch.virials is not None: -# self.virials_computed += 1.0 -# self.delta_virials.append(batch.virials - output["virials"]) -# self.delta_virials_per_atom.append( -# (batch.virials - output["virials"]) -# / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) -# ) -# if output.get("dipole") is not None and batch.dipole is not None: -# self.Mus_computed += 1.0 -# self.mus.append(batch.dipole) -# self.delta_mus.append(batch.dipole - output["dipole"]) -# self.delta_mus_per_atom.append( -# (batch.dipole - output["dipole"]) -# / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) -# ) - -# def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: -# if isinstance(delta, list): -# delta = torch.cat(delta) -# return to_numpy(delta) - -# def compute(self): -# aux = {} -# aux["loss"] = to_numpy(self.total_loss / self.num_data).item() -# if self.E_computed: -# delta_es = self.convert(self.delta_es) -# delta_es_per_atom = self.convert(self.delta_es_per_atom) -# aux["mae_e"] = compute_mae(delta_es) -# aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) -# aux["rmse_e"] = compute_rmse(delta_es) -# aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) -# aux["q95_e"] = compute_q95(delta_es) -# if self.Fs_computed: -# fs = self.convert(self.fs) -# delta_fs = self.convert(self.delta_fs) -# aux["mae_f"] = compute_mae(delta_fs) -# aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) -# aux["rmse_f"] = compute_rmse(delta_fs) -# aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) -# aux["q95_f"] = compute_q95(delta_fs) -# if self.stress_computed: -# delta_stress = self.convert(self.delta_stress) -# delta_stress_per_atom = self.convert(self.delta_stress_per_atom) -# aux["mae_stress"] = compute_mae(delta_stress) -# aux["rmse_stress"] = compute_rmse(delta_stress) -# aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom) -# aux["q95_stress"] = compute_q95(delta_stress) -# if self.virials_computed: -# delta_virials = self.convert(self.delta_virials) -# delta_virials_per_atom = self.convert(self.delta_virials_per_atom) -# aux["mae_virials"] = compute_mae(delta_virials) -# aux["rmse_virials"] = compute_rmse(delta_virials) -# aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) -# aux["q95_virials"] = compute_q95(delta_virials) -# if self.Mus_computed: -# mus = self.convert(self.mus) -# delta_mus = self.convert(self.delta_mus) -# delta_mus_per_atom = self.convert(self.delta_mus_per_atom) -# aux["mae_mu"] = compute_mae(delta_mus) -# aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) -# aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) -# aux["rmse_mu"] = compute_rmse(delta_mus) -# aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) -# aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) -# aux["q95_mu"] = compute_q95(delta_mus) - -# return aux["loss"], aux From 6c89dd642fd1d70ddbf6b4abb80a49d38e7cc793 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 27 Sep 2024 13:17:10 -0400 Subject: [PATCH 33/51] mace_utils comments (draft 4) --- hydragnn/utils/mace_utils/modules/utils.py | 14 +- hydragnn/utils/mace_utils/tools/__init__.py | 55 +--- .../mace_utils/tools/finetuning_utils.py | 278 +++++++++--------- .../utils/mace_utils/tools/torch_tools.py | 129 ++++---- 4 files changed, 222 insertions(+), 254 deletions(-) diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index 488cd90a3..7e6fbfc25 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -37,13 +37,13 @@ def get_edge_vectors_and_lengths( return vectors, lengths -def _check_non_zero(std): - if std == 0.0: - logging.warning( - "Standard deviation of the scaling is zero, Changing to no scaling" - ) - std = 1.0 - return std +# def _check_non_zero(std): +# if std == 0.0: +# logging.warning( +# "Standard deviation of the scaling is zero, Changing to no scaling" +# ) +# std = 1.0 +# return std def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): diff --git a/hydragnn/utils/mace_utils/tools/__init__.py b/hydragnn/utils/mace_utils/tools/__init__.py index 3703f3152..9464eef4f 100644 --- a/hydragnn/utils/mace_utils/tools/__init__.py +++ b/hydragnn/utils/mace_utils/tools/__init__.py @@ -1,57 +1,30 @@ -# from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser -# from .arg_parser_tools import check_args from .cg import U_matrix_real -# from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState -from .finetuning_utils import load_foundations +# from .finetuning_utils import load_foundations from .torch_tools import ( TensorDict, - cartesian_to_spherical, + # cartesian_to_spherical, count_parameters, - init_device, - init_wandb, - set_default_dtype, - set_seeds, - spherical_to_cartesian, + # init_device, + # init_wandb, + # set_default_dtype, + # spherical_to_cartesian, to_numpy, - to_one_hot, + # to_one_hot, voigt_to_matrix, ) __all__ = [ "TensorDict", - "AtomicNumberTable", - "atomic_numbers_to_indices", "to_numpy", - "to_one_hot", - "build_default_arg_parser", - "check_args", - "set_seeds", - "init_device", - "setup_logger", - "get_tag", + # "to_one_hot", + # "init_device", "count_parameters", - "get_optimizer", - "MetricsLogger", - "get_atomic_number_table_from_zs", - "train", - "evaluate", - "SWAContainer", - "CheckpointHandler", - "CheckpointIO", - "CheckpointState", - "set_default_dtype", - "compute_mae", - "compute_rel_mae", - "compute_rmse", - "compute_rel_rmse", - "compute_q95", - "compute_c", + # "set_default_dtype", "U_matrix_real", - "spherical_to_cartesian", - "cartesian_to_spherical", + # "spherical_to_cartesian", + # "cartesian_to_spherical", "voigt_to_matrix", - "init_wandb", - "load_foundations", - "build_preprocess_arg_parser", + # "init_wandb", + # "load_foundations", ] diff --git a/hydragnn/utils/mace_utils/tools/finetuning_utils.py b/hydragnn/utils/mace_utils/tools/finetuning_utils.py index 0b214a287..219e416ae 100644 --- a/hydragnn/utils/mace_utils/tools/finetuning_utils.py +++ b/hydragnn/utils/mace_utils/tools/finetuning_utils.py @@ -1,149 +1,149 @@ -import torch +# import torch -from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable +# from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable -def load_foundations( - model: torch.nn.Module, - model_foundations: torch.nn.Module, - table: AtomicNumberTable, - load_readout=False, - use_shift=False, - use_scale=True, - max_L=2, -): - """ - Load the foundations of a model into a model for fine-tuning. - """ - assert model_foundations.r_max == model.r_max - z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) - new_z_table = table - num_species_foundations = len(z_table.zs) - num_channels_foundation = ( - model_foundations.node_embedding.linear.weight.shape[0] - // num_species_foundations - ) - indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] - num_radial = model.radial_embedding.out_dim - num_species = len(indices_weights) - max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access - model.node_embedding.linear.weight = torch.nn.Parameter( - model_foundations.node_embedding.linear.weight.view( - num_species_foundations, -1 - )[indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": - model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( - model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() - ) +# def load_foundations( +# model: torch.nn.Module, +# model_foundations: torch.nn.Module, +# table: AtomicNumberTable, +# load_readout=False, +# use_shift=False, +# use_scale=True, +# max_L=2, +# ): +# """ +# Load the foundations of a model into a model for fine-tuning. +# """ +# assert model_foundations.r_max == model.r_max +# z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) +# new_z_table = table +# num_species_foundations = len(z_table.zs) +# num_channels_foundation = ( +# model_foundations.node_embedding.linear.weight.shape[0] +# // num_species_foundations +# ) +# indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] +# num_radial = model.radial_embedding.out_dim +# num_species = len(indices_weights) +# max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access +# model.node_embedding.linear.weight = torch.nn.Parameter( +# model_foundations.node_embedding.linear.weight.view( +# num_species_foundations, -1 +# )[indices_weights, :] +# .flatten() +# .clone() +# / (num_species_foundations / num_species) ** 0.5 +# ) +# if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": +# model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( +# model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() +# ) - for i in range(int(model.num_interactions)): - model.interactions[i].linear_up.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear_up.weight.clone() - ) - model.interactions[i].avg_num_neighbors = model_foundations.interactions[ - i - ].avg_num_neighbors - for j in range(4): # Assuming 4 layers in conv_tp_weights, - layer_name = f"layer{j}" - if j == 0: - getattr( - model.interactions[i].conv_tp_weights, layer_name - ).weight = torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ) - .weight[:num_radial, :] - .clone() - ) - else: - getattr( - model.interactions[i].conv_tp_weights, layer_name - ).weight = torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() - ) +# for i in range(int(model.num_interactions)): +# model.interactions[i].linear_up.weight = torch.nn.Parameter( +# model_foundations.interactions[i].linear_up.weight.clone() +# ) +# model.interactions[i].avg_num_neighbors = model_foundations.interactions[ +# i +# ].avg_num_neighbors +# for j in range(4): # Assuming 4 layers in conv_tp_weights, +# layer_name = f"layer{j}" +# if j == 0: +# getattr( +# model.interactions[i].conv_tp_weights, layer_name +# ).weight = torch.nn.Parameter( +# getattr( +# model_foundations.interactions[i].conv_tp_weights, +# layer_name, +# ) +# .weight[:num_radial, :] +# .clone() +# ) +# else: +# getattr( +# model.interactions[i].conv_tp_weights, layer_name +# ).weight = torch.nn.Parameter( +# getattr( +# model_foundations.interactions[i].conv_tp_weights, +# layer_name, +# ).weight.clone() +# ) - model.interactions[i].linear.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear.weight.clone() - ) - if ( - model.interactions[i].__class__.__name__ - == "RealAgnosticResidualInteractionBlock" - ): - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - num_species_foundations, - num_channels_foundation, - )[:, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - else: - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - (max_ell + 1), - num_species_foundations, - num_channels_foundation, - )[:, :, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - # Transferring products - for i in range(2): # Assuming 2 products modules - max_range = max_L + 1 if i == 0 else 1 - for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[ - j - ].weights_max = torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) +# model.interactions[i].linear.weight = torch.nn.Parameter( +# model_foundations.interactions[i].linear.weight.clone() +# ) +# if ( +# model.interactions[i].__class__.__name__ +# == "RealAgnosticResidualInteractionBlock" +# ): +# model.interactions[i].skip_tp.weight = torch.nn.Parameter( +# model_foundations.interactions[i] +# .skip_tp.weight.reshape( +# num_channels_foundation, +# num_species_foundations, +# num_channels_foundation, +# )[:, indices_weights, :] +# .flatten() +# .clone() +# / (num_species_foundations / num_species) ** 0.5 +# ) +# else: +# model.interactions[i].skip_tp.weight = torch.nn.Parameter( +# model_foundations.interactions[i] +# .skip_tp.weight.reshape( +# num_channels_foundation, +# (max_ell + 1), +# num_species_foundations, +# num_channels_foundation, +# )[:, :, indices_weights, :] +# .flatten() +# .clone() +# / (num_species_foundations / num_species) ** 0.5 +# ) +# # Transferring products +# for i in range(2): # Assuming 2 products modules +# max_range = max_L + 1 if i == 0 else 1 +# for j in range(max_range): # Assuming 3 contractions in symmetric_contractions +# model.products[i].symmetric_contractions.contractions[ +# j +# ].weights_max = torch.nn.Parameter( +# model_foundations.products[i] +# .symmetric_contractions.contractions[j] +# .weights_max[indices_weights, :, :] +# .clone() +# ) - for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[ - k - ] = torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] - .clone() - ) +# for k in range(2): # Assuming 2 weights in each contraction +# model.products[i].symmetric_contractions.contractions[j].weights[ +# k +# ] = torch.nn.Parameter( +# model_foundations.products[i] +# .symmetric_contractions.contractions[j] +# .weights[k][indices_weights, :, :] +# .clone() +# ) - model.products[i].linear.weight = torch.nn.Parameter( - model_foundations.products[i].linear.weight.clone() - ) +# model.products[i].linear.weight = torch.nn.Parameter( +# model_foundations.products[i].linear.weight.clone() +# ) - if load_readout: - # Transferring readouts - model.readouts[0].linear.weight = torch.nn.Parameter( - model_foundations.readouts[0].linear.weight.clone() - ) +# if load_readout: +# # Transferring readouts +# model.readouts[0].linear.weight = torch.nn.Parameter( +# model_foundations.readouts[0].linear.weight.clone() +# ) - model.readouts[1].linear_1.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_1.weight.clone() - ) +# model.readouts[1].linear_1.weight = torch.nn.Parameter( +# model_foundations.readouts[1].linear_1.weight.clone() +# ) - model.readouts[1].linear_2.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_2.weight.clone() - ) - if model_foundations.scale_shift is not None: - if use_scale: - model.scale_shift.scale = model_foundations.scale_shift.scale.clone() - if use_shift: - model.scale_shift.shift = model_foundations.scale_shift.shift.clone() - return model +# model.readouts[1].linear_2.weight = torch.nn.Parameter( +# model_foundations.readouts[1].linear_2.weight.clone() +# ) +# if model_foundations.scale_shift is not None: +# if use_scale: +# model.scale_shift.scale = model_foundations.scale_shift.scale.clone() +# if use_shift: +# model.scale_shift.shift = model_foundations.scale_shift.shift.clone() +# return model diff --git a/hydragnn/utils/mace_utils/tools/torch_tools.py b/hydragnn/utils/mace_utils/tools/torch_tools.py index 1ec3ecde7..dc949cf50 100644 --- a/hydragnn/utils/mace_utils/tools/torch_tools.py +++ b/hydragnn/utils/mace_utils/tools/torch_tools.py @@ -15,83 +15,78 @@ TensorDict = Dict[str, torch.Tensor] -def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: - """ - Generates one-hot encoding with classes from - :param indices: (N x 1) tensor - :param num_classes: number of classes - :param device: torch device - :return: (N x num_classes) tensor - """ - shape = indices.shape[:-1] + (num_classes,) - oh = torch.zeros(shape, device=indices.device).view(shape) +# def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: +# """ +# Generates one-hot encoding with classes from +# :param indices: (N x 1) tensor +# :param num_classes: number of classes +# :param device: torch device +# :return: (N x num_classes) tensor +# """ +# shape = indices.shape[:-1] + (num_classes,) +# oh = torch.zeros(shape, device=indices.device).view(shape) - # scatter_ is the in-place version of scatter - oh.scatter_(dim=-1, index=indices, value=1) +# # scatter_ is the in-place version of scatter +# oh.scatter_(dim=-1, index=indices, value=1) - return oh.view(*shape) +# return oh.view(*shape) def count_parameters(module: torch.nn.Module) -> int: return int(sum(np.prod(p.shape) for p in module.parameters())) -def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: - return {k: v.to(device) if v is not None else None for k, v in td.items()} - - -def set_seeds(seed: int) -> None: - np.random.seed(seed) - torch.manual_seed(seed) +# def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: +# return {k: v.to(device) if v is not None else None for k, v in td.items()} def to_numpy(t: torch.Tensor) -> np.ndarray: return t.cpu().detach().numpy() -def init_device(device_str: str) -> torch.device: - if "cuda" in device_str: - assert torch.cuda.is_available(), "No CUDA device available!" - if ":" in device_str: - # Check if the desired device is available - assert int(device_str.split(":")[-1]) < torch.cuda.device_count() - logging.info( - f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" - ) - torch.cuda.init() - return torch.device(device_str) - if device_str == "mps": - assert torch.backends.mps.is_available(), "No MPS backend is available!" - logging.info("Using MPS GPU acceleration") - return torch.device("mps") +# def init_device(device_str: str) -> torch.device: +# if "cuda" in device_str: +# assert torch.cuda.is_available(), "No CUDA device available!" +# if ":" in device_str: +# # Check if the desired device is available +# assert int(device_str.split(":")[-1]) < torch.cuda.device_count() +# logging.info( +# f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" +# ) +# torch.cuda.init() +# return torch.device(device_str) +# if device_str == "mps": +# assert torch.backends.mps.is_available(), "No MPS backend is available!" +# logging.info("Using MPS GPU acceleration") +# return torch.device("mps") - logging.info("Using CPU") - return torch.device("cpu") +# logging.info("Using CPU") +# return torch.device("cpu") -dtype_dict = {"float32": torch.float32, "float64": torch.float64} +# dtype_dict = {"float32": torch.float32, "float64": torch.float64} -def set_default_dtype(dtype: str) -> None: - torch.set_default_dtype(dtype_dict[dtype]) +# def set_default_dtype(dtype: str) -> None: +# torch.set_default_dtype(dtype_dict[dtype]) -def spherical_to_cartesian(t: torch.Tensor): - """ - Convert spherical notation to cartesian notation - """ - stress_cart_tensor = CartesianTensor("ij=ji") - stress_rtp = stress_cart_tensor.reduced_tensor_products() - return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) +# def spherical_to_cartesian(t: torch.Tensor): +# """ +# Convert spherical notation to cartesian notation +# """ +# stress_cart_tensor = CartesianTensor("ij=ji") +# stress_rtp = stress_cart_tensor.reduced_tensor_products() +# return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) -def cartesian_to_spherical(t: torch.Tensor): - """ - Convert cartesian notation to spherical notation - """ - stress_cart_tensor = CartesianTensor("ij=ji") - stress_rtp = stress_cart_tensor.reduced_tensor_products() - return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) +# def cartesian_to_spherical(t: torch.Tensor): +# """ +# Convert cartesian notation to spherical notation +# """ +# stress_cart_tensor = CartesianTensor("ij=ji") +# stress_rtp = stress_cart_tensor.reduced_tensor_products() +# return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) def voigt_to_matrix(t: torch.Tensor): @@ -119,20 +114,20 @@ def voigt_to_matrix(t: torch.Tensor): ) -def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): - import wandb +# def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): +# import wandb - wandb.init(project=project, entity=entity, name=name, config=config, dir=directory) +# wandb.init(project=project, entity=entity, name=name, config=config, dir=directory) -@contextmanager -def default_dtype(dtype: torch.dtype): - """Context manager for configuring the default_dtype used by torch +# @contextmanager +# def default_dtype(dtype: torch.dtype): +# """Context manager for configuring the default_dtype used by torch - Args: - dtype (torch.dtype): the default dtype to use within this context manager - """ - init = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(init) +# Args: +# dtype (torch.dtype): the default dtype to use within this context manager +# """ +# init = torch.get_default_dtype() +# torch.set_default_dtype(dtype) +# yield +# torch.set_default_dtype(init) From af49a73dfd9e8ef371fb2cb70332044277813fc9 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 27 Sep 2024 14:11:27 -0400 Subject: [PATCH 34/51] taking more mace utils out --- hydragnn/utils/mace_utils/modules/utils.py | 16 +- hydragnn/utils/mace_utils/tools/__init__.py | 25 --- hydragnn/utils/mace_utils/tools/compile.py | 11 -- .../mace_utils/tools/finetuning_utils.py | 149 ---------------- .../utils/mace_utils/tools/torch_tools.py | 133 -------------- hydragnn/utils/mace_utils/tools/utils.py | 168 ------------------ 6 files changed, 1 insertion(+), 501 deletions(-) delete mode 100644 hydragnn/utils/mace_utils/tools/finetuning_utils.py delete mode 100644 hydragnn/utils/mace_utils/tools/torch_tools.py delete mode 100644 hydragnn/utils/mace_utils/tools/utils.py diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/mace_utils/modules/utils.py index 7e6fbfc25..a390304cd 100644 --- a/hydragnn/utils/mace_utils/modules/utils.py +++ b/hydragnn/utils/mace_utils/modules/utils.py @@ -4,17 +4,12 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### -import logging -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import torch import torch.nn import torch.utils.data -from scipy.constants import c, e - -from hydragnn.utils.mace_utils.tools import to_numpy -from hydragnn.utils.mace_utils.tools.scatter import scatter_sum from .blocks import AtomicEnergiesBlock @@ -37,15 +32,6 @@ def get_edge_vectors_and_lengths( return vectors, lengths -# def _check_non_zero(std): -# if std == 0.0: -# logging.warning( -# "Standard deviation of the scaling is zero, Changing to no scaling" -# ) -# std = 1.0 -# return std - - def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): out = [] for i in range(num_layers - 1): diff --git a/hydragnn/utils/mace_utils/tools/__init__.py b/hydragnn/utils/mace_utils/tools/__init__.py index 9464eef4f..cd5fb8634 100644 --- a/hydragnn/utils/mace_utils/tools/__init__.py +++ b/hydragnn/utils/mace_utils/tools/__init__.py @@ -1,30 +1,5 @@ from .cg import U_matrix_real -# from .finetuning_utils import load_foundations -from .torch_tools import ( - TensorDict, - # cartesian_to_spherical, - count_parameters, - # init_device, - # init_wandb, - # set_default_dtype, - # spherical_to_cartesian, - to_numpy, - # to_one_hot, - voigt_to_matrix, -) - __all__ = [ - "TensorDict", - "to_numpy", - # "to_one_hot", - # "init_device", - "count_parameters", - # "set_default_dtype", "U_matrix_real", - # "spherical_to_cartesian", - # "cartesian_to_spherical", - "voigt_to_matrix", - # "init_wandb", - # "load_foundations", ] diff --git a/hydragnn/utils/mace_utils/tools/compile.py b/hydragnn/utils/mace_utils/tools/compile.py index 425e4c02d..9bd2620af 100644 --- a/hydragnn/utils/mace_utils/tools/compile.py +++ b/hydragnn/utils/mace_utils/tools/compile.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from functools import wraps from typing import Callable, Tuple @@ -6,7 +5,6 @@ import torch._dynamo as dynamo except ImportError: dynamo = None -from e3nn import get_optimization_defaults, set_optimization_defaults from torch import autograd, nn from torch.fx import symbolic_trace @@ -14,15 +12,6 @@ TypeTuple = Tuple[type, ...] -@contextmanager -def disable_e3nn_codegen(): - """Context manager that disables the legacy PyTorch code generation used in e3nn.""" - init_val = get_optimization_defaults()["jit_script_fx"] - set_optimization_defaults(jit_script_fx=False) - yield - set_optimization_defaults(jit_script_fx=init_val) - - def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: """Function transform that prepares a MACE module for torch.compile diff --git a/hydragnn/utils/mace_utils/tools/finetuning_utils.py b/hydragnn/utils/mace_utils/tools/finetuning_utils.py deleted file mode 100644 index 219e416ae..000000000 --- a/hydragnn/utils/mace_utils/tools/finetuning_utils.py +++ /dev/null @@ -1,149 +0,0 @@ -# import torch - -# from hydragnn.utils.mace_utils.tools.utils import AtomicNumberTable - - -# def load_foundations( -# model: torch.nn.Module, -# model_foundations: torch.nn.Module, -# table: AtomicNumberTable, -# load_readout=False, -# use_shift=False, -# use_scale=True, -# max_L=2, -# ): -# """ -# Load the foundations of a model into a model for fine-tuning. -# """ -# assert model_foundations.r_max == model.r_max -# z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) -# new_z_table = table -# num_species_foundations = len(z_table.zs) -# num_channels_foundation = ( -# model_foundations.node_embedding.linear.weight.shape[0] -# // num_species_foundations -# ) -# indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] -# num_radial = model.radial_embedding.out_dim -# num_species = len(indices_weights) -# max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access -# model.node_embedding.linear.weight = torch.nn.Parameter( -# model_foundations.node_embedding.linear.weight.view( -# num_species_foundations, -1 -# )[indices_weights, :] -# .flatten() -# .clone() -# / (num_species_foundations / num_species) ** 0.5 -# ) -# if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": -# model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( -# model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() -# ) - -# for i in range(int(model.num_interactions)): -# model.interactions[i].linear_up.weight = torch.nn.Parameter( -# model_foundations.interactions[i].linear_up.weight.clone() -# ) -# model.interactions[i].avg_num_neighbors = model_foundations.interactions[ -# i -# ].avg_num_neighbors -# for j in range(4): # Assuming 4 layers in conv_tp_weights, -# layer_name = f"layer{j}" -# if j == 0: -# getattr( -# model.interactions[i].conv_tp_weights, layer_name -# ).weight = torch.nn.Parameter( -# getattr( -# model_foundations.interactions[i].conv_tp_weights, -# layer_name, -# ) -# .weight[:num_radial, :] -# .clone() -# ) -# else: -# getattr( -# model.interactions[i].conv_tp_weights, layer_name -# ).weight = torch.nn.Parameter( -# getattr( -# model_foundations.interactions[i].conv_tp_weights, -# layer_name, -# ).weight.clone() -# ) - -# model.interactions[i].linear.weight = torch.nn.Parameter( -# model_foundations.interactions[i].linear.weight.clone() -# ) -# if ( -# model.interactions[i].__class__.__name__ -# == "RealAgnosticResidualInteractionBlock" -# ): -# model.interactions[i].skip_tp.weight = torch.nn.Parameter( -# model_foundations.interactions[i] -# .skip_tp.weight.reshape( -# num_channels_foundation, -# num_species_foundations, -# num_channels_foundation, -# )[:, indices_weights, :] -# .flatten() -# .clone() -# / (num_species_foundations / num_species) ** 0.5 -# ) -# else: -# model.interactions[i].skip_tp.weight = torch.nn.Parameter( -# model_foundations.interactions[i] -# .skip_tp.weight.reshape( -# num_channels_foundation, -# (max_ell + 1), -# num_species_foundations, -# num_channels_foundation, -# )[:, :, indices_weights, :] -# .flatten() -# .clone() -# / (num_species_foundations / num_species) ** 0.5 -# ) -# # Transferring products -# for i in range(2): # Assuming 2 products modules -# max_range = max_L + 1 if i == 0 else 1 -# for j in range(max_range): # Assuming 3 contractions in symmetric_contractions -# model.products[i].symmetric_contractions.contractions[ -# j -# ].weights_max = torch.nn.Parameter( -# model_foundations.products[i] -# .symmetric_contractions.contractions[j] -# .weights_max[indices_weights, :, :] -# .clone() -# ) - -# for k in range(2): # Assuming 2 weights in each contraction -# model.products[i].symmetric_contractions.contractions[j].weights[ -# k -# ] = torch.nn.Parameter( -# model_foundations.products[i] -# .symmetric_contractions.contractions[j] -# .weights[k][indices_weights, :, :] -# .clone() -# ) - -# model.products[i].linear.weight = torch.nn.Parameter( -# model_foundations.products[i].linear.weight.clone() -# ) - -# if load_readout: -# # Transferring readouts -# model.readouts[0].linear.weight = torch.nn.Parameter( -# model_foundations.readouts[0].linear.weight.clone() -# ) - -# model.readouts[1].linear_1.weight = torch.nn.Parameter( -# model_foundations.readouts[1].linear_1.weight.clone() -# ) - -# model.readouts[1].linear_2.weight = torch.nn.Parameter( -# model_foundations.readouts[1].linear_2.weight.clone() -# ) -# if model_foundations.scale_shift is not None: -# if use_scale: -# model.scale_shift.scale = model_foundations.scale_shift.scale.clone() -# if use_shift: -# model.scale_shift.shift = model_foundations.scale_shift.shift.clone() -# return model diff --git a/hydragnn/utils/mace_utils/tools/torch_tools.py b/hydragnn/utils/mace_utils/tools/torch_tools.py deleted file mode 100644 index dc949cf50..000000000 --- a/hydragnn/utils/mace_utils/tools/torch_tools.py +++ /dev/null @@ -1,133 +0,0 @@ -########################################################################################### -# Tools for torch -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from contextlib import contextmanager -from typing import Dict - -import numpy as np -import torch -from e3nn.io import CartesianTensor - -TensorDict = Dict[str, torch.Tensor] - - -# def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: -# """ -# Generates one-hot encoding with classes from -# :param indices: (N x 1) tensor -# :param num_classes: number of classes -# :param device: torch device -# :return: (N x num_classes) tensor -# """ -# shape = indices.shape[:-1] + (num_classes,) -# oh = torch.zeros(shape, device=indices.device).view(shape) - -# # scatter_ is the in-place version of scatter -# oh.scatter_(dim=-1, index=indices, value=1) - -# return oh.view(*shape) - - -def count_parameters(module: torch.nn.Module) -> int: - return int(sum(np.prod(p.shape) for p in module.parameters())) - - -# def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: -# return {k: v.to(device) if v is not None else None for k, v in td.items()} - - -def to_numpy(t: torch.Tensor) -> np.ndarray: - return t.cpu().detach().numpy() - - -# def init_device(device_str: str) -> torch.device: -# if "cuda" in device_str: -# assert torch.cuda.is_available(), "No CUDA device available!" -# if ":" in device_str: -# # Check if the desired device is available -# assert int(device_str.split(":")[-1]) < torch.cuda.device_count() -# logging.info( -# f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" -# ) -# torch.cuda.init() -# return torch.device(device_str) -# if device_str == "mps": -# assert torch.backends.mps.is_available(), "No MPS backend is available!" -# logging.info("Using MPS GPU acceleration") -# return torch.device("mps") - -# logging.info("Using CPU") -# return torch.device("cpu") - - -# dtype_dict = {"float32": torch.float32, "float64": torch.float64} - - -# def set_default_dtype(dtype: str) -> None: -# torch.set_default_dtype(dtype_dict[dtype]) - - -# def spherical_to_cartesian(t: torch.Tensor): -# """ -# Convert spherical notation to cartesian notation -# """ -# stress_cart_tensor = CartesianTensor("ij=ji") -# stress_rtp = stress_cart_tensor.reduced_tensor_products() -# return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) - - -# def cartesian_to_spherical(t: torch.Tensor): -# """ -# Convert cartesian notation to spherical notation -# """ -# stress_cart_tensor = CartesianTensor("ij=ji") -# stress_rtp = stress_cart_tensor.reduced_tensor_products() -# return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) - - -def voigt_to_matrix(t: torch.Tensor): - """ - Convert voigt notation to matrix notation - :param t: (6,) tensor or (3, 3) tensor or (9,) tensor - :return: (3, 3) tensor - """ - if t.shape == (3, 3): - return t - if t.shape == (6,): - return torch.tensor( - [ - [t[0], t[5], t[4]], - [t[5], t[1], t[3]], - [t[4], t[3], t[2]], - ], - dtype=t.dtype, - ) - if t.shape == (9,): - return t.view(3, 3) - - raise ValueError( - f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" - ) - - -# def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): -# import wandb - -# wandb.init(project=project, entity=entity, name=name, config=config, dir=directory) - - -# @contextmanager -# def default_dtype(dtype: torch.dtype): -# """Context manager for configuring the default_dtype used by torch - -# Args: -# dtype (torch.dtype): the default dtype to use within this context manager -# """ -# init = torch.get_default_dtype() -# torch.set_default_dtype(dtype) -# yield -# torch.set_default_dtype(init) diff --git a/hydragnn/utils/mace_utils/tools/utils.py b/hydragnn/utils/mace_utils/tools/utils.py deleted file mode 100644 index 762d98802..000000000 --- a/hydragnn/utils/mace_utils/tools/utils.py +++ /dev/null @@ -1,168 +0,0 @@ -########################################################################################### -# Statistics utilities -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import json -import logging -import os -import sys -from typing import Any, Dict, Iterable, Optional, Sequence, Union - -import numpy as np -import torch - -from .torch_tools import to_numpy - - -def compute_mae(delta: np.ndarray) -> float: - return np.mean(np.abs(delta)).item() - - -def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float: - target_norm = np.mean(np.abs(target_val)) - return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100 - - -def compute_rmse(delta: np.ndarray) -> float: - return np.sqrt(np.mean(np.square(delta))).item() - - -def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float: - target_norm = np.sqrt(np.mean(np.square(target_val))).item() - return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100 - - -def compute_q95(delta: np.ndarray) -> float: - return np.percentile(np.abs(delta), q=95) - - -def compute_c(delta: np.ndarray, eta: float) -> float: - return np.mean(np.abs(delta) < eta).item() - - -def get_tag(name: str, seed: int) -> str: - return f"{name}_run-{seed}" - - -def setup_logger( - level: Union[int, str] = logging.INFO, - tag: Optional[str] = None, - directory: Optional[str] = None, - rank: Optional[int] = 0, -): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels - - # Create formatters - formatter = logging.Formatter( - "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - # Add filter for rank - logger.addFilter(lambda _: rank == 0) - - # Create console handler - ch = logging.StreamHandler(stream=sys.stdout) - ch.setLevel(level) - ch.setFormatter(formatter) - logger.addHandler(ch) - - if directory is not None and tag is not None: - os.makedirs(name=directory, exist_ok=True) - - # Create file handler for non-debug logs - main_log_path = os.path.join(directory, f"{tag}.log") - fh_main = logging.FileHandler(main_log_path) - fh_main.setLevel(level) - fh_main.setFormatter(formatter) - logger.addHandler(fh_main) - - # Create file handler for debug logs - debug_log_path = os.path.join(directory, f"{tag}_debug.log") - fh_debug = logging.FileHandler(debug_log_path) - fh_debug.setLevel(logging.DEBUG) - fh_debug.setFormatter(formatter) - fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) - logger.addHandler(fh_debug) - - -class AtomicNumberTable: - def __init__(self, zs: Sequence[int]): - self.zs = zs - - def __len__(self) -> int: - return len(self.zs) - - def __str__(self): - return f"AtomicNumberTable: {tuple(s for s in self.zs)}" - - def index_to_z(self, index: int) -> int: - return self.zs[index] - - def z_to_index(self, atomic_number: str) -> int: - return self.zs.index(atomic_number) - - -def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: - z_set = set() - for z in zs: - z_set.add(z) - return AtomicNumberTable(sorted(list(z_set))) - - -def atomic_numbers_to_indices( - atomic_numbers: np.ndarray, z_table: AtomicNumberTable -) -> np.ndarray: - to_index_fn = np.vectorize(z_table.z_to_index) - return to_index_fn(atomic_numbers) - - -def get_optimizer( - name: str, - amsgrad: bool, - learning_rate: float, - weight_decay: float, - parameters: Iterable[torch.Tensor], -) -> torch.optim.Optimizer: - if name == "adam": - return torch.optim.Adam( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - if name == "adamw": - return torch.optim.AdamW( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - raise RuntimeError(f"Unknown optimizer '{name}'") - - -class UniversalEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, np.integer): - return int(o) - if isinstance(o, np.floating): - return float(o) - if isinstance(o, np.ndarray): - return o.tolist() - if isinstance(o, torch.Tensor): - return to_numpy(o) - return json.JSONEncoder.default(self, o) - - -class MetricsLogger: - def __init__(self, directory: str, tag: str) -> None: - self.directory = directory - self.filename = tag + ".txt" - self.path = os.path.join(self.directory, self.filename) - - def log(self, d: Dict[str, Any]) -> None: - logging.debug(f"Saving info: {self.path}") - os.makedirs(name=self.directory, exist_ok=True) - with open(self.path, mode="a", encoding="utf-8") as f: - f.write(json.dumps(d, cls=UniversalEncoder)) - f.write("\n") From d4bef5e10dca91c9fee3016e3d1894b35e4e63d8 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 27 Sep 2024 15:29:42 -0400 Subject: [PATCH 35/51] rebase MACE and make the tests run a little faster --- hydragnn/models/MACEStack.py | 4 +- .../input_config_parsing/config_utils.py | 2 +- .../mace_utils/modules/__init__.py | 0 .../{ => model}/mace_utils/modules/blocks.py | 4 +- .../mace_utils/modules/irreps_tools.py | 0 .../{ => model}/mace_utils/modules/radial.py | 4 +- .../modules/symmetric_contraction.py | 2 +- .../{ => model}/mace_utils/modules/utils.py | 0 .../{ => model}/mace_utils/tools/__init__.py | 0 .../utils/{ => model}/mace_utils/tools/cg.py | 0 .../{ => model}/mace_utils/tools/compile.py | 50 +++++++++---------- .../{ => model}/mace_utils/tools/scatter.py | 0 tests/inputs/ci.json | 2 +- tests/inputs/ci_equivariant.json | 2 +- tests/inputs/ci_multihead.json | 2 +- tests/inputs/ci_vectoroutput.json | 2 +- tests/test_model_loadpred.py | 2 +- 17 files changed, 38 insertions(+), 38 deletions(-) rename hydragnn/utils/{ => model}/mace_utils/modules/__init__.py (100%) rename hydragnn/utils/{ => model}/mace_utils/modules/blocks.py (99%) rename hydragnn/utils/{ => model}/mace_utils/modules/irreps_tools.py (100%) rename hydragnn/utils/{ => model}/mace_utils/modules/radial.py (98%) rename hydragnn/utils/{ => model}/mace_utils/modules/symmetric_contraction.py (99%) rename hydragnn/utils/{ => model}/mace_utils/modules/utils.py (100%) rename hydragnn/utils/{ => model}/mace_utils/tools/__init__.py (100%) rename hydragnn/utils/{ => model}/mace_utils/tools/cg.py (100%) rename hydragnn/utils/{ => model}/mace_utils/tools/compile.py (63%) rename hydragnn/utils/{ => model}/mace_utils/tools/scatter.py (100%) diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index c09fd39b8..7c015714a 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -42,13 +42,13 @@ from torch_geometric.nn import global_mean_pool # Mace -from hydragnn.utils.mace_utils.modules.blocks import ( +from hydragnn.utils.model.mace_utils.modules.blocks import ( EquivariantProductBasisBlock, LinearNodeEmbeddingBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, ) -from hydragnn.utils.mace_utils.modules.utils import ( +from hydragnn.utils.model.mace_utils.modules.utils import ( get_edge_vectors_and_lengths, ) diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index c37e0881e..daa3f68ec 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -14,7 +14,7 @@ check_if_graph_size_variable, gather_deg, ) -from hydragnn.utils.model import calculate_avg_deg +from hydragnn.utils.model.model import calculate_avg_deg from hydragnn.utils.distributed import get_comm_size_and_rank from copy import deepcopy import json diff --git a/hydragnn/utils/mace_utils/modules/__init__.py b/hydragnn/utils/model/mace_utils/modules/__init__.py similarity index 100% rename from hydragnn/utils/mace_utils/modules/__init__.py rename to hydragnn/utils/model/mace_utils/modules/__init__.py diff --git a/hydragnn/utils/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py similarity index 99% rename from hydragnn/utils/mace_utils/modules/blocks.py rename to hydragnn/utils/model/mace_utils/modules/blocks.py index b77f4811d..4291912ae 100644 --- a/hydragnn/utils/mace_utils/modules/blocks.py +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -12,8 +12,8 @@ from e3nn import nn, o3 from e3nn.util.jit import compile_mode -from hydragnn.utils.mace_utils.tools.compile import simplify_if_compile -from hydragnn.utils.mace_utils.tools.scatter import scatter_sum +from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile +from hydragnn.utils.model.mace_utils.tools.scatter import scatter_sum from .irreps_tools import ( linear_out_irreps, diff --git a/hydragnn/utils/mace_utils/modules/irreps_tools.py b/hydragnn/utils/model/mace_utils/modules/irreps_tools.py similarity index 100% rename from hydragnn/utils/mace_utils/modules/irreps_tools.py rename to hydragnn/utils/model/mace_utils/modules/irreps_tools.py diff --git a/hydragnn/utils/mace_utils/modules/radial.py b/hydragnn/utils/model/mace_utils/modules/radial.py similarity index 98% rename from hydragnn/utils/mace_utils/modules/radial.py rename to hydragnn/utils/model/mace_utils/modules/radial.py index 94c0b8064..8b3f7ebf1 100644 --- a/hydragnn/utils/mace_utils/modules/radial.py +++ b/hydragnn/utils/model/mace_utils/modules/radial.py @@ -9,8 +9,8 @@ import torch from e3nn.util.jit import compile_mode -from hydragnn.utils.mace_utils.tools.compile import simplify_if_compile -from hydragnn.utils.mace_utils.tools.scatter import scatter_sum +from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile +from hydragnn.utils.model.mace_utils.tools.scatter import scatter_sum @compile_mode("script") diff --git a/hydragnn/utils/mace_utils/modules/symmetric_contraction.py b/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py similarity index 99% rename from hydragnn/utils/mace_utils/modules/symmetric_contraction.py rename to hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py index 5c807c717..8f2edd1c5 100644 --- a/hydragnn/utils/mace_utils/modules/symmetric_contraction.py +++ b/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py @@ -14,7 +14,7 @@ from e3nn.util.codegen import CodeGenMixin from e3nn.util.jit import compile_mode -from hydragnn.utils.mace_utils.tools.cg import U_matrix_real +from hydragnn.utils.model.mace_utils.tools.cg import U_matrix_real BATCH_EXAMPLE = 10 ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] diff --git a/hydragnn/utils/mace_utils/modules/utils.py b/hydragnn/utils/model/mace_utils/modules/utils.py similarity index 100% rename from hydragnn/utils/mace_utils/modules/utils.py rename to hydragnn/utils/model/mace_utils/modules/utils.py diff --git a/hydragnn/utils/mace_utils/tools/__init__.py b/hydragnn/utils/model/mace_utils/tools/__init__.py similarity index 100% rename from hydragnn/utils/mace_utils/tools/__init__.py rename to hydragnn/utils/model/mace_utils/tools/__init__.py diff --git a/hydragnn/utils/mace_utils/tools/cg.py b/hydragnn/utils/model/mace_utils/tools/cg.py similarity index 100% rename from hydragnn/utils/mace_utils/tools/cg.py rename to hydragnn/utils/model/mace_utils/tools/cg.py diff --git a/hydragnn/utils/mace_utils/tools/compile.py b/hydragnn/utils/model/mace_utils/tools/compile.py similarity index 63% rename from hydragnn/utils/mace_utils/tools/compile.py rename to hydragnn/utils/model/mace_utils/tools/compile.py index 9bd2620af..10f13d1ac 100644 --- a/hydragnn/utils/mace_utils/tools/compile.py +++ b/hydragnn/utils/model/mace_utils/tools/compile.py @@ -12,31 +12,31 @@ TypeTuple = Tuple[type, ...] -def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: - """Function transform that prepares a MACE module for torch.compile - - Args: - func (ModuleFactory): A function that creates an nn.Module - allow_autograd (bool, optional): Force inductor compiler to inline call to - `torch.autograd.grad`. Defaults to True. - - Returns: - ModuleFactory: Decorated function that creates a torch.compile compatible module - """ - if allow_autograd: - dynamo.allow_in_graph(autograd.grad) - elif dynamo.allowed_functions.is_allowed(autograd.grad): - dynamo.disallow_in_graph(autograd.grad) - - @wraps(func) - def wrapper(*args, **kwargs): - with disable_e3nn_codegen(): - model = func(*args, **kwargs) - - model = simplify(model) - return model - - return wrapper +# def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: +# """Function transform that prepares a MACE module for torch.compile + +# Args: +# func (ModuleFactory): A function that creates an nn.Module +# allow_autograd (bool, optional): Force inductor compiler to inline call to +# `torch.autograd.grad`. Defaults to True. + +# Returns: +# ModuleFactory: Decorated function that creates a torch.compile compatible module +# """ +# if allow_autograd: +# dynamo.allow_in_graph(autograd.grad) +# elif dynamo.allowed_functions.is_allowed(autograd.grad): +# dynamo.disallow_in_graph(autograd.grad) + +# @wraps(func) +# def wrapper(*args, **kwargs): +# with disable_e3nn_codegen(): +# model = func(*args, **kwargs) + +# model = simplify(model) +# return model + +# return wrapper _SIMPLIFY_REGISTRY = set() diff --git a/hydragnn/utils/mace_utils/tools/scatter.py b/hydragnn/utils/model/mace_utils/tools/scatter.py similarity index 100% rename from hydragnn/utils/mace_utils/tools/scatter.py rename to hydragnn/utils/model/mace_utils/tools/scatter.py diff --git a/tests/inputs/ci.json b/tests/inputs/ci.json index 951a44adf..ad05c10e4 100644 --- a/tests/inputs/ci.json +++ b/tests/inputs/ci.json @@ -66,7 +66,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 100, + "num_epoch": 10, "perc_train": 0.7, "EarlyStopping": true, "patience": 10, diff --git a/tests/inputs/ci_equivariant.json b/tests/inputs/ci_equivariant.json index 51175e9e7..92f02e974 100644 --- a/tests/inputs/ci_equivariant.json +++ b/tests/inputs/ci_equivariant.json @@ -67,7 +67,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 100, + "num_epoch": 10, "perc_train": 0.7, "EarlyStopping": true, "patience": 10, diff --git a/tests/inputs/ci_multihead.json b/tests/inputs/ci_multihead.json index f408c4aa8..5922241dd 100644 --- a/tests/inputs/ci_multihead.json +++ b/tests/inputs/ci_multihead.json @@ -64,7 +64,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 100, + "num_epoch": 10, "Checkpoint": true, "checkpoint_warmup": 10, "perc_train": 0.7, diff --git a/tests/inputs/ci_vectoroutput.json b/tests/inputs/ci_vectoroutput.json index ddb616615..4c2980be0 100644 --- a/tests/inputs/ci_vectoroutput.json +++ b/tests/inputs/ci_vectoroutput.json @@ -56,7 +56,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 80, + "num_epoch": 10, "Checkpoint": true, "checkpoint_warmup": 10, "perc_train": 0.7, diff --git a/tests/test_model_loadpred.py b/tests/test_model_loadpred.py index 74d84b7dc..a3e3438e3 100755 --- a/tests/test_model_loadpred.py +++ b/tests/test_model_loadpred.py @@ -13,7 +13,7 @@ import random import hydragnn from tests.test_graphs import unittest_train_model -from hydragnn.utils.config_utils import update_config +from hydragnn.utils.input_config_parsing.config_utils import update_config def unittest_model_prediction(config): From a6ae103a1cba5c7b5ff3317e6e7ab7876b957428 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Mon, 30 Sep 2024 13:52:07 -0400 Subject: [PATCH 36/51] torch scatter update --- .../model/mace_utils/modules/__init__.py | 4 +- .../utils/model/mace_utils/modules/blocks.py | 60 +--------- .../utils/model/mace_utils/modules/radial.py | 6 +- .../utils/model/mace_utils/tools/compile.py | 27 ----- .../utils/model/mace_utils/tools/scatter.py | 112 ------------------ 5 files changed, 9 insertions(+), 200 deletions(-) delete mode 100644 hydragnn/utils/model/mace_utils/tools/scatter.py diff --git a/hydragnn/utils/model/mace_utils/modules/__init__.py b/hydragnn/utils/model/mace_utils/modules/__init__.py index f461a338a..f1b701ec2 100644 --- a/hydragnn/utils/model/mace_utils/modules/__init__.py +++ b/hydragnn/utils/model/mace_utils/modules/__init__.py @@ -6,10 +6,10 @@ AtomicEnergiesBlock, EquivariantProductBasisBlock, InteractionBlock, - LinearDipoleReadoutBlock, + # LinearDipoleReadoutBlock, LinearNodeEmbeddingBlock, LinearReadoutBlock, - NonLinearDipoleReadoutBlock, + # NonLinearDipoleReadoutBlock, NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, diff --git a/hydragnn/utils/model/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py index 4291912ae..010051dda 100644 --- a/hydragnn/utils/model/mace_utils/modules/blocks.py +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -13,7 +13,7 @@ from e3nn.util.jit import compile_mode from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile -from hydragnn.utils.model.mace_utils.tools.scatter import scatter_sum +from torch_scatter import scatter from .irreps_tools import ( linear_out_irreps, @@ -75,60 +75,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [... return self.linear_2(x) # [n_nodes, 1] -@compile_mode("script") -class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): - super().__init__() - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - return self.linear(x) # [n_nodes, 1] - - -@compile_mode("script") -class NonLinearDipoleReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - MLP_irreps: o3.Irreps, - gate: Callable, - dipole_only: bool = False, - ): - super().__init__() - self.hidden_irreps = MLP_irreps - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") - irreps_scalars = o3.Irreps( - [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] - ) - irreps_gated = o3.Irreps( - [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] - ) - irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) - self.equivariant_nonlin = nn.Gate( - irreps_scalars=irreps_scalars, - act_scalars=[gate for _, ir in irreps_scalars], - irreps_gates=irreps_gates, - act_gates=[gate] * len(irreps_gates), - irreps_gated=irreps_gated, - ) - self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() - self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) - self.linear_2 = o3.Linear( - irreps_in=self.hidden_irreps, irreps_out=self.irreps_out - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.equivariant_nonlin(self.linear_1(x)) - return self.linear_2(x) # [n_nodes, 1] - - @compile_mode("script") class AtomicEnergiesBlock(torch.nn.Module): atomic_energies: torch.Tensor @@ -413,8 +359,8 @@ def forward( mji = self.conv_tp( node_feats_up[sender], edge_attrs, tp_weights ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes + message = scatter( + src=mji, index=receiver, dim=0, dim_size=num_nodes, reduce="sum" ) # [n_nodes, irreps] message = self.linear(message) / self.avg_num_neighbors return ( diff --git a/hydragnn/utils/model/mace_utils/modules/radial.py b/hydragnn/utils/model/mace_utils/modules/radial.py index 8b3f7ebf1..f9c703b79 100644 --- a/hydragnn/utils/model/mace_utils/modules/radial.py +++ b/hydragnn/utils/model/mace_utils/modules/radial.py @@ -10,7 +10,7 @@ from e3nn.util.jit import compile_mode from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile -from hydragnn.utils.model.mace_utils.tools.scatter import scatter_sum +from torch_scatter import scatter @compile_mode("script") @@ -215,7 +215,9 @@ def forward( - (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2) ) * (x < r_max) v_edges = 0.5 * v_edges * envelope - V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) + V_ZBL = scatter( + v_edges, receiver, dim=0, dim_size=node_attrs.size(0), reduce="sum" + ) return V_ZBL.squeeze(-1) def __repr__(self): diff --git a/hydragnn/utils/model/mace_utils/tools/compile.py b/hydragnn/utils/model/mace_utils/tools/compile.py index 10f13d1ac..1f5722687 100644 --- a/hydragnn/utils/model/mace_utils/tools/compile.py +++ b/hydragnn/utils/model/mace_utils/tools/compile.py @@ -12,33 +12,6 @@ TypeTuple = Tuple[type, ...] -# def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: -# """Function transform that prepares a MACE module for torch.compile - -# Args: -# func (ModuleFactory): A function that creates an nn.Module -# allow_autograd (bool, optional): Force inductor compiler to inline call to -# `torch.autograd.grad`. Defaults to True. - -# Returns: -# ModuleFactory: Decorated function that creates a torch.compile compatible module -# """ -# if allow_autograd: -# dynamo.allow_in_graph(autograd.grad) -# elif dynamo.allowed_functions.is_allowed(autograd.grad): -# dynamo.disallow_in_graph(autograd.grad) - -# @wraps(func) -# def wrapper(*args, **kwargs): -# with disable_e3nn_codegen(): -# model = func(*args, **kwargs) - -# model = simplify(model) -# return model - -# return wrapper - - _SIMPLIFY_REGISTRY = set() diff --git a/hydragnn/utils/model/mace_utils/tools/scatter.py b/hydragnn/utils/model/mace_utils/tools/scatter.py deleted file mode 100644 index 7e1139a99..000000000 --- a/hydragnn/utils/model/mace_utils/tools/scatter.py +++ /dev/null @@ -1,112 +0,0 @@ -"""basic scatter_sum operations from torch_scatter from -https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py -Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. -PyTorch plans to move these features into the main repo, but until then, -to make installation simpler, we need this pure python set of wrappers -that don't require installing PyTorch C++ extensions. -See https://github.com/pytorch/pytorch/issues/63780. -""" - -from typing import Optional - -import torch - - -def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand_as(other) - return src - - -def scatter_sum( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - reduce: str = "sum", -) -> torch.Tensor: - assert reduce == "sum" # for now, TODO - index = _broadcast(index, src, dim) - if out is None: - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - elif index.numel() == 0: - size[dim] = 0 - else: - size[dim] = int(index.max()) + 1 - out = torch.zeros(size, dtype=src.dtype, device=src.device) - return out.scatter_add_(dim, index, src) - else: - return out.scatter_add_(dim, index, src) - - -def scatter_std( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - unbiased: bool = True, -) -> torch.Tensor: - if out is not None: - dim_size = out.size(dim) - - if dim < 0: - dim = src.dim() + dim - - count_dim = dim - if index.dim() <= dim: - count_dim = index.dim() - 1 - - ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) - count = scatter_sum(ones, index, count_dim, dim_size=dim_size) - - index = _broadcast(index, src, dim) - tmp = scatter_sum(src, index, dim, dim_size=dim_size) - count = _broadcast(count, tmp, dim).clamp(1) - mean = tmp.div(count) - - var = src - mean.gather(dim, index) - var = var * var - out = scatter_sum(var, index, dim, out, dim_size) - - if unbiased: - count = count.sub(1).clamp_(1) - out = out.div(count + 1e-6).sqrt() - - return out - - -def scatter_mean( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, -) -> torch.Tensor: - out = scatter_sum(src, index, dim, out, dim_size) - dim_size = out.size(dim) - - index_dim = dim - if index_dim < 0: - index_dim = index_dim + src.dim() - if index.dim() <= index_dim: - index_dim = index.dim() - 1 - - ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) - count = scatter_sum(ones, index, index_dim, None, dim_size) - count[count < 1] = 1 - count = _broadcast(count, out, dim) - if out.is_floating_point(): - out.true_divide_(count) - else: - out.div_(count, rounding_mode="floor") - return out From b1e146d5c1e5759a595277bc466163efb40be970 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Mon, 30 Sep 2024 14:15:02 -0400 Subject: [PATCH 37/51] Move around utils --- hydragnn/models/MACEStack.py | 2 +- .../{mace_utils/modules => }/irreps_tools.py | 16 +++++++ .../utils/model/mace_utils/modules/blocks.py | 2 +- .../utils/model/mace_utils/tools/compile.py | 44 +++++++++++++++++++ .../modules/utils.py => operations.py} | 33 +++----------- hydragnn/utils/model/utils.py | 20 +++++++++ 6 files changed, 87 insertions(+), 30 deletions(-) rename hydragnn/utils/model/{mace_utils/modules => }/irreps_tools.py (87%) rename hydragnn/utils/model/{mace_utils/modules/utils.py => operations.py} (58%) create mode 100644 hydragnn/utils/model/utils.py diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 7c015714a..3aa9ea74f 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -48,7 +48,7 @@ RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, ) -from hydragnn.utils.model.mace_utils.modules.utils import ( +from hydragnn.utils.model.utils import ( get_edge_vectors_and_lengths, ) diff --git a/hydragnn/utils/model/mace_utils/modules/irreps_tools.py b/hydragnn/utils/model/irreps_tools.py similarity index 87% rename from hydragnn/utils/model/mace_utils/modules/irreps_tools.py rename to hydragnn/utils/model/irreps_tools.py index 642f3fa87..acb513114 100644 --- a/hydragnn/utils/model/mace_utils/modules/irreps_tools.py +++ b/hydragnn/utils/model/irreps_tools.py @@ -84,3 +84,19 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: field = field.reshape(batch, mul, d) out.append(field) return torch.cat(out, dim=-1) + + +def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): + out = [] + for i in range(num_layers - 1): + out.append( + x[ + :, + i + * (l_max + 1) ** 2 + * num_features : (i * (l_max + 1) ** 2 + 1) + * num_features, + ] + ) + out.append(x[:, -num_features:]) + return torch.cat(out, dim=-1) \ No newline at end of file diff --git a/hydragnn/utils/model/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py index 010051dda..e3d3b1da0 100644 --- a/hydragnn/utils/model/mace_utils/modules/blocks.py +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -15,7 +15,7 @@ from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile from torch_scatter import scatter -from .irreps_tools import ( +from hydragnn.utils.model.irreps_tools import ( linear_out_irreps, reshape_irreps, tp_out_irreps_with_instructions, diff --git a/hydragnn/utils/model/mace_utils/tools/compile.py b/hydragnn/utils/model/mace_utils/tools/compile.py index 1f5722687..2df0e4573 100644 --- a/hydragnn/utils/model/mace_utils/tools/compile.py +++ b/hydragnn/utils/model/mace_utils/tools/compile.py @@ -1,3 +1,9 @@ +########################################################################################### +# Compilation utilities for MACE +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This code is pulled from the MACE repository, which is distributed under the MIT License +########################################################################################### +from contextlib import contextmanager from functools import wraps from typing import Callable, Tuple @@ -5,6 +11,7 @@ import torch._dynamo as dynamo except ImportError: dynamo = None +from e3nn import get_optimization_defaults, set_optimization_defaults from torch import autograd, nn from torch.fx import symbolic_trace @@ -12,9 +19,44 @@ TypeTuple = Tuple[type, ...] +@contextmanager +def disable_e3nn_codegen(): + """Context manager that disables the legacy PyTorch code generation used in e3nn.""" + init_val = get_optimization_defaults()["jit_script_fx"] + set_optimization_defaults(jit_script_fx=False) + yield + set_optimization_defaults(jit_script_fx=init_val) + + +def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: + """Function transform that prepares a MACE module for torch.compile + Args: + func (ModuleFactory): A function that creates an nn.Module + allow_autograd (bool, optional): Force inductor compiler to inline call to + `torch.autograd.grad`. Defaults to True. + Returns: + ModuleFactory: Decorated function that creates a torch.compile compatible module + """ + if allow_autograd: + dynamo.allow_in_graph(autograd.grad) + elif dynamo.allowed_functions.is_allowed(autograd.grad): + dynamo.disallow_in_graph(autograd.grad) + + @wraps(func) + def wrapper(*args, **kwargs): + with disable_e3nn_codegen(): + model = func(*args, **kwargs) + + model = simplify(model) + return model + + return wrapper + + _SIMPLIFY_REGISTRY = set() + def simplify_if_compile(module: nn.Module) -> nn.Module: """Decorator to register a module for symbolic simplification @@ -33,6 +75,8 @@ def simplify_if_compile(module: nn.Module) -> nn.Module: return module +# This is a different type of simplify() function than provided in e3nn. In e3nn, simplify() +# combines irreps, while this one simplifies the module for symbolic tracing. def simplify(module: nn.Module) -> nn.Module: """Recursively searches for registered modules to simplify with `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. diff --git a/hydragnn/utils/model/mace_utils/modules/utils.py b/hydragnn/utils/model/operations.py similarity index 58% rename from hydragnn/utils/model/mace_utils/modules/utils.py rename to hydragnn/utils/model/operations.py index a390304cd..a222606d4 100644 --- a/hydragnn/utils/model/mace_utils/modules/utils.py +++ b/hydragnn/utils/model/operations.py @@ -1,19 +1,12 @@ -########################################################################################### -# Utilities -# Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - from typing import List, Tuple - import numpy as np import torch -import torch.nn -import torch.utils.data - -from .blocks import AtomicEnergiesBlock +########################################################################################### +# Function for the computation of edge vectors and lengths (MIT License (see MIT.md)) +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +########################################################################################### def get_edge_vectors_and_lengths( positions: torch.Tensor, # [n_nodes, 3] edge_index: torch.Tensor, # [2, n_edges] @@ -29,20 +22,4 @@ def get_edge_vectors_and_lengths( vectors_normed = vectors / (lengths + eps) return vectors_normed, lengths - return vectors, lengths - - -def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): - out = [] - for i in range(num_layers - 1): - out.append( - x[ - :, - i - * (l_max + 1) ** 2 - * num_features : (i * (l_max + 1) ** 2 + 1) - * num_features, - ] - ) - out.append(x[:, -num_features:]) - return torch.cat(out, dim=-1) + return vectors, lengths \ No newline at end of file diff --git a/hydragnn/utils/model/utils.py b/hydragnn/utils/model/utils.py new file mode 100644 index 000000000..7db557de9 --- /dev/null +++ b/hydragnn/utils/model/utils.py @@ -0,0 +1,20 @@ +########################################################################################### +# Utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn +import torch.utils.data + +from hydragnn.utils.model.mace_utils.modules.blocks import AtomicEnergiesBlock + + + + + + From 5ad74f39d04c8ca3ff922cfc28ad797d08335497 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Mon, 30 Sep 2024 14:21:56 -0400 Subject: [PATCH 38/51] clean up imports and move files --- hydragnn/models/MACEStack.py | 4 ++-- .../model/mace_utils/modules/__init__.py | 2 -- .../utils/model/mace_utils/modules/blocks.py | 6 +++--- .../utils/model/mace_utils/modules/radial.py | 3 ++- hydragnn/utils/model/operations.py | 10 ++++++++++ hydragnn/utils/model/utils.py | 20 ------------------- 6 files changed, 17 insertions(+), 28 deletions(-) delete mode 100644 hydragnn/utils/model/utils.py diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 3aa9ea74f..f52cfd63a 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -41,14 +41,14 @@ ) # This naming is because there is torch.nn.Sequential and torch_geometric.nn.Sequential from torch_geometric.nn import global_mean_pool -# Mace +# MACE from hydragnn.utils.model.mace_utils.modules.blocks import ( EquivariantProductBasisBlock, LinearNodeEmbeddingBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, ) -from hydragnn.utils.model.utils import ( +from hydragnn.utils.model.operations import ( get_edge_vectors_and_lengths, ) diff --git a/hydragnn/utils/model/mace_utils/modules/__init__.py b/hydragnn/utils/model/mace_utils/modules/__init__.py index f1b701ec2..8ffab253e 100644 --- a/hydragnn/utils/model/mace_utils/modules/__init__.py +++ b/hydragnn/utils/model/mace_utils/modules/__init__.py @@ -6,10 +6,8 @@ AtomicEnergiesBlock, EquivariantProductBasisBlock, InteractionBlock, - # LinearDipoleReadoutBlock, LinearNodeEmbeddingBlock, LinearReadoutBlock, - # NonLinearDipoleReadoutBlock, NonLinearReadoutBlock, RadialEmbeddingBlock, RealAgnosticAttResidualInteractionBlock, diff --git a/hydragnn/utils/model/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py index e3d3b1da0..892da7b55 100644 --- a/hydragnn/utils/model/mace_utils/modules/blocks.py +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -8,18 +8,18 @@ from typing import Callable, List, Optional, Tuple, Union import numpy as np +import torch import torch.nn.functional +from torch_scatter import scatter from e3nn import nn, o3 from e3nn.util.jit import compile_mode from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile -from torch_scatter import scatter - from hydragnn.utils.model.irreps_tools import ( - linear_out_irreps, reshape_irreps, tp_out_irreps_with_instructions, ) + from .radial import ( AgnesiTransform, BesselBasis, diff --git a/hydragnn/utils/model/mace_utils/modules/radial.py b/hydragnn/utils/model/mace_utils/modules/radial.py index f9c703b79..6d64c80d1 100644 --- a/hydragnn/utils/model/mace_utils/modules/radial.py +++ b/hydragnn/utils/model/mace_utils/modules/radial.py @@ -5,12 +5,13 @@ ########################################################################################### import ase + import numpy as np import torch +from torch_scatter import scatter from e3nn.util.jit import compile_mode from hydragnn.utils.model.mace_utils.tools.compile import simplify_if_compile -from torch_scatter import scatter @compile_mode("script") diff --git a/hydragnn/utils/model/operations.py b/hydragnn/utils/model/operations.py index a222606d4..68afe653b 100644 --- a/hydragnn/utils/model/operations.py +++ b/hydragnn/utils/model/operations.py @@ -1,3 +1,13 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## from typing import List, Tuple import numpy as np import torch diff --git a/hydragnn/utils/model/utils.py b/hydragnn/utils/model/utils.py deleted file mode 100644 index 7db557de9..000000000 --- a/hydragnn/utils/model/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -########################################################################################### -# Utilities -# Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import List, Tuple - -import numpy as np -import torch -import torch.nn -import torch.utils.data - -from hydragnn.utils.model.mace_utils.modules.blocks import AtomicEnergiesBlock - - - - - - From e6df82f58321cc13478a718fe0c850a6628ec0c5 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Mon, 30 Sep 2024 14:22:37 -0400 Subject: [PATCH 39/51] formatting --- hydragnn/utils/model/irreps_tools.py | 2 +- hydragnn/utils/model/mace_utils/tools/cg.py | 4 ++-- hydragnn/utils/model/mace_utils/tools/compile.py | 1 - hydragnn/utils/model/model.py | 1 - hydragnn/utils/model/operations.py | 2 +- 5 files changed, 4 insertions(+), 6 deletions(-) diff --git a/hydragnn/utils/model/irreps_tools.py b/hydragnn/utils/model/irreps_tools.py index acb513114..71d74e6ff 100644 --- a/hydragnn/utils/model/irreps_tools.py +++ b/hydragnn/utils/model/irreps_tools.py @@ -99,4 +99,4 @@ def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max ] ) out.append(x[:, -num_features:]) - return torch.cat(out, dim=-1) \ No newline at end of file + return torch.cat(out, dim=-1) diff --git a/hydragnn/utils/model/mace_utils/tools/cg.py b/hydragnn/utils/model/mace_utils/tools/cg.py index 6c1b94864..2cca09c94 100644 --- a/hydragnn/utils/model/mace_utils/tools/cg.py +++ b/hydragnn/utils/model/mace_utils/tools/cg.py @@ -52,9 +52,9 @@ def _wigner_nj( C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) if normalization == "component": - C *= ir_out.dim ** 0.5 + C *= ir_out.dim**0.5 if normalization == "norm": - C *= ir_left.dim ** 0.5 * ir.dim ** 0.5 + C *= ir_left.dim**0.5 * ir.dim**0.5 C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) C = C.reshape( diff --git a/hydragnn/utils/model/mace_utils/tools/compile.py b/hydragnn/utils/model/mace_utils/tools/compile.py index 2df0e4573..2b7995d48 100644 --- a/hydragnn/utils/model/mace_utils/tools/compile.py +++ b/hydragnn/utils/model/mace_utils/tools/compile.py @@ -56,7 +56,6 @@ def wrapper(*args, **kwargs): _SIMPLIFY_REGISTRY = set() - def simplify_if_compile(module: nn.Module) -> nn.Module: """Decorator to register a module for symbolic simplification diff --git a/hydragnn/utils/model/model.py b/hydragnn/utils/model/model.py index 8dacafc7d..7e3251e08 100644 --- a/hydragnn/utils/model/model.py +++ b/hydragnn/utils/model/model.py @@ -281,7 +281,6 @@ def __init__( self.use_deepspeed = use_deepspeed def __call__(self, model, optimizer, perf_metric): - if (perf_metric > self.min_perf_metric + self.min_delta) or ( self.count < self.warmup ): diff --git a/hydragnn/utils/model/operations.py b/hydragnn/utils/model/operations.py index 68afe653b..5101f70ac 100644 --- a/hydragnn/utils/model/operations.py +++ b/hydragnn/utils/model/operations.py @@ -32,4 +32,4 @@ def get_edge_vectors_and_lengths( vectors_normed = vectors / (lengths + eps) return vectors_normed, lengths - return vectors, lengths \ No newline at end of file + return vectors, lengths From 32fe8e5f030c7d7649a3e78968bbb335fbd542b2 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Mon, 30 Sep 2024 14:31:22 -0400 Subject: [PATCH 40/51] formatting --- hydragnn/utils/model/mace_utils/tools/cg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hydragnn/utils/model/mace_utils/tools/cg.py b/hydragnn/utils/model/mace_utils/tools/cg.py index 2cca09c94..6c1b94864 100644 --- a/hydragnn/utils/model/mace_utils/tools/cg.py +++ b/hydragnn/utils/model/mace_utils/tools/cg.py @@ -52,9 +52,9 @@ def _wigner_nj( C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) if normalization == "component": - C *= ir_out.dim**0.5 + C *= ir_out.dim ** 0.5 if normalization == "norm": - C *= ir_left.dim**0.5 * ir.dim**0.5 + C *= ir_left.dim ** 0.5 * ir.dim ** 0.5 C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) C = C.reshape( From 17f283bc90589243339346f8b35cf77238a22c22 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Mon, 30 Sep 2024 14:42:14 -0400 Subject: [PATCH 41/51] Add source information --- .../model/mace_utils/modules/__init__.py | 23 ++++++++++++++----- .../utils/model/mace_utils/modules/blocks.py | 5 ++++ .../utils/model/mace_utils/modules/radial.py | 5 ++++ .../modules/symmetric_contraction.py | 5 ++++ .../utils/model/mace_utils/tools/__init__.py | 11 +++++++++ hydragnn/utils/model/mace_utils/tools/cg.py | 5 ++++ .../utils/model/mace_utils/tools/compile.py | 6 +++++ requirements-torch.txt | 1 + 8 files changed, 55 insertions(+), 6 deletions(-) diff --git a/hydragnn/utils/model/mace_utils/modules/__init__.py b/hydragnn/utils/model/mace_utils/modules/__init__.py index 8ffab253e..d21d2b811 100644 --- a/hydragnn/utils/model/mace_utils/modules/__init__.py +++ b/hydragnn/utils/model/mace_utils/modules/__init__.py @@ -1,3 +1,14 @@ +########################################################################################### +# __init__ file for Modules +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + from typing import Callable, Dict, Optional, Type import torch @@ -21,12 +32,12 @@ "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, } -gate_dict: Dict[str, Optional[Callable]] = { - "abs": torch.abs, - "tanh": torch.tanh, - "silu": torch.nn.functional.silu, - "None": None, -} +# gate_dict: Dict[str, Optional[Callable]] = { +# "abs": torch.abs, +# "tanh": torch.tanh, +# "silu": torch.nn.functional.silu, +# "None": None, +# } __all__ = [ "AtomicEnergiesBlock", diff --git a/hydragnn/utils/model/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py index 892da7b55..0e7223bf1 100644 --- a/hydragnn/utils/model/mace_utils/modules/blocks.py +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -3,6 +3,11 @@ # Authors: Ilyes Batatia, Gregor Simm # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### from abc import abstractmethod from typing import Callable, List, Optional, Tuple, Union diff --git a/hydragnn/utils/model/mace_utils/modules/radial.py b/hydragnn/utils/model/mace_utils/modules/radial.py index 6d64c80d1..cf5043a78 100644 --- a/hydragnn/utils/model/mace_utils/modules/radial.py +++ b/hydragnn/utils/model/mace_utils/modules/radial.py @@ -3,6 +3,11 @@ # Authors: Ilyes Batatia, Gregor Simm # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### import ase diff --git a/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py b/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py index 8f2edd1c5..465d8fa9e 100644 --- a/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py +++ b/hydragnn/utils/model/mace_utils/modules/symmetric_contraction.py @@ -4,6 +4,11 @@ # Authors: Ilyes Batatia # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### from typing import Dict, Optional, Union diff --git a/hydragnn/utils/model/mace_utils/tools/__init__.py b/hydragnn/utils/model/mace_utils/tools/__init__.py index cd5fb8634..26207ecbe 100644 --- a/hydragnn/utils/model/mace_utils/tools/__init__.py +++ b/hydragnn/utils/model/mace_utils/tools/__init__.py @@ -1,3 +1,14 @@ +########################################################################################### +# __init__ file for Tools +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + from .cg import U_matrix_real __all__ = [ diff --git a/hydragnn/utils/model/mace_utils/tools/cg.py b/hydragnn/utils/model/mace_utils/tools/cg.py index 6c1b94864..f0349998b 100644 --- a/hydragnn/utils/model/mace_utils/tools/cg.py +++ b/hydragnn/utils/model/mace_utils/tools/cg.py @@ -3,6 +3,11 @@ # Authors: Ilyes Batatia # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### import collections from typing import List, Union diff --git a/hydragnn/utils/model/mace_utils/tools/compile.py b/hydragnn/utils/model/mace_utils/tools/compile.py index 2b7995d48..823d1511e 100644 --- a/hydragnn/utils/model/mace_utils/tools/compile.py +++ b/hydragnn/utils/model/mace_utils/tools/compile.py @@ -3,6 +3,12 @@ # Authors: Ilyes Batatia, Gregor Simm and David Kovacs # This code is pulled from the MACE repository, which is distributed under the MIT License ########################################################################################### +# Taken From: +# GitHub: https://github.com/ACEsuit/mace +# ArXiV: https://arxiv.org/pdf/2206.07697 +# Date: August 27, 2024 | 12:37 (EST) +########################################################################################### + from contextlib import contextmanager from functools import wraps from typing import Callable, Tuple diff --git a/requirements-torch.txt b/requirements-torch.txt index b4b14066d..f1998597e 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,3 +1,4 @@ torch==2.0.1 torchvision torchaudio + \ No newline at end of file From 92ddccd6335420fc0bc4ad0552a144e351d91b0d Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Tue, 1 Oct 2024 11:50:52 -0400 Subject: [PATCH 42/51] Add checking and processing for node attributes --- hydragnn/models/MACEStack.py | 64 ++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index f52cfd63a..c334c9a38 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -8,7 +8,6 @@ # # # SPDX-License-Identifier: BSD-3-Clause # ############################################################################## - # Adapted From: # GitHub: https://github.com/ACEsuit/mace # ArXiV: https://arxiv.org/pdf/2206.07697 @@ -22,12 +21,14 @@ # NOTE MACE Architecture: ## There are two key ideas of MACE: ### (1) Message passing and interaction blocks are equivariant to the O(3) group. And invariant to the T(3) group (translations). -### (2) Predictions are made in an n-body expansion, where n is the numnber of layers. This is done by creating multi-body -### interactions, then decoding them. Layer 1 will decode 1-body interactions, layer 2 will decode w-body interactions, -### and so on. So, for a 3-layer model predicting energy, there are 3 outputs for energy, one at each layer, and they -### are summed at the end. This requires some adjustment to the behavior from Base.py +### (2) Predictions are made in an n-body expansion, where n is the numnber of convolutional layers+1. This is done by creating +### multi-body interactions, then decoding them. Decoding before anything else with 1-body interactions, Interaction Layer 1 +### will decode 2-body interactions, layer 2 will decode 3-body interactions,and so on. So, for a 3-convolutional-layer model +### predicting energy, there are 4 outputs for energy: 1 before convolution + 3*(1 after each layer). These outputs are summed +### at the end. This requires some adjustment to the behavior from Base.py -from typing import Any, Callable, Dict, List, Optional, Type, Union +# from typing import Any, Callable, Dict, List, Optional, Type, Union +import warnings # Torch import torch @@ -64,8 +65,6 @@ import numpy as np import math -# pylint: disable=C0302 - @compile_mode("script") class MACEStack(Base): @@ -109,8 +108,8 @@ def __init__( ## Defined self.interaction_cls = RealAgnosticAttResidualInteractionBlock self.interaction_cls_first = RealAgnosticAttResidualInteractionBlock - self.num_elements = 118 # Number of elements in the periodic table - atomic_numbers = list(range(1, self.num_elements + 1)) + atomic_numbers = list(range(1, 119)) # 118 elements in the periodic table + self.num_elements = len(atomic_numbers) # Optional num_polynomial_cutoff = ( 5 if num_polynomial_cutoff is None else num_polynomial_cutoff @@ -119,9 +118,6 @@ def __init__( radial_type = "bessel" if radial_type is None else radial_type # Making Irreps - self.node_attr_irreps = o3.Irreps( - [(self.num_elements, (0, 1))] - ) # 118 is the number of elements in the periodic table self.sh_irreps = o3.Irreps.spherical_harmonics( max_ell ) # This makes the irreps string @@ -166,6 +162,9 @@ def _init_conv(self): ## This integrates HYDRA multihead nature with MACE's layer-wise readouts ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance self.multihead_decoders = ModuleList() + self.node_attr_irreps = o3.Irreps( + [(self.num_elements, (0, 1))] + ) # node_attr_irreps is created here because we need input_dim, which requires super(base) to be called, which calls _init_conv hidden_irreps = o3.Irreps( create_irreps_string(self.hidden_dim, self.node_max_ell) ) @@ -414,12 +413,9 @@ def _conv_args(self, data): data.pos = data.pos - mean_pos[data.batch] # Create node_attrs from atomic numbers. Later on it may contain more information - ## Node attrs are intrinsic properties of the atoms, like charge, atomic number, etc.. + ## Node attrs are intrinsic properties of the atoms. Currently, MACE only supports atomic number node attributes ## data.node_attrs is already used in another place, so has been renamed to data.node_attributes from MACE and same with other data variable names - one_hot = torch.nn.functional.one_hot( - data["x"].long().squeeze(-1), num_classes=118 - ).float() # [n_atoms, 118] ## 118 atoms in the peridoic table - data.node_attributes = one_hot # To-Do: Add more information to node_attrs + data.node_attributes = process_node_attributes(data["x"], self.num_elements) data.shifts = torch.zeros( (data.edge_index.shape[1], 3), dtype=data.pos.dtype, device=data.pos.device ) # Shifts takes into account pbc conditions, but I believe we already generate data.pos to take it into account @@ -468,6 +464,38 @@ def create_irreps_string( return " + ".join(irreps) +def process_node_attributes(node_attributes, num_elements): + # Check that node attributes are atomic numbers and process accordingly + node_attributes = node_attributes.squeeze() # Squeeze all unnecessary dimensions + assert ( + node_attributes.dim() == 1 + ), "MACE only supports raw atomic numbers as node_attributes. Your data.x \ + isn't a 1D tensor after squeezing, are you using vector features? " + + # Check that all elements are integers or integer-like (e.g., 1.0, 2.0), not floats like 1.1 + # This is only a warning so that we don't enforce requirements on the tests + if not torch.all(node_attributes == node_attributes.round()): + warnings.warn( + "MACE only supports raw atomic numbers as node_attributes. Your data.x " + "contains floats that do not align with atomic numbers." + ) + + # Check that all atomic numbers are within the valid range (1 to num_elements) + # This is only a warning so that we don't enforce requirements on the tests + if not torch.all((node_attributes >= 1) & (node_attributes <= num_elements)): + warnings.warn( + "MACE only supports raw atomic numbers as node_attributes. Your data.x \ + is not in the range 1-118, which doesn't align with atomic numbers." + ) + + # Perform one-hot encoding + one_hot = torch.nn.functional.one_hot( + node_attributes.long(), num_classes=num_elements + ).float() # [n_atoms, 118] + + return one_hot + + @compile_mode("script") class MultiheadDecoderBlock(torch.nn.Module): def __init__( From 173ae67766d7c2b3d3114f8e1665e1a8034d888e Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Tue, 1 Oct 2024 12:35:08 -0400 Subject: [PATCH 43/51] MACE natively oly handles atomic numbers as node_attributes. Add warnings, errors, and fixes if data.x doesn't match what's expected --- hydragnn/models/MACEStack.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index c334c9a38..1b0da191f 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -470,27 +470,29 @@ def process_node_attributes(node_attributes, num_elements): assert ( node_attributes.dim() == 1 ), "MACE only supports raw atomic numbers as node_attributes. Your data.x \ - isn't a 1D tensor after squeezing, are you using vector features? " + isn't a 1D tensor after squeezing, are you using vector features?" # Check that all elements are integers or integer-like (e.g., 1.0, 2.0), not floats like 1.1 - # This is only a warning so that we don't enforce requirements on the tests + # This is only a warning so that we don't enforce this requirement on the tests. if not torch.all(node_attributes == node_attributes.round()): warnings.warn( - "MACE only supports raw atomic numbers as node_attributes. Your data.x " - "contains floats that do not align with atomic numbers." + "MACE only supports raw atomic numbers as node_attributes. Your data.x \ + contains floats, which does not align with atomic numbers." ) # Check that all atomic numbers are within the valid range (1 to num_elements) - # This is only a warning so that we don't enforce requirements on the tests + # This is only a warning so that we don't enforce this requirement on the tests. if not torch.all((node_attributes >= 1) & (node_attributes <= num_elements)): warnings.warn( "MACE only supports raw atomic numbers as node_attributes. Your data.x \ - is not in the range 1-118, which doesn't align with atomic numbers." + is not in the range 1-118, which does not align with atomic numbers." ) + node_attributes = torch.clamp(node_attributes, min=1, max=118) # Perform one-hot encoding one_hot = torch.nn.functional.one_hot( - node_attributes.long(), num_classes=num_elements + (node_attributes - 1).long(), + num_classes=num_elements, # Subtract 1 to make atomic numbers 0-indexed for one-hot encoding ).float() # [n_atoms, 118] return one_hot From 6d53ff68ac94d038279a6cba2df1271670b23306 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Wed, 2 Oct 2024 15:12:02 -0400 Subject: [PATCH 44/51] adjust requirements.txt installation and use hidden_dim for sizing more effectively in InteractionBlock --- .github/workflows/CI.yml | 3 +-- .../utils/model/mace_utils/modules/blocks.py | 18 ++++++++++++++---- requirements-torch.txt | 4 +++- requirements-torch2.txt | 3 --- 4 files changed, 18 insertions(+), 10 deletions(-) delete mode 100644 requirements-torch2.txt diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 18d7823fe..24d88444a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -40,8 +40,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install --upgrade -r requirements.txt -r requirements-dev.txt - python -m pip install --upgrade -r requirements-torch.txt --index-url https://download.pytorch.org/whl/cpu - python -m pip install --upgrade -r requirements-torch2.txt + python -m pip install --upgrade -r requirements-torch.txt --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple python -m pip install --upgrade -r requirements-pyg.txt --find-links https://data.pyg.org/whl/torch-2.0.1+cpu.html python -m pip install --upgrade -r requirements-deepspeed.txt - name: Format black diff --git a/hydragnn/utils/model/mace_utils/modules/blocks.py b/hydragnn/utils/model/mace_utils/modules/blocks.py index 0e7223bf1..27574828d 100644 --- a/hydragnn/utils/model/mace_utils/modules/blocks.py +++ b/hydragnn/utils/model/mace_utils/modules/blocks.py @@ -272,12 +272,19 @@ def __repr__(self): ) +########################################################################################### +# NOTE: Below is one of the many possible Interaction Blocks in the MACE architecture. +# Since there are adaptations to the original code in order to be integrated with +# the HydraGNN framework, and the changes between blocks are relatively minor, we've +# elected to adapt one general-purpose block here. Users can access the other blocks +# and adapt similarly from the original MACE code (linked with date at the top). +########################################################################################### @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: self.node_feats_down_irreps = o3.Irreps( - "64x0e" - ) # Interesting, this seems to be a required shaping + [(o3.Irreps(self.hidden_irreps).count(o3.Irrep(0, 1)), (0, 1))] + ) # First linear self.linear_up = o3.Linear( self.node_feats_irreps, @@ -313,9 +320,12 @@ def _setup(self) -> None: ) # The following specifies the network architecture for embedding l=0 (scalar) irreps ## It is worth double-checking, but I believe this means that type 0 (scalar) irreps - ## are being embedded by 3 layers of size 256 and the output dim, then activated. + # are being embedded by 3 layers of size self.hidden_dim (scalar irreps) and the + # output dim, then activated. self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], + [input_dim] + + 3 * [o3.Irreps(self.hidden_irreps).count(o3.Irrep(0, 1))] + + [self.conv_tp.weight_numel], torch.nn.functional.silu, ) diff --git a/requirements-torch.txt b/requirements-torch.txt index f1998597e..7bbb88c33 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,4 +1,6 @@ torch==2.0.1 torchvision torchaudio - \ No newline at end of file +e3nn==0.5.1 +torch-ema==0.3 +torchmetrics==1.4.0 diff --git a/requirements-torch2.txt b/requirements-torch2.txt deleted file mode 100644 index 9fb3c2dc1..000000000 --- a/requirements-torch2.txt +++ /dev/null @@ -1,3 +0,0 @@ -e3nn -torch-ema -torchmetrics From 6034c0c90beed1e2104c487cea4759a763154631 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Thu, 3 Oct 2024 12:40:16 -0400 Subject: [PATCH 45/51] Add comments in compile file --- hydragnn/utils/model/mace_utils/tools/compile.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/hydragnn/utils/model/mace_utils/tools/compile.py b/hydragnn/utils/model/mace_utils/tools/compile.py index 823d1511e..7179676c6 100644 --- a/hydragnn/utils/model/mace_utils/tools/compile.py +++ b/hydragnn/utils/model/mace_utils/tools/compile.py @@ -8,6 +8,12 @@ # ArXiV: https://arxiv.org/pdf/2206.07697 # Date: August 27, 2024 | 12:37 (EST) ########################################################################################### +# NOTE: This file relates heavily to speedups in the MACE architecture which can be done by +# compiling correctly. As a result, much of this code can be commented and MACE will +# still function correctly. Specifically, only simplify_if_compile() and it's imports +# are required. However, for optimal performance, it is recommended to keep the entire +# file as-is. +########################################################################################### from contextlib import contextmanager from functools import wraps From 8b9d52b8359248583e24a50eb6f31ee25b449117 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Fri, 4 Oct 2024 16:33:41 -0400 Subject: [PATCH 46/51] tests for different radial transforms and exposing those options in MACE --- hydragnn/models/MACEStack.py | 5 +- hydragnn/models/create.py | 5 +- .../input_config_parsing/config_utils.py | 4 + tests/inputs/ci.json | 1 + tests/test_radial_transforms.py | 208 ++++++++++++++++++ 5 files changed, 220 insertions(+), 3 deletions(-) create mode 100755 tests/test_radial_transforms.py diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 1b0da191f..533fa91c5 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -71,13 +71,14 @@ class MACEStack(Base): def __init__( self, r_max: float, # The cutoff radius for the radial basis functions and edge_index + radial_type: str, # The type of radial basis function to use + distance_transform: str, # The distance transform to use num_bessel: int, # The number of radial bessel functions. This dictates the richness of radial information in message-passing. max_ell: int, # Max l-type for CG-tensor product. Theoretically, there is no max l-type, but in practice, we need to truncate the CG-tensor product to keep tractible computation node_max_ell: int, # Max l-type for node features avg_num_neighbors: float, num_polynomial_cutoff, # The polynomial cutoff function ensures that the function goes to zero at the cutoff radius smoothly. Same as envelope_exponent for DimeNet correlation, # Used in the product basis block and *roughly* determines the richness of interaction in the n-body interaction of layer 'n'. - radial_type, # The type of radial basis function to use *args, **kwargs, ): @@ -148,7 +149,7 @@ def __init__( num_bessel=num_bessel, num_polynomial_cutoff=num_polynomial_cutoff, radial_type=radial_type, - distance_transform=None, + distance_transform=distance_transform, ) self.node_embedding = LinearNodeEmbeddingBlock( irreps_in=self.node_attr_irreps, diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 8f6346ce7..e6c09bf20 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -56,6 +56,7 @@ def create_model_config( config["Architecture"]["num_after_skip"], config["Architecture"]["num_radial"], config["Architecture"]["radial_type"], + config["Architecture"]["distance_transform"], config["Architecture"]["basis_emb_size"], config["Architecture"]["int_emb_size"], config["Architecture"]["out_emb_size"], @@ -97,6 +98,7 @@ def create_model( num_after_skip: int = None, num_radial: int = None, radial_type: str = None, + distance_transform: str = None, basis_emb_size: int = None, int_emb_size: int = None, out_emb_size: int = None, @@ -349,13 +351,14 @@ def create_model( assert node_max_ell >= 1, "MACE requires node_max_ell >= 1." model = MACEStack( radius, + radial_type, + distance_transform, num_radial, max_ell, node_max_ell, avg_num_neighbors, envelope_exponent, correlation, - radial_type, input_dim, hidden_dim, output_dim, diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index daa3f68ec..a4245c938 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -68,6 +68,10 @@ def update_config(config, train_loader, val_loader, test_loader): if "radius" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["radius"] = None + if "radial_type" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["radial_type"] = None + if "distance_transform" not in config["NeuralNetwork"]["Architecture"]: + config["NeuralNetwork"]["Architecture"]["distance_transform"] = None if "num_gaussians" not in config["NeuralNetwork"]["Architecture"]: config["NeuralNetwork"]["Architecture"]["num_gaussians"] = None if "num_filters" not in config["NeuralNetwork"]["Architecture"]: diff --git a/tests/inputs/ci.json b/tests/inputs/ci.json index ad05c10e4..03e0786be 100644 --- a/tests/inputs/ci.json +++ b/tests/inputs/ci.json @@ -28,6 +28,7 @@ "model_type": "PNA", "radius": 2.0, "max_neighbours": 100, + "radial_type": "bessel", "num_gaussians": 50, "envelope_exponent": 5, "int_emb_size": 64, diff --git a/tests/test_radial_transforms.py b/tests/test_radial_transforms.py new file mode 100755 index 000000000..9238288c1 --- /dev/null +++ b/tests/test_radial_transforms.py @@ -0,0 +1,208 @@ +############################################################################## +# Copyright (c) 2024, Oak Ridge National Laboratory # +# All rights reserved. # +# # +# This file is part of HydraGNN and is distributed under a BSD 3-clause # +# license. For the licensing terms see the LICENSE file in the top-level # +# directory. # +# # +# SPDX-License-Identifier: BSD-3-Clause # +############################################################################## + +import sys, os, json +import pytest + +import torch + +torch.manual_seed(97) +import shutil + +import hydragnn, tests +from hydragnn.utils.input_config_parsing.config_utils import merge_config + + +# Main unit test function called by pytest wrappers. +## Adapted from test_graphs.py ... Currently, only the single head model json is tested, although the multihead functionality remains. +def unittest_train_model( + model_type, + radial_type, + distance_transform, + ci_input, + use_lengths=True, + overwrite_data=False, + use_deepspeed=False, + overwrite_config=None, +): + world_size, rank = hydragnn.utils.distributed.get_comm_size_and_rank() + + os.environ["SERIALIZED_DATA_PATH"] = os.getcwd() + + # Read in config settings and override model type. + config_file = os.path.join(os.getcwd(), "tests/inputs", ci_input) + with open(config_file, "r") as f: + config = json.load(f) + config["NeuralNetwork"]["Architecture"]["model_type"] = model_type + config["NeuralNetwork"]["Architecture"]["radial_type"] = radial_type + config["NeuralNetwork"]["Architecture"]["distance_transform"] = distance_transform + + # Overwrite config settings if provided + if overwrite_config: + config = merge_config(config, overwrite_config) + + """ + to test this locally, set ci.json as + "Dataset": { + ... + "path": { + "train": "serialized_dataset/unit_test_singlehead_train.pkl", + "test": "serialized_dataset/unit_test_singlehead_test.pkl", + "validate": "serialized_dataset/unit_test_singlehead_validate.pkl"} + ... + """ + # use pkl files if exist by default + for dataset_name in config["Dataset"]["path"].keys(): + if dataset_name == "total": + pkl_file = ( + os.environ["SERIALIZED_DATA_PATH"] + + "/serialized_dataset/" + + config["Dataset"]["name"] + + ".pkl" + ) + else: + pkl_file = ( + os.environ["SERIALIZED_DATA_PATH"] + + "/serialized_dataset/" + + config["Dataset"]["name"] + + "_" + + dataset_name + + ".pkl" + ) + if os.path.exists(pkl_file): + config["Dataset"]["path"][dataset_name] = pkl_file + + # In the unit test runs, it is found MFC favors graph-level features over node-level features, compared with other models; + # hence here we decrease the loss weight coefficient for graph-level head in MFC. + if model_type == "MFC" and ci_input == "ci_multihead.json": + config["NeuralNetwork"]["Architecture"]["task_weights"][0] = 2 + + # Only run with edge lengths for models that support them. + if use_lengths: + config["NeuralNetwork"]["Architecture"]["edge_features"] = ["lengths"] + + if rank == 0: + num_samples_tot = 500 + # check if serialized pickle files or folders for raw files provided + pkl_input = False + if list(config["Dataset"]["path"].values())[0].endswith(".pkl"): + pkl_input = True + # only generate new datasets, if not pkl + if not pkl_input: + for dataset_name, data_path in config["Dataset"]["path"].items(): + if overwrite_data: + shutil.rmtree(data_path) + if not os.path.exists(data_path): + os.makedirs(data_path) + if dataset_name == "total": + num_samples = num_samples_tot + elif dataset_name == "train": + num_samples = int( + num_samples_tot + * config["NeuralNetwork"]["Training"]["perc_train"] + ) + elif dataset_name == "test": + num_samples = int( + num_samples_tot + * (1 - config["NeuralNetwork"]["Training"]["perc_train"]) + * 0.5 + ) + elif dataset_name == "validate": + num_samples = int( + num_samples_tot + * (1 - config["NeuralNetwork"]["Training"]["perc_train"]) + * 0.5 + ) + if not os.listdir(data_path): + tests.deterministic_graph_data( + data_path, number_configurations=num_samples + ) + + # Run Training + hydragnn.run_training(config, use_deepspeed) + + ( + error, + error_mse_task, + true_values, + predicted_values, + ) = hydragnn.run_prediction(config, use_deepspeed) + + # Set RMSE and sample MAE error thresholds + thresholds = { + "SAGE": [0.20, 0.20], + "PNA": [0.10, 0.10], + "PNAPlus": [0.10, 0.10], + "MFC": [0.20, 0.30], + "GIN": [0.25, 0.20], + "GAT": [0.60, 0.70], + "CGCNN": [0.175, 0.175], + "SchNet": [0.20, 0.20], + "DimeNet": [0.50, 0.50], + "EGNN": [0.20, 0.20], + "MACE": [0.60, 0.70], + } + + verbosity = 2 + + for ihead in range(len(true_values)): + error_head_mse = error_mse_task[ihead] + error_str = ( + str("{:.6f}".format(error_head_mse)) + + " < " + + str(thresholds[model_type][0]) + ) + hydragnn.utils.print.print_distributed(verbosity, "head: " + error_str) + assert ( + error_head_mse < thresholds[model_type][0] + ), "Head RMSE checking failed for " + str(ihead) + + head_true = true_values[ihead] + head_pred = predicted_values[ihead] + # Check individual samples + mae = torch.nn.L1Loss() + sample_mean_abs_error = mae(head_true, head_pred) + error_str = ( + "{:.6f}".format(sample_mean_abs_error) + + " < " + + str(thresholds[model_type][1]) + ) + assert ( + sample_mean_abs_error < thresholds[model_type][1] + ), "MAE sample checking failed!" + + # Check RMSE error + error_str = str("{:.6f}".format(error)) + " < " + str(thresholds[model_type][0]) + hydragnn.utils.print.print_distributed(verbosity, "total: " + error_str) + assert error < thresholds[model_type][0], "Total RMSE checking failed!" + str(error) + + +@pytest.mark.parametrize( + "model_type", + ["MACE"], +) +@pytest.mark.parametrize("basis_function", ["bessel", "gaussian", "chebyshev"]) +@pytest.mark.parametrize("distance_transform", ["None", "Agnesi", "Soft"]) +def pytest_train_model_transforms( + model_type, + basis_function, + distance_transform, + use_lengths=True, + overwrite_data=False, +): + unittest_train_model( + model_type, + basis_function, + distance_transform, + "ci.json", + use_lengths, + overwrite_data, + ) From 81fbe1173ce9cd7e88c9478e9312da70ac9d3240 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Tue, 8 Oct 2024 12:15:38 -0400 Subject: [PATCH 47/51] Reverse tests changes fully --- examples/LennardJones/LJ.json | 8 +- .../utils/model/mace_utils/modules/radial.py | 83 ------------------- tests/inputs/ci.json | 2 +- tests/inputs/ci_equivariant.json | 2 +- tests/inputs/ci_multihead.json | 2 +- tests/inputs/ci_vectoroutput.json | 2 +- tests/test_graphs.py | 7 +- tests/test_model_loadpred.py | 4 - 8 files changed, 10 insertions(+), 100 deletions(-) diff --git a/examples/LennardJones/LJ.json b/examples/LennardJones/LJ.json index 05b3c6cb6..942052003 100644 --- a/examples/LennardJones/LJ.json +++ b/examples/LennardJones/LJ.json @@ -34,8 +34,8 @@ "node_max_ell": 1, "num_radial": 5, "num_spherical": 2, - "hidden_dim": 2, - "num_conv_layers": 2, + "hidden_dim": 20, + "num_conv_layers": 4, "output_heads": { "node": { "num_headlayers": 2, @@ -57,9 +57,9 @@ "output_names": ["graph_energy"] }, "Training": { - "num_epoch": 2, + "num_epoch": 15, "batch_size": 64, - "perc_train": 0.1, + "perc_train": 0.7, "patience": 20, "early_stopping": true, "Optimizer": { diff --git a/hydragnn/utils/model/mace_utils/modules/radial.py b/hydragnn/utils/model/mace_utils/modules/radial.py index cf5043a78..f53896c49 100644 --- a/hydragnn/utils/model/mace_utils/modules/radial.py +++ b/hydragnn/utils/model/mace_utils/modules/radial.py @@ -147,89 +147,6 @@ def __repr__(self): return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" -@compile_mode("script") -class ZBLBasis(torch.nn.Module): - """ - Implementation of the Ziegler-Biersack-Littmark (ZBL) potential - """ - - p: torch.Tensor - r_max: torch.Tensor - - def __init__(self, r_max: float, p=6, trainable=False): - super().__init__() - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - # Pre-calculate the p coefficients for the ZBL potential - self.register_buffer( - "c", - torch.tensor( - [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() - ), - ) - self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) - self.register_buffer( - "covalent_radii", - torch.tensor( - ase.data.covalent_radii, - dtype=torch.get_default_dtype(), - ), - ) - self.cutoff = PolynomialCutoff(r_max, p) - if trainable: - self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) - self.a_prefactor = torch.nn.Parameter( - torch.tensor(0.4543, requires_grad=True) - ) - else: - self.register_buffer("a_exp", torch.tensor(0.300)) - self.register_buffer("a_prefactor", torch.tensor(0.4543)) - - def forward( - self, - x: torch.Tensor, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( - -1 - ) - Z_u = node_atomic_numbers[sender] - Z_v = node_atomic_numbers[receiver] - a = ( - self.a_prefactor - * 0.529 - / (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp)) - ) - r_over_a = x / a - phi = ( - self.c[0] * torch.exp(-3.2 * r_over_a) - + self.c[1] * torch.exp(-0.9423 * r_over_a) - + self.c[2] * torch.exp(-0.4028 * r_over_a) - + self.c[3] * torch.exp(-0.2016 * r_over_a) - ) - v_edges = (14.3996 * Z_u * Z_v) / x * phi - r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - envelope = ( - 1.0 - - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p) - + self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1) - - (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2) - ) * (x < r_max) - v_edges = 0.5 * v_edges * envelope - V_ZBL = scatter( - v_edges, receiver, dim=0, dim_size=node_attrs.size(0), reduce="sum" - ) - return V_ZBL.squeeze(-1) - - def __repr__(self): - return f"{self.__class__.__name__}(r_max={self.r_max}, c={self.c})" - - @compile_mode("script") class AgnesiTransform(torch.nn.Module): """ diff --git a/tests/inputs/ci.json b/tests/inputs/ci.json index 03e0786be..14ebd43af 100644 --- a/tests/inputs/ci.json +++ b/tests/inputs/ci.json @@ -67,7 +67,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 10, + "num_epoch": 100, "perc_train": 0.7, "EarlyStopping": true, "patience": 10, diff --git a/tests/inputs/ci_equivariant.json b/tests/inputs/ci_equivariant.json index 92f02e974..51175e9e7 100644 --- a/tests/inputs/ci_equivariant.json +++ b/tests/inputs/ci_equivariant.json @@ -67,7 +67,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 10, + "num_epoch": 100, "perc_train": 0.7, "EarlyStopping": true, "patience": 10, diff --git a/tests/inputs/ci_multihead.json b/tests/inputs/ci_multihead.json index 5922241dd..b2c752bb9 100644 --- a/tests/inputs/ci_multihead.json +++ b/tests/inputs/ci_multihead.json @@ -64,7 +64,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 10, + "num_epoch": 80, "Checkpoint": true, "checkpoint_warmup": 10, "perc_train": 0.7, diff --git a/tests/inputs/ci_vectoroutput.json b/tests/inputs/ci_vectoroutput.json index 4c2980be0..ddb616615 100644 --- a/tests/inputs/ci_vectoroutput.json +++ b/tests/inputs/ci_vectoroutput.json @@ -56,7 +56,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 10, + "num_epoch": 80, "Checkpoint": true, "checkpoint_warmup": 10, "perc_train": 0.7, diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 5cd9b2d77..2149220cc 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -219,9 +219,7 @@ def pytest_train_model(model_type, ci_input, overwrite_data=False): # Test only models @pytest.mark.parametrize( - # "model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"] - "model_type", - ["MACE"], + "model_type", ["PNA", "PNAPlus", "CGCNN", "SchNet", "EGNN", "MACE"] ) def pytest_train_model_lengths(model_type, overwrite_data=False): unittest_train_model(model_type, "ci.json", True, overwrite_data) @@ -234,8 +232,7 @@ def pytest_train_equivariant_model(model_type, overwrite_data=False): # Test vector output -# @pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "MACE"]) -@pytest.mark.parametrize("model_type", ["MACE"]) +@pytest.mark.parametrize("model_type", ["PNA", "PNAPlus", "MACE"]) def pytest_train_model_vectoroutput(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data) diff --git a/tests/test_model_loadpred.py b/tests/test_model_loadpred.py index a3e3438e3..8b3617959 100755 --- a/tests/test_model_loadpred.py +++ b/tests/test_model_loadpred.py @@ -96,7 +96,3 @@ def pytest_model_loadpred(): False, ) unittest_model_prediction(config) - - -# if __name__ == "__main__": -# pytest_model_loadpred() From 6c21612e8d25cae2f8e33fe87fd9cda41ed486c3 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Tue, 8 Oct 2024 12:17:20 -0400 Subject: [PATCH 48/51] Missed reversed change --- tests/inputs/ci_multihead.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/inputs/ci_multihead.json b/tests/inputs/ci_multihead.json index b2c752bb9..f408c4aa8 100644 --- a/tests/inputs/ci_multihead.json +++ b/tests/inputs/ci_multihead.json @@ -64,7 +64,7 @@ "denormalize_output": false }, "Training": { - "num_epoch": 80, + "num_epoch": 100, "Checkpoint": true, "checkpoint_warmup": 10, "perc_train": 0.7, From 0feb249a1875c2240b0a715db1e3876e764a377d Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Tue, 8 Oct 2024 12:53:59 -0400 Subject: [PATCH 49/51] fix errors --- hydragnn/utils/model/mace_utils/modules/__init__.py | 10 +--------- tests/test_graphs.py | 5 ----- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/hydragnn/utils/model/mace_utils/modules/__init__.py b/hydragnn/utils/model/mace_utils/modules/__init__.py index d21d2b811..3f6c5d027 100644 --- a/hydragnn/utils/model/mace_utils/modules/__init__.py +++ b/hydragnn/utils/model/mace_utils/modules/__init__.py @@ -25,24 +25,16 @@ ScaleShiftBlock, ) -from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis +from .radial import BesselBasis, GaussianBasis, PolynomialCutoff from .symmetric_contraction import SymmetricContraction interaction_classes: Dict[str, Type[InteractionBlock]] = { "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, } -# gate_dict: Dict[str, Optional[Callable]] = { -# "abs": torch.abs, -# "tanh": torch.tanh, -# "silu": torch.nn.functional.silu, -# "None": None, -# } - __all__ = [ "AtomicEnergiesBlock", "RadialEmbeddingBlock", - "ZBLBasis", "LinearNodeEmbeddingBlock", "LinearReadoutBlock", "EquivariantProductBasisBlock", diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 2149220cc..2aec7fc2d 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -250,12 +250,7 @@ def pytest_train_model_vectoroutput(model_type, overwrite_data=False): "DimeNet", "EGNN", "PNAEq", - "MACE", ], ) def pytest_train_model_conv_head(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) - - -def train_model_conv_head(model_type, overwrite_data=False): - unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data) From e910ef65346a005cbc7c990e277c5b798995bee6 Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Tue, 8 Oct 2024 16:49:14 -0400 Subject: [PATCH 50/51] Fix edge attr usage --- hydragnn/models/MACEStack.py | 32 +++++++++++++++++++------------- hydragnn/models/create.py | 1 + 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 533fa91c5..d61696a6f 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -74,6 +74,7 @@ def __init__( radial_type: str, # The type of radial basis function to use distance_transform: str, # The distance transform to use num_bessel: int, # The number of radial bessel functions. This dictates the richness of radial information in message-passing. + edge_dim: int, # The dimension of HYDRA's optional edge attributes max_ell: int, # Max l-type for CG-tensor product. Theoretically, there is no max l-type, but in practice, we need to truncate the CG-tensor product to keep tractible computation node_max_ell: int, # Max l-type for node features avg_num_neighbors: float, @@ -105,6 +106,7 @@ def __init__( ## Passed self.node_max_ell = node_max_ell num_interactions = kwargs["num_conv_layers"] + self.edge_dim = edge_dim self.avg_num_neighbors = avg_num_neighbors ## Defined self.interaction_cls = RealAgnosticAttResidualInteractionBlock @@ -163,9 +165,15 @@ def _init_conv(self): ## This integrates HYDRA multihead nature with MACE's layer-wise readouts ## NOTE Norm techniques (feature_layers in HYDRA) are not advised for use in equivariant models as it can break equivariance self.multihead_decoders = ModuleList() - self.node_attr_irreps = o3.Irreps( - [(self.num_elements, (0, 1))] - ) # node_attr_irreps is created here because we need input_dim, which requires super(base) to be called, which calls _init_conv + # attr_irreps for node and edges are created here because we need input_dim, which requires super(base) to be called, which calls _init_conv + self.node_attr_irreps = o3.Irreps([(self.num_elements, (0, 1))]) + # Edge Attributes are by default the spherical harmoncis but should be extended to include HYDRA's edge_attr is desired + if self.use_edge_attr: + self.edge_attrs_irreps = ( + o3.Irreps(f"{self.edge_dim}x0e") + self.sh_irreps + ).simplify() # Simplify combines irreps of the same type + else: + self.edge_attrs_irreps = self.sh_irreps hidden_irreps = o3.Irreps( create_irreps_string(self.hidden_dim, self.node_max_ell) ) @@ -245,7 +253,9 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): o3.Irrep(0, 1) ) # Multiple copies of spherical harmonics for multiple interactions. They are 'combined' in a certain way during .simplify() ## This makes it a requirement that hidden irreps all have the same number of channels interaction_irreps = ( - (self.sh_irreps * num_features).sort()[0].simplify() + (self.sh_irreps * num_features) + .sort()[0] + .simplify() # Kept as sh_irreps for the output of reshape irreps, whether or not edge_attr irreps are added from HYDRA functionality ) # .sort() is a tuple, so we need the [0] element for the sorted result ### Output output_irreps = create_irreps_string(output_dim, self.node_max_ell) @@ -257,7 +267,7 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): inter = self.interaction_cls_first( node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=self.sh_irreps, + edge_attrs_irreps=self.edge_attrs_irreps, edge_feats_irreps=self.edge_feats_irreps, target_irreps=interaction_irreps, # Replace with output? hidden_irreps=hidden_irreps_out, @@ -285,7 +295,7 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): inter = self.interaction_cls( node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=hidden_irreps, - edge_attrs_irreps=self.sh_irreps, + edge_attrs_irreps=self.edge_attrs_irreps, edge_feats_irreps=self.edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, @@ -307,7 +317,7 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): inter = self.interaction_cls( node_attrs_irreps=self.node_attr_irreps, node_feats_irreps=hidden_irreps, - edge_attrs_irreps=self.sh_irreps, + edge_attrs_irreps=self.edge_attrs_irreps, edge_feats_irreps=self.edge_feats_irreps, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, @@ -326,13 +336,8 @@ def get_conv(self, input_dim, output_dim, first_layer=False, last_layer=False): ) # Change sizing to output_irreps input_args = "node_attributes, pos, node_features, edge_attributes, edge_features, edge_index" - # readout_args = "node_energies" conv_args = "node_attributes, edge_attributes, edge_features, edge_index" # node_features is not used here because it's passed through in the forward - if self.use_edge_attr: - input_args += ", edge_attr" - conv_args += ", edge_attr" - if not last_layer: return PyGSequential( input_args, @@ -429,6 +434,8 @@ def _conv_args(self, data): shifts=data["shifts"], ) edge_attributes = self.spherical_harmonics(vectors) + if self.use_edge_attr: + edge_attributes = torch.cat([data.edge_attr, edge_attributes], dim=1) edge_features = self.radial_embedding( lengths, data["node_attributes"], data["edge_index"], self.atomic_numbers ) @@ -437,7 +444,6 @@ def _conv_args(self, data): data.node_features = node_feats data.edge_attributes = edge_attributes data.edge_features = edge_features - data.lengths = lengths conv_args = { "node_attributes": data.node_attributes, diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index 218783561..b28ebb83d 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -378,6 +378,7 @@ def create_model( radial_type, distance_transform, num_radial, + edge_dim, max_ell, node_max_ell, avg_num_neighbors, From a120e5e0eebff31dcb3c9da1170c558b8cce944f Mon Sep 17 00:00:00 2001 From: Rylie Weaver Date: Tue, 8 Oct 2024 18:38:09 -0400 Subject: [PATCH 51/51] fix bug from merge resolve --- hydragnn/models/create.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hydragnn/models/create.py b/hydragnn/models/create.py index c99fb5b2e..08d21b77e 100644 --- a/hydragnn/models/create.py +++ b/hydragnn/models/create.py @@ -109,6 +109,10 @@ def create_model( num_filters: int = None, radius: float = None, equivariance: bool = False, + correlation: Union[int, List[int]] = None, + max_ell: int = None, + node_max_ell: int = None, + avg_num_neighbors: int = None, conv_checkpointing: bool = False, verbosity: int = 0, use_gpu: bool = True,