Skip to content

Graphs_integration #161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 285 commits into
base: main
Choose a base branch
from
Open

Graphs_integration #161

wants to merge 285 commits into from

Conversation

EnricoTrizio
Copy link
Collaborator

@EnricoTrizio EnricoTrizio commented Nov 13, 2024

General description

Add the code for CVs based on GNN in the most (possible) organic way.
This largely inherits from Jintu's work (kudos! @jintuzhang), where all the code was based on a "library within the library".
Some functions were much different from the rest of the code (e.g., all the code for GNN models), others were mostly redundant (e.g., GraphDataset, GraphDataModule, CV base, and specific CV classes).

It was wise to reduce the code duplicates and redundancies and make the whole library more organic, still including all the new functionalities.

The modifications have been made to keep the user experience as close as possible to what it was when only descriptor-based models were implemented. For example, the defaults are based on such a scenario, when applicable, and the use of GNN models needs to be "activated" by the user.

SPOILER: This required some thinking and some modifications here and there

General questions

  • Shall we make the overall structure smoother? i.e., no too many utils.py here and there and too many submodules?
  • Shall we keep the current names for the graph keys in the datasets? i.e, data_list, z_table etc
  • Do we like the metadata thing for the datasets?
  • What shall we do with the BLOCKS? Is it worth it to keep this thing?

General todos

  • Check everything 😄
  • Double check the Docs

Point-by-point description

Data handling

Overview

So far, we have a DictDataset (based on torch.Dataset) and the corresponding DictModule (based on lightning.lightningDataModule).

For GNNs, there was a GraphDataset (based on lists) and the corresponding DictModule (based on lightning.lightningDataModule).
Here, the data are handled for convenience using the PyTorchGeometric framework.
There are also a bunch of auxiliary functions for neighborhoods and handling of atom types, plus some utilities to initialize the dataset easily from files.

Implemented solution

The two things are merged:

  1. A single DictDataset that can handle both types of data.
  • It also has a metadata attribute that stores general properties in a dict (e.g., cutoff and atom_types).
  • In the __init__, the user can specify the data_type (either descriptors (default) or graphs. This is then stored in metadata and is used in the DictLoader to handle the data the right way (see below)
  • New utils have been added in mlcolvar.data.utils: save_dataset, load_dataset and save_dataset_configurations_as_extyz
  1. A single DictModule that can handle both types of data. Depending on the metadata['data_type'] of the incoming dataset, it either uses our DictLoader or the torch_geometric.DataLoader.
  2. A new submodule data.graph containing:
  • atomic.py for the handling of atomic quantities based on the data class Configuration
  • neighborhood.py for building neighbor lists using matscipy
  • utils.py to frame Configurations into dataset and one-hot embeddings. It also contains create_test_graph_input as creating inputs for testing here requires several lines of code.
  1. A new create_dataset_from_trajectories utils in mlcolvar.utils.io that allows creating a dataset directly from some trajectory files, providing topology files and using mdtraj, thus allowing easy handling of the more complex bio-simulations formats. For solids/surfaces/chemistry, the util handles .xyz files using a combination of ase and mdtraj to be efficient and retain the convenient mdtraj atom selection.
  2. A single create_timelagged_dataset that can also create the time-lagged dataset starting from DictDataset with data_type=='graphs'

NB For the graph datasets, the keys are the original ones:

  • data_list: all the graph data, e.g., edge src and dst, batch index... (this goes in DictDataset)
  • z_table: atomic numbers map (this goes in DictDataset.metadata)
  • cutoff: cutoff used in the graph (this goes in DictDataset.metadata)

GNN models

Overview

Of course, they needed to be implemented 😄 but we could inherit most of the code from Jintu.
As an overview, there is a BaseGNN parent class that implements the common features, and then each model (e.g., SchNet or GVP) is implemented on top of that.
There is also a radial.py that implements a bunch of tools for radial embeddings.

Implemented solution

The GNN code is now implemented in mlcolvar.core.nn.graph.

  1. There is a BaseGNN class that is a template for the architecture-specific code. This, for example, already has the methods for embedding edges and setting some common properties.
  2. The Radial module implements the tools for radial embeddings
  3. The SchNetModel and GVPModel are implemented based on BaseGNN
  4. In utils.py, there is a function that creates data for the tests for this module. This could be replaced using the very similar function mlcolvar.data.graph.utils.create_test_graph_input that is more general and used also for other things

CV models

Overview

In Jintu's implementation, all the CV classes we tested were re-implemented, still using the original loss function code.
The point there is that the initialization of the underlying ML model (also in the current version of the library) is performed within the CV class.
We did it to make it simple, and indeed, it is for feed-forward networks, as they have very few things to set (i.e., layers, nodes, activations), and also because there were no alternatives at the time.
For GNNs, however, the initialization can vary a lot (i.e., different architectures and many parameters one could set).

We couldn't cut corners here to include everything, thus somewhere we need to add an extra layer of complexity to either the workflow or the CV models.

Implemented solution

We keep everything similar to what it used to be in the library, except for:

  1. We rename the layers keyword to the more general model in the init of the CV classes that can accept
  • A list of integers, as it was before. It works as the old layers keyword and initializes a FeedForward with that and all the DEFAULT_BLOCKS' (see point 2), e.g., for DeepLDA: ['norm_in', 'nn', 'lda']`.
  • A mlcolvar.core.nn.FeedForward or mlcolvar.core.nn.graph.BaseGNN model that you had initialized outside the CV class. This way, one overrides the old default and provides an external model and uses the MODEL_BLOCKS, e.g. for DeepLDA: ['nn', 'lda']. For example, the initialization can be something like this
# for GNN-based
gnn_model = SchNet(...)
model = DeepLDA(..., model=gnn_model, ...)

# for FFNN-based, alternative 1, this keeps the normalization from BLOCKS
model = DeepLDA(..., model=[2, 3], ...)

# for FFNN-based, alternative 2, this uses the MODEL_BLOCKS
ff_model = FeedForward(layers=[2, 3])
model = DeepLDA(..., model=ff_model, ...)
  1. The BLOCKS of each CV model are duplicated in DEFAULT_BLOCKS' and MODEL_BLOCKS` to account for the different behaviors. This was a simple way to initialize everything in all the cases (maybe not best one, see questions)
  2. In the training step, the change amounts to having a different setup of the data depending on the type of ML-model we are using, then the rest is basically the same as it was.

Things to note

  1. All the loss functions are untouched! Except for the CommittorLoss as it does not depend only on the output space but also on the derivatives wrt the input/positions.
  2. When an external GNN model is provided, logging is still not working. I left it things for the very end of the PR, focusing on making the things work before.
  3. Autoencoder-based CVs only raise a NotImplementedError, as we do not have, for now, a stable GNN-AE. As a consequence, the MultiTaskCV also does not support GNN models, as, in the way we intend it, it wouldn't make much sense without a GNN-based AE.

TODOs

  • Make logger work with graph models 🗡️
  • Add autoencoders (in the future)

Explain module

Overview

There is a new module graph_sensitivity that performs a per-node sensitivity analysis. Some internal functions have been adapted to handle both types of datasets.

TODOs

  • Maybe we can add something to visualize the results on the molecule?

Status

  • Ready to go

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants