-
Notifications
You must be signed in to change notification settings - Fork 31
Graphs_integration #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
EnricoTrizio
wants to merge
285
commits into
main
Choose a base branch
from
graphs_integration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Graphs_integration #161
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Deleted `utils.py` to resolve the rebase problem.
…raphs_integration
…raphs_integration
…raphs_integration
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
General description
Add the code for CVs based on GNN in the most (possible) organic way.
This largely inherits from Jintu's work (kudos! @jintuzhang), where all the code was based on a "library within the library".
Some functions were much different from the rest of the code (e.g., all the code for GNN models), others were mostly redundant (e.g., GraphDataset, GraphDataModule, CV base, and specific CV classes).
It was wise to reduce the code duplicates and redundancies and make the whole library more organic, still including all the new functionalities.
The modifications have been made to keep the user experience as close as possible to what it was when only descriptor-based models were implemented. For example, the defaults are based on such a scenario, when applicable, and the use of GNN models needs to be "activated" by the user.
SPOILER: This required some thinking and some modifications here and there
General questions
BLOCKS
? Is it worth it to keep this thing?General todos
Point-by-point description
Data handling
Overview
So far, we have a
DictDataset
(based ontorch.Dataset
) and the correspondingDictModule
(based onlightning.lightningDataModule
).For GNNs, there was a
GraphDataset
(based on lists) and the correspondingDictModule
(based onlightning.lightningDataModule
).Here, the data are handled for convenience using the
PyTorchGeometric
framework.There are also a bunch of auxiliary functions for neighborhoods and handling of atom types, plus some utilities to initialize the dataset easily from files.
Implemented solution
The two things are merged:
DictDataset
that can handle both types of data.metadata
attribute that stores general properties in adict
(e.g., cutoff and atom_types).__init__
, the user can specify thedata_type
(eitherdescriptors
(default) orgraphs
. This is then stored inmetadata
and is used in theDictLoader
to handle the data the right way (see below)mlcolvar.data.utils
:save_dataset
,load_dataset
andsave_dataset_configurations_as_extyz
DictModule
that can handle both types of data. Depending on themetadata['data_type']
of the incoming dataset, it either uses ourDictLoader
or thetorch_geometric.DataLoader
.data.graph
containing:atomic.py
for the handling of atomic quantities based on the data classConfiguration
neighborhood.py
for building neighbor lists usingmatscipy
utils.py
to frameConfigurations
into dataset and one-hot embeddings. It also containscreate_test_graph_input
as creating inputs for testing here requires several lines of code.create_dataset_from_trajectories
utils inmlcolvar.utils.io
that allows creating a dataset directly from some trajectory files, providing topology files and usingmdtraj
, thus allowing easy handling of the more complex bio-simulations formats. For solids/surfaces/chemistry, the util handles.xyz
files using a combination ofase
andmdtraj
to be efficient and retain the convenientmdtraj
atom selection.create_timelagged_dataset
that can also create the time-lagged dataset starting fromDictDataset
withdata_type=='graphs'
NB For the graph datasets, the keys are the original ones:
data_list
: all the graph data, e.g., edge src and dst, batch index... (this goes inDictDataset
)z_table
: atomic numbers map (this goes inDictDataset.metadata
)cutoff
: cutoff used in the graph (this goes inDictDataset.metadata
)GNN models
Overview
Of course, they needed to be implemented 😄 but we could inherit most of the code from Jintu.
As an overview, there is a
BaseGNN
parent class that implements the common features, and then each model (e.g.,SchNet
or GVP) is implemented on top of that.There is also a
radial.py
that implements a bunch of tools for radial embeddings.Implemented solution
The GNN code is now implemented in
mlcolvar.core.nn.graph
.BaseGNN
class that is a template for the architecture-specific code. This, for example, already has the methods for embedding edges and setting some common properties.Radial
module implements the tools for radial embeddingsSchNetModel
andGVPModel
are implemented based onBaseGNN
utils.py
, there is a function that creates data for the tests for this module. This could be replaced using the very similar functionmlcolvar.data.graph.utils.create_test_graph_input
that is more general and used also for other thingsCV models
Overview
In Jintu's implementation, all the CV classes we tested were re-implemented, still using the original loss function code.
The point there is that the initialization of the underlying ML model (also in the current version of the library) is performed within the CV class.
We did it to make it simple, and indeed, it is for feed-forward networks, as they have very few things to set (i.e., layers, nodes, activations), and also because there were no alternatives at the time.
For GNNs, however, the initialization can vary a lot (i.e., different architectures and many parameters one could set).
We couldn't cut corners here to include everything, thus somewhere we need to add an extra layer of complexity to either the workflow or the CV models.
Implemented solution
We keep everything similar to what it used to be in the library, except for:
layers
keyword to the more generalmodel
in the init of the CV classes that can acceptlayers
keyword and initializes a FeedForward with that and all theDEFAULT_BLOCKS' (see point 2), e.g., for DeepLDA:
['norm_in', 'nn', 'lda']`.mlcolvar.core.nn.FeedForward
ormlcolvar.core.nn.graph.BaseGNN
model that you had initialized outside the CV class. This way, one overrides the old default and provides an external model and uses theMODEL_BLOCKS
, e.g. for DeepLDA:['nn', 'lda']
. For example, the initialization can be something like thisBLOCKS
of each CV model are duplicated inDEFAULT_BLOCKS' and
MODEL_BLOCKS` to account for the different behaviors. This was a simple way to initialize everything in all the cases (maybe not best one, see questions)Things to note
CommittorLoss
as it does not depend only on the output space but also on the derivatives wrt the input/positions.NotImplementedError
, as we do not have, for now, a stable GNN-AE. As a consequence, theMultiTaskCV
also does not support GNN models, as, in the way we intend it, it wouldn't make much sense without a GNN-based AE.TODOs
Explain module
Overview
There is a new module
graph_sensitivity
that performs a per-node sensitivity analysis. Some internal functions have been adapted to handle both types of datasets.TODOs
Status