Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graphs_integration #161

Draft
wants to merge 145 commits into
base: main
Choose a base branch
from
Draft

Graphs_integration #161

wants to merge 145 commits into from

Conversation

EnricoTrizio
Copy link
Collaborator

@EnricoTrizio EnricoTrizio commented Nov 13, 2024

General description

Add the code for CVs based on GNN in the most (possible) organic way.
This largely inherits from Jintu's work (kudos! @jintuzhang), where all the code was based on a "library within the library".
Some functions were much different for the rest of the code (e.g., all the code for GNN and the GraphDatamodule), others were mostly redundant (e.g., GraphDataset, CV base and specific classes).

It would be wise to reduce the code duplicates and redundancies and make the whole library more organic, still including all the new functionalities.
SPOILER: this requires some thinking and some modifications here and there

(We could also split the story in more PR in case)


Point-by-point description

Data handling

Affecting --> mlcolvar.data, mlcolvar.utils.io

Overview

So far, we have a DictDataset (based on torch.Dataset) and the corresponding DictModule (based on lightning.lightningDataModule).

For GNNs, there was a GraphDataset (based on lists) and the corresponding DictModule (based on lightning.lightningDataModule).
Here, the data are handled using the PyTorchGeometric for convenience.
There are also a bunch of auxiliary functions for neighborhoods and handling of atom types, plus some utils to initialize the dataset easily from files.

Possible solution

Merge the two dataset classes into DictDataset as they basically do the same thing.
Add a metadata attribute to the dataset to store quantities that are not data (e.g., cutoff and atom_types).
Add some exceptions for the torch-geometric-like data (still easy thanks to the dictionary structure, i.e., they will always be found with the same key).
Now the key are the original ones:

  • data_list: all the graph data, e.g., edge src and dst, batch index... (this goes in DictDataset)
  • z_table: atomic numbers map (this goes in DictDataset.metadata)
  • cutoff: cutoff used in the graph (this goes in DictDataset.metadata)

All the other classes/functions that are different will go in a mlcolvar.data.graph module (almost) as they are, eventually, we need to match the API and the functionalities if possible, for example, in the datamodule.

Questions

  • Maybe some keys can be renamed more easily?
  • Do we like the metadata thing?
  • Single DataModule?
  • Maybe make the overall structure smoother? i.e., no too many utils.py here and there and to many submodules?

GNN models

Affecting --> mlcolvar.core.nn

Overview

Of course, they need to be implemented 😄 but we can inherit most of the code from Jintu.
As an overview, there is a BaseGNN parent class that implements the common features, and then each model (e.g., SchNet or GVP) is implemented on top of that.
There is also a radial.py that implements a bunch of tools for radial embeddings.

Possible solution

Include the GNN codes as they are, eventually with some revision, into a mlcolvar.core.nn.gnn module

Questions

  • Maybe we can try to rely on PyTorch_geometric also here?
  • Scripting tracing and nasty things?

CV models

Affecting --> mlcolvar.cvs

Overview

In Jintu's implementation, all the CV classes we tested were re-implemented, still using the original loss function code.
The point, there, is that the initialization of the underlying ML model (also in the current version of the library) is performed within the CV class.
We did it to make it simple, and indeed, it is for feed-forward networks, as they have very few things to set (i.e., layers, node, activations) and also because there were no alternatives at the time.
For GNNs, however, the initialization can vary a lot (i.e., different architectures and many parameters one could set).

I am afraid we can't cut corners here if we want to include everything and somewhere we need to add an extra layer of complexity to wither the workflow or the CV models,

Possible solution(s)

  1. (my first implementation, now in the PR) We keep everything similar to what it is now. We add a gnn_model=None keyword to the init of the CV classes that can use GNNs (plus some error messages and checks here and there) so that you can pass GNN model that you had initialized outside the CV class
# for GNN-based
gnn_model = SchNet(...)
model = DeepTDA(..., gnn_model=gnn_model, ...)

# for FFNN-based
model = DeepTDA(..., gnn_model=None (default), ...)
  • if gnn_model is None: we use the feedforward implementation as it used to be (read: if you don't know you can use GNN, you don't mess up anything)
  • if gnn_model is BaseGNN: we override under the hood the code that needs to be adapted (i.e., the blocks and the training_step), all the rest can be the same for what I've seen so far (only TDA)

PRO: Only one Base CV class (maybe with a few modifications, only one specific CV class (with a decent amount of modifications), the user experience will not change much.
CONS: the whole mechanism may not be super clear and clean

  1. The same as 1 but more general: we also take the feedforward model initialization out of the CV models and we add some signature to the different models (i.e., model.model_type that can be ff or gnn) so that we can use the right code thereafter in the CV.
# for GNN-based
gnn_model = SchNet(...)
model = DeepTDA(..., model=gnn_model, ...)

# for FFNN-based
ff_model = FeedForward(...)
model = DeepTDA(..., model=ff_model, ...)

PRO: similar to 1 but more general and maybe less confusing than 1
CONS: always adds one more step to the workflow it may sound more complicated than before (even if it's just one line)

  1. (The Jintu's way) Keep two separate classes for graph and feed-forward based CVs.
# for GNN-based
model = GraphDeepTDA(TDA_params... , GNN_params...)

# for FFNN-based
model = DeepTDA(TDA_params..., FFNN_params)

PRO: no (eventually breaking) changes anywhere and for sure much lower activation energy for this PR.
CONS: quite redundant code, still the GNN require a lot of parameters to be set

  1. Strange things with classes of classes that @andrrizzi may suggest

Questions

  • What do we prefer? User experience? Code conciseness? Less changes?

General todos

  • Check everything 😄
  • Fix dependencies
  • Fix and clean imports
  • Fix init files
  • Remove commented lines
  • Tests !!!
  • DOCS !!!

General questions

  • How many new dependencies do we want to keep? Can we make something optional?

Status

  • Ready to go

Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

Copy link

codecov bot commented Nov 14, 2024

Codecov Report

Attention: Patch coverage is 3.61991% with 1917 lines in your changes missing coverage. Please review.

Project coverage is 59.68%. Comparing base (6576f08) to head (af5d238).

❗ There is a different number of reports uploaded between BASE (6576f08) and HEAD (af5d238). Click for more details.

HEAD has 3 uploads less than BASE
Flag BASE (6576f08) HEAD (af5d238)
codecov 6 3
Additional details and impacted files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants