Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherry picked changes from other PRs #291

Merged
merged 6 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ Here is an example of a training routine definition:
interval = "epoch" # Interval for learning rate updates (per epoch)

[training.loss_parameter]
loss_property = ['per_system_energy', 'per_atom_force'] # Properties to include in the loss function
loss_components = ['per_system_energy', 'per_atom_force'] # Properties to include in the loss function

[training.loss_parameter.weight]
per_system_energy = 0.999 # Weight for per molecule energy in the loss calculation
Expand Down
4 changes: 2 additions & 2 deletions docs/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ Loss function
^^^^^^^^^^^^^^^^^^^^^^^^
The loss function quantifies the discrepancy between the model's predictions and the target properties, providing a scalar value that guides the optimizer in updating the model's parameters. This function is configured in the `[training.loss]` section of the training TOML file.

Depending on the specified `loss_property`` section, the loss function can combine various individual loss functions. *Modelforge* always includes the mean squared error (MSE) for energy prediction, and may also incorporate MSE for force prediction, dipole moment prediction, and partial charge prediction.
Depending on the specified `loss_components`` section, the loss function can combine various individual loss functions. *Modelforge* always includes the mean squared error (MSE) for energy prediction, and may also incorporate MSE for force prediction, dipole moment prediction, and partial charge prediction.

The design of the loss function is intrinsically linked to the structure of the energy function. For instance, if the energy function aggregates atomic energies, then loss_property should include `per_system_energy` and optionally, `per_atom_force`.
The design of the loss function is intrinsically linked to the structure of the energy function. For instance, if the energy function aggregates atomic energies, then loss_components should include `per_system_energy` and optionally, `per_atom_force`.


Predicting Short-Range Atomic Energies
Expand Down
4 changes: 4 additions & 0 deletions modelforge/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,10 @@ def __init__(
)
self.lock_file = f"{self.cache_processed_dataset_filename}.lockfile"

def transfer_batch_to_device(self, batch, device, dataloader_idx):
# move all tensors to the device
return batch.to_device(device)

@lock_with_attribute("lock_file")
def prepare_data(
self,
Expand Down
5 changes: 1 addition & 4 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
using a neural network model.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Tuple

from typing import Dict, Tuple, List
import torch
from loguru import logger as log
from torch import nn
Expand Down
4 changes: 1 addition & 3 deletions modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,8 @@ def forward(

# featurize pairwise distances using radial basis functions (RBF)
f_ij = self.radial_symmetry_function_module(d_ij)
f_ij_cut = self.cutoff_module(d_ij)
# Apply the filter network and cutoff function
filters = torch.mul(self.filter_net(f_ij), f_ij_cut)

filters = torch.mul(self.filter_net(f_ij), self.cutoff_module(d_ij))
# depending on whether we share filters or not filters have different
# shape at dim=1 (dim=0 is always the number of atom pairs) if we share
# filters, we copy the filters and use the same filters for all blocks
Expand Down
12 changes: 9 additions & 3 deletions modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,20 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
# first term in equation 6 in the PhysNet paper
embedding_atom_i = self.activation_function(
self.interaction_i(data["atomic_embedding"])
)
) # shape (nr_of_atoms_in_batch, atomic_embedding_dim)

# second term in equation 6 in the PhysNet paper
# apply attention mask G to radial basis functions f_ij
g = self.attention_mask(data["f_ij"])
g = self.attention_mask(
data["f_ij"]
) # shape (nr_of_atom_pairs_in_batch, atomic_embedding_dim)
# calculate the updated embedding for atom j
# NOTE: this changes the 2nd dimension from number_of_radial_basis_functions to atomic_embedding_dim
embedding_atom_j = self.activation_function(
self.interaction_j(data["atomic_embedding"][idx_j])
self.interaction_j(data["atomic_embedding"])[
idx_j
] # NOTE this is the same as the embedding_atom_i, but then we are selecting the embedding of atom j
# shape (nr_of_atom_pairs_in_batch, atomic_embedding_dim)
)
updated_embedding_atom_j = torch.mul(
g, embedding_atom_j
Expand Down
9 changes: 6 additions & 3 deletions modelforge/potential/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def __init__(

super().__init__()

self.eval()
self.core_network = torch.jit.script(core_network) if jit else core_network
self.neighborlist = (
torch.jit.script(neighborlist) if jit_neighborlist else neighborlist
Expand Down Expand Up @@ -430,7 +429,6 @@ def load_state_dict(
strict=strict,
assign=assign,
)
self.eval() # Set the model to evaluation mode


def setup_potential(
Expand Down Expand Up @@ -507,7 +505,6 @@ def setup_potential(
jit=jit,
jit_neighborlist=False if use_training_mode_neighborlist else True,
)
model.eval()
return model


Expand Down Expand Up @@ -611,6 +608,12 @@ def generate_potential(
neighborlist_strategy=inference_neighborlist_strategy,
verlet_neighborlist_skin=verlet_neighborlist_skin,
)
# Disable gradients for model parameters
for param in potential.parameters():
param.requires_grad = False
# Set model to eval
potential.eval()

if simulation_environment == "JAX":
# register nnp_input as pytree
from modelforge.utils.io import import_
Expand Down
79 changes: 57 additions & 22 deletions modelforge/potential/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,37 +151,72 @@ def forward(self, r_ij: torch.Tensor) -> torch.Tensor:
return sub_aev

def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor:
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
"""
Compute the angular subAEV terms of the center atom given neighbor
pairs.

This correspond to equation (4) in the ANI paper. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where N
is the number of neighbor atom pairs within the cutoff radius and
output tensor should have shape
(conformations, atoms, ``self.angular_sublength()``)

Parameters
----------
vectors12: torch.Tensor
Pairwise distance vectors. Shape: [2, n_pairs, 3]
Returns
-------
torch.Tensor
Angular subAEV terms. Shape: [n_pairs, ShfZ_size * ShfA_size]

"""
vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances12 = vectors12.norm(2, dim=-5)
# vectors12 has shape: (2, n_pairs, 3)
distances12 = vectors12.norm(p=2, dim=-1) # Shape: (2, n_pairs)
distances_sum = distances12.sum(dim=0) / 2 # Shape: (n_pairs,)
fcj12 = self.cosine_cutoff(distances12) # Shape: (2, n_pairs)
fcj12_prod = fcj12.prod(dim=0) # Shape: (n_pairs,)

# cos_angles: (n_pairs,)

# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * torch.nn.functional.cosine_similarity(
vectors12[0], vectors12[1], dim=-5
vectors12[0], vectors12[1], dim=-1
)
angles = torch.acos(cos_angles)
fcj12 = self.cosine_cutoff(distances12)
factor1 = ((1 + torch.cos(angles - self.ShfZ)) / 2) ** self.Zeta

angles = torch.acos(cos_angles) # Shape: (n_pairs,)

# Prepare shifts for broadcasting
angles = angles.unsqueeze(-1) # Shape: (n_pairs, 1)
distances_sum = distances_sum.unsqueeze(-1) # Shape: (n_pairs, 1)

# Compute factor1
delta_angles = angles - self.ShfZ.view(1, -1) # Shape: (n_pairs, ShfZ_size)
factor1 = (
(1 + torch.cos(delta_angles)) / 2
) ** self.Zeta # Shape: (n_pairs, ShfZ_size)

# Compute factor2
delta_distances = distances_sum - self.ShfA.view(
1, -1
) # Shape: (n_pairs, ShfA_size)
factor2 = torch.exp(
-self.EtaA * (distances12.sum(0) / 2 - self.ShfA) ** 2
).unsqueeze(-1)
factor2 = factor2.squeeze(4).squeeze(3)
ret = 2 * factor1 * factor2 * fcj12.prod(0)
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
-self.EtaA * delta_distances**2
) # Shape: (n_pairs, ShfA_size)

# Compute the outer product of factor1 and factor2 efficiently
# fcj12_prod: (n_pairs, 1, 1)
fcj12_prod = fcj12_prod.unsqueeze(-1).unsqueeze(-1) # Shape: (n_pairs, 1, 1)

# factor1: (n_pairs, ShfZ_size, 1)
factor1 = factor1.unsqueeze(-1)
# factor2: (n_pairs, 1, ShfA_size)
factor2 = factor2.unsqueeze(-2)

# Compute ret: (n_pairs, ShfZ_size, ShfA_size)
ret = 2 * fcj12_prod * factor1 * factor2

# Flatten the last two dimensions to get the final subAEV
# ret: (n_pairs, ShfZ_size * ShfA_size)
ret = ret.reshape(distances12.size(dim=1), -1)

return ret


import math
Expand Down
14 changes: 7 additions & 7 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ def compute_properties(
# Compute the atomic representation
representation = self.schnet_representation_module(data, pairlist_output)
atomic_embedding = representation["atomic_embedding"]
f_ij = representation["f_ij"]
f_cutoff = representation["f_cutoff"]

# Apply interaction modules to update the atomic embedding
for interaction in self.interaction_modules:
v = interaction(
atomic_embedding = atomic_embedding + interaction(
atomic_embedding,
pairlist_output,
representation["f_ij"],
representation["f_cutoff"],
f_ij,
f_cutoff,
)
atomic_embedding = atomic_embedding + v # Update atomic features

return {
"per_atom_scalar_representation": atomic_embedding,
Expand Down Expand Up @@ -293,14 +294,13 @@ def forward(

# Generate interaction filters based on radial basis functions
W_ij = self.filter_network(f_ij.squeeze(1))
W_ij = W_ij * f_ij_cutoff
W_ij = W_ij * f_ij_cutoff # Shape: [n_pairs, number_of_filters]

# Perform continuous-filter convolution
x_j = atomic_embedding[idx_j]
x_ij = x_j * W_ij # Element-wise multiplication

out = torch.zeros_like(atomic_embedding)
out.scatter_add_(
out = torch.zeros_like(atomic_embedding).scatter_add_(
0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij
) # Aggregate per-atom pair to per-atom

Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/data/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ threshold_mode = "abs"
interval = "epoch"

[training.loss_parameter]
loss_property = ['per_system_energy', 'per_atom_force'] # use
loss_components = ['per_system_energy', 'per_atom_force'] # use

[training.loss_parameter.weight]
per_system_energy = 0.999 #NOTE: reciprocal units
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/data/training_defaults/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ threshold_mode = "abs"
interval = "epoch"
# ------------------------------------------------------------ #
[training.loss_parameter]
loss_property = ['per_system_energy'] #, 'per_atom_force'] # use
loss_components = ['per_system_energy'] #, 'per_atom_force'] # use
# ------------------------------------------------------------ #
[training.loss_parameter.weight]
per_system_energy = 1.0 #NOTE: reciprocal units
Expand Down
4 changes: 2 additions & 2 deletions modelforge/tests/test_parameter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def test_training_parameter_model():
with pytest.raises(ValidationError):
training_parameters.splitting_strategy.dataset_split = [0.7, 0.1, 0.1, 0.1]

# this will throw an error because the datafile has 1 entries for the loss_property dictionary
# this will throw an error because the datafile has 1 entries for the loss_components dictionary
with pytest.raises(ValidationError):
training_parameters.loss_parameter.loss_property = [
training_parameters.loss_parameter.loss_components = [
"per_system_energy",
"per_atom_force",
]
14 changes: 7 additions & 7 deletions modelforge/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,22 @@ def get_trainer(config):
def add_force_to_loss_parameter(config):
"""
[training.loss_parameter]
loss_property = ['per_system_energy', 'per_atom_force']
loss_components = ['per_system_energy', 'per_atom_force']
# ------------------------------------------------------------ #
[training.loss_parameter.weight]
per_system_energy = 0.999 #NOTE: reciprocal units
per_atom_force = 0.001

"""
t_config = config["training"]
t_config.loss_parameter.loss_property.append("per_atom_force")
t_config.loss_parameter.loss_components.append("per_atom_force")
t_config.loss_parameter.weight["per_atom_force"] = 0.001


def add_dipole_moment_to_loss_parameter(config):
"""
[training.loss_parameter]
loss_property = [
loss_components = [
"per_system_energy",
"per_atom_force",
"per_system_dipole_moment",
Expand All @@ -116,8 +116,8 @@ def add_dipole_moment_to_loss_parameter(config):

"""
t_config = config["training"]
t_config.loss_parameter.loss_property.append("per_system_dipole_moment")
t_config.loss_parameter.loss_property.append("per_system_total_charge")
t_config.loss_parameter.loss_components.append("per_system_dipole_moment")
t_config.loss_parameter.loss_components.append("per_system_total_charge")
t_config.loss_parameter.weight["per_system_dipole_moment"] = 0.01
t_config.loss_parameter.weight["per_system_total_charge"] = 0.01

Expand All @@ -130,8 +130,8 @@ def add_dipole_moment_to_loss_parameter(config):

def replace_per_system_with_per_atom_loss(config):
t_config = config["training"]
t_config.loss_parameter.loss_property.remove("per_system_energy")
t_config.loss_parameter.loss_property.append("per_atom_energy")
t_config.loss_parameter.loss_components.remove("per_system_energy")
t_config.loss_parameter.loss_components.append("per_atom_energy")

t_config.loss_parameter.weight.pop("per_system_energy")
t_config.loss_parameter.weight["per_atom_energy"] = 0.999
Expand Down
Loading
Loading