diff --git a/README.md b/README.md index e6ec94fa..19a0bf5d 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,11 @@ Machine Learning Collective Variables for Enhanced Sampling **PAPER** [![paper](https://img.shields.io/badge/JCP-10.1063/5.0156343-blue)](https://doi.org/10.1063/5.0156343) [![preprint](https://img.shields.io/badge/arXiv-2305.19980-lightblue)](https://arxiv.org/abs/2305.19980) -The documentation is available at: -- **stable** version: https://mlcolvar.readthedocs.io -- **latest** version: https://mlcolvar.readthedocs.io/en/latest/ - +--- --- +## Overview + `mlcolvar` is a Python library aimed to help design data-driven collective-variables (CVs) for enhanced sampling simulations. The key features are: 1. A unified framework to help test and use (some) of the CVs proposed in the literature. @@ -26,26 +25,28 @@ The documentation is available at: The library is built upon the [PyTorch](https://pytorch.org/) ML library as well as the [Lightning](https://lightning.ai/) high-level framework. --- +--- -Some of the **CVs** which are implemented, organized by learning setting: -- _Unsupervised_: PCA, (Variational) AutoEncoders [[1](http://dx.doi.org/%2010.1002/jcc.25520),[2](http://dx.doi.org/%2010.1021/acs.jctc.1c00415)] -- _Supervised_: LDA [[3](http://dx.doi.org/10.1021/acs.jpclett.8b00733)], DeepLDA [[4](http://dx.doi.org/%2010.1021/acs.jpclett.0c00535)], DeepTDA [[5](http://dx.doi.org/%2010.1021/acs.jpclett.1c02317)] -- _Time-informed_: TICA [[6](http://dx.doi.org/%2010.1063/1.4811489)], DeepTICA/SRVs [[7](http://dx.doi.org/10.1073/pnas.2113533118),[8](http://dx.doi.org/%2010.1063/1.5092521)], VDE [[9](http://dx.doi.org/10.1103/PhysRevE.97.062412)] -And many others can be implemented based on the building blocks or with simple modifications. Check out the [tutorials](https://mlcolvar.readthedocs.io/en/stable/tutorials.html) and the [examples](https://mlcolvar.readthedocs.io/en/stable/examples.html) section of the documentation. +## Documentation +The documentation is available at: +- **stable** version: https://mlcolvar.readthedocs.io +- **latest** version: https://mlcolvar.readthedocs.io/en/latest/ +--- --- +## Installation -**Install with `pip`** +**1. Install latest stable version with `pip`** -The library is available on [PyPi](https://pypi.org/project/mlcolvar/) and can be installed with `pip`. This is the preferred choice for **users** as it automatically installs the package requirements. +The **latest stable version** of library is available on [PyPi](https://pypi.org/project/mlcolvar/) and can be installed with `pip`. This is the preferred choice for **users** as it automatically installs the package requirements. ```bash pip install mlcolvar ``` -**Clone from GitHub** +**2. Clone repository from GitHub** The library can also be installed cloning the repository from GitHub. This is the preferred choice for **developers** as it provides more flexibility and allows editable installation. @@ -55,16 +56,52 @@ cd mlcolvar pip -e install . ``` +--- +--- + +## CV methods + +Some of the **CVs** which are implemented, organized by learning setting: +- _Unsupervised_: PCA, (Variational) AutoEncoders [[1](http://dx.doi.org/%2010.1002/jcc.25520),[2](http://dx.doi.org/%2010.1021/acs.jctc.1c00415)] +- _Supervised_: LDA [[3](http://dx.doi.org/10.1021/acs.jpclett.8b00733)], DeepLDA [[4](http://dx.doi.org/%2010.1021/acs.jpclett.0c00535)], DeepTDA [[5](http://dx.doi.org/%2010.1021/acs.jpclett.1c02317)] +- _Time-informed_: TICA [[6](http://dx.doi.org/%2010.1063/1.4811489)], DeepTICA/SRVs [[7](http://dx.doi.org/10.1073/pnas.2113533118),[8](http://dx.doi.org/%2010.1063/1.5092521)], VDE [[9](http://dx.doi.org/10.1103/PhysRevE.97.062412)] +- _Committor-based_ [[10](https://doi.org/10.1038/s43588-024-00645-0),[11](https://doi.org/10.1038/s43588-025-00799-5)] +- _Multi-task_ [[12](https://doi.org/10.1063/5.0156343)] + +And many others can be implemented based on the building blocks or with simple modifications. Check out the [tutorials](https://mlcolvar.readthedocs.io/en/stable/tutorials.html) and the [examples](https://mlcolvar.readthedocs.io/en/stable/examples.html) section of the documentation. + +--- +--- + +## Model architectures: feed-forward vs graph-based + +- **Feed-forward**: All the CV methods can be used using *standard* neural networks as architecture, either feed-forward or autoencoders. +In this case, for the inputs there are two possibilities: + - Directly use precomputed physical descriptors, ideally obtained using PLUMED. This options is faster and covers most use cases. + - Compute physical descriptors within the model starting from the atomic positions, ideally obtained from PLUMED. This can be done using as a *preprocessing module* the tools available in the **transform** module of the library or implementing your own descriptors. This option is typically slower and, for example, it should be chosen if the desired descriptors are not already available in PLUMED. + +- **Graph neural networks**: All the CV methods **not based on autoencoders** can be used also using graph neural networks as architecture and directly **atomic positions** as inputs, following the scheme reported in [[JCTC 2024](https://doi.org/10.1021/acs.jctc.4c01197)]. In this case, the inputs are directly the atomic positions and species. +Note that, in general, feed-forward based methods are faster than those graph-based. --- +--- + +### PLUMED interfaces + The resulting CVs can be deployed for enhancing sampling with the [PLUMED](https://www.plumed.org/) plugin compiled with `libtorch`. In particular: -**PLUMED interface**: the resulting CVs can be deployed for enhancing sampling with the [PLUMED](https://www.plumed.org/) package via the [pytorch](https://www.plumed.org/doc-master/user-doc/html/PYTORCH_MODEL/) interface, available since version 2.9. +- **Feed-forward-based** CV models can be employed via the [pytorch](https://www.plumed.org/doc-master/user-doc/html/PYTORCH_MODEL/) interface, available with the official release of PLUMED since version 2.9. + - Note: The transition-state-oriented Kolmogorov bias proposed in [[Nat.Comp.Sci. 2024](https://doi.org/10.1038/s43588-024-00645-0) and [2025](https://doi.org/10.1038/s43588-025-00799-5)], can be employed using the custom interface available at #TODO +- **Graph-based** models can be employed using the custom interface developed in [[JCTC 2024](https://doi.org/10.1021/acs.jctc.4c01197)] available at #TODO. + - Note: This interface already supports the calculation of transition-state-oriented Kolmogorov bias proposed in [[Nat.Comp.Sci. 2024](https://doi.org/10.1038/s43588-024-00645-0) and [2025](https://doi.org/10.1038/s43588-025-00799-5)] --- +--- -**Notes**: in early versions (`v<=0.2.*`) the library was called `mlcvs`. This is still accessible for compatibility with PLUMED masterclasses in the [releases](https://github.com/luigibonati/mlcolvar/releases) or by cloning the `pre-lightning` branch. +## Notes +In early versions (`v<=0.2.*`) the library was called `mlcvs`. This is still accessible for compatibility with PLUMED masterclasses in the [releases](https://github.com/luigibonati/mlcolvar/releases) or by cloning the `pre-lightning` branch. +--- --- Copyright (c) 2023 Luigi Bonati, Enrico Trizio, Andrea Rizzi and Michele Parrinello. diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index ac8b28e6..077b6e45 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -3,6 +3,7 @@ channels: - pytorch - conda-forge + - pyg - defaults dependencies: @@ -24,9 +25,13 @@ dependencies: - matplotlib - scikit-learn - scipy - + - pyg + # Pip-only installs - pip: - KDEpy - nbmake + - mdtraj + - matscipy + diff --git a/docs/api_core.rst b/docs/api_core.rst index 75af705b..66366362 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -1,10 +1,14 @@ Core modules ------------- +============ These are the building blocks which are used to construct the CVs. -.. rubric:: NN +NN +-- +This module implements the architectures with learnable weights that can be used to build CV models. +Descriptors-based +^^^^^^^^^^^^^^^^ .. currentmodule:: mlcolvar.core.nn .. autosummary:: @@ -13,7 +17,31 @@ These are the building blocks which are used to construct the CVs. FeedForward -.. rubric:: Loss +Graphs-based +^^^^^^^^^^^^ +.. currentmodule:: mlcolvar.core.nn.graph + +Base class +"""""""""" +.. autosummary:: + :toctree: autosummary + :template: custom-class-template.rst + + BaseGNN + +Architectures +""""""""""""" +.. autosummary:: + :toctree: autosummary + :template: custom-class-template.rst + + SchNetModel + GVPModel + + +Loss +---- +This module implements the loss functions that can be used to optimize CV models. .. currentmodule:: mlcolvar.core.loss @@ -31,8 +59,13 @@ These are the building blocks which are used to construct the CVs. GeneratorLoss SmartDerivatives -.. rubric:: Stats +Stats +----- +This module implements statistical methods with learnable weights that can be used in CV models. + +Base class +^^^^^^^^^^ .. currentmodule:: mlcolvar.core.stats .. autosummary:: @@ -40,11 +73,26 @@ These are the building blocks which are used to construct the CVs. :template: custom-class-template.rst Stats + +Linear methods +^^^^^^^^^^^^^^ +.. currentmodule:: mlcolvar.core.stats + +.. autosummary:: + :toctree: autosummary + :template: custom-class-template.rst + PCA LDA TICA -.. rubric:: Transform + +Transform +--------- +This module implements **non-learnable** pre/postprocessing tools + +Base class +^^^^^^^^^^ .. currentmodule:: mlcolvar.core.transform @@ -55,7 +103,9 @@ These are the building blocks which are used to construct the CVs. Transform -.. rubric:: Transform.descriptors +Descriptors +^^^^^^^^^^^ +This submodule implements several descriptors that can be computed starting from atomic positions. .. currentmodule:: mlcolvar.core.transform.descriptors @@ -69,7 +119,9 @@ These are the building blocks which are used to construct the CVs. EigsAdjMat MultipleDescriptors -.. rubric:: Transform.tools +Tools +^^^^^ +This submodule implements pre/postporcessing tools. .. currentmodule:: mlcolvar.core.transform.tools diff --git a/docs/api_data.rst b/docs/api_data.rst index 671cff12..a3e6be5d 100644 --- a/docs/api_data.rst +++ b/docs/api_data.rst @@ -1,6 +1,9 @@ Data ---- +General: dataset, module and loader +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + .. currentmodule:: mlcolvar.data This module contains the classes used for handling datasets and for feeding them to the Lightning trainer. @@ -11,4 +14,19 @@ This module contains the classes used for handling datasets and for feeding them DictDataset DictLoader - DictModule \ No newline at end of file + DictModule + +Graph specific tools +^^^^^^^^^^^^^^^^^^^^ +.. currentmodule:: mlcolvar.data.graph + +This module contains the classes used for handling and creating graphs. + +.. autosummary:: + :toctree: autosummary + :template: custom-class-template.rst + + AtomicNumberTable + Configuration + get_neighborhood + create_dataset_from_configurations \ No newline at end of file diff --git a/docs/api_explain.rst b/docs/api_explain.rst index 91352802..23f514f0 100644 --- a/docs/api_explain.rst +++ b/docs/api_explain.rst @@ -1,12 +1,16 @@ Explain ------- -.. rubric:: Sensitivity analysis +Sensitivity analysis +^^^^^^^^^^^^^^^^^^^^ Perform sensitivity analysis to identify feature relevances .. currentmodule:: mlcolvar.explain.sensitivity +Descriptors-based +"""""""""""""""""" + .. autosummary:: :toctree: autosummary :template: custom-class-template.rst @@ -14,7 +18,19 @@ Perform sensitivity analysis to identify feature relevances sensitivity_analysis plot_sensitivity -.. rubric:: Sparse linear model +Graph-based +""""""""""" +.. currentmodule:: mlcolvar.explain.graph_sensitivity + +.. autosummary:: + :toctree: autosummary + :template: custom-class-template.rst + + graph_node_sensitivity + + +Sparse linear models +^^^^^^^^^^^^^^^^^^^^ Use sparse models to approximate classification/regression tasks diff --git a/docs/api_utils.rst b/docs/api_utils.rst index 1fe87a1c..d9e5c5e9 100644 --- a/docs/api_utils.rst +++ b/docs/api_utils.rst @@ -1,7 +1,9 @@ Utils ----- -.. rubric:: Input/Output + +Input/Output +^^^^^^^^^^^^ Helper functions for loading dataframes (incl. PLUMED files) and directly creating datasets from them. @@ -14,7 +16,9 @@ Helper functions for loading dataframes (incl. PLUMED files) and directly creati load_dataframe create_dataset_from_files -.. rubric:: Time-lagged datasets + +Time-lagged datasets +^^^^^^^^^^^^^^^^^^^^ Create a dataset of pairs of time-lagged configurations. @@ -26,31 +30,35 @@ Create a dataset of pairs of time-lagged configurations. create_timelagged_dataset -.. rubric:: FES - -.. rubric:: Trainer -Functions used in conjunction with the lightning Trainer (e.g. logging, metrics...). +FES +^^^ +Compute (and plot) the free energy surface along the CVs. -.. currentmodule:: mlcolvar.utils.trainer +.. currentmodule:: mlcolvar.utils.fes .. autosummary:: :toctree: autosummary :template: custom-class-template.rst - MetricsCallback + compute_fes -Compute (and plot) the free energy surface along the CVs. -.. currentmodule:: mlcolvar.utils.fes +Trainer +^^^^^^^ +Functions used in conjunction with the lightning Trainer (e.g. logging, metrics...). + +.. currentmodule:: mlcolvar.utils.trainer .. autosummary:: :toctree: autosummary :template: custom-class-template.rst - compute_fes + MetricsCallback + -Plotting utils +Plot +^^^^ .. currentmodule:: mlcolvar.utils.plot diff --git a/docs/autosummary/mlcolvar.core.nn.FeedForward.rst b/docs/autosummary/mlcolvar.core.nn.FeedForward.rst index 1078460e..5ed3efa2 100644 --- a/docs/autosummary/mlcolvar.core.nn.FeedForward.rst +++ b/docs/autosummary/mlcolvar.core.nn.FeedForward.rst @@ -17,6 +17,7 @@ .. autosummary:: ~FeedForward.__init__ + ~FeedForward.backward ~FeedForward.forward @@ -29,27 +30,9 @@ .. autosummary:: - ~FeedForward.CHECKPOINT_HYPER_PARAMS_KEY - ~FeedForward.CHECKPOINT_HYPER_PARAMS_NAME - ~FeedForward.CHECKPOINT_HYPER_PARAMS_TYPE ~FeedForward.T_destination - ~FeedForward.automatic_optimization ~FeedForward.call_super_init - ~FeedForward.current_epoch - ~FeedForward.device - ~FeedForward.dtype ~FeedForward.dump_patches - ~FeedForward.example_input_array - ~FeedForward.fabric - ~FeedForward.global_rank - ~FeedForward.global_step - ~FeedForward.hparams - ~FeedForward.hparams_initial - ~FeedForward.local_rank - ~FeedForward.logger - ~FeedForward.loggers - ~FeedForward.on_gpu - ~FeedForward.trainer ~FeedForward.training diff --git a/docs/autosummary/mlcolvar.core.nn.graph.BaseGNN.rst b/docs/autosummary/mlcolvar.core.nn.graph.BaseGNN.rst new file mode 100644 index 00000000..0d3f69da --- /dev/null +++ b/docs/autosummary/mlcolvar.core.nn.graph.BaseGNN.rst @@ -0,0 +1,41 @@ +mlcolvar.core.nn.graph.BaseGNN +============================== + +.. currentmodule:: mlcolvar.core.nn.graph + +.. autoclass:: BaseGNN + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~BaseGNN.__init__ + ~BaseGNN.embed_edge + + + + +.. + + + .. rubric:: Attributes + + .. autosummary:: + + ~BaseGNN.T_destination + ~BaseGNN.call_super_init + ~BaseGNN.dump_patches + ~BaseGNN.in_features + ~BaseGNN.out_features + ~BaseGNN.training + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.core.nn.graph.GVPModel.rst b/docs/autosummary/mlcolvar.core.nn.graph.GVPModel.rst new file mode 100644 index 00000000..68ba5df3 --- /dev/null +++ b/docs/autosummary/mlcolvar.core.nn.graph.GVPModel.rst @@ -0,0 +1,41 @@ +mlcolvar.core.nn.graph.GVPModel +=============================== + +.. currentmodule:: mlcolvar.core.nn.graph + +.. autoclass:: GVPModel + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~GVPModel.__init__ + ~GVPModel.forward + + + + +.. + + + .. rubric:: Attributes + + .. autosummary:: + + ~GVPModel.T_destination + ~GVPModel.call_super_init + ~GVPModel.dump_patches + ~GVPModel.in_features + ~GVPModel.out_features + ~GVPModel.training + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.core.nn.graph.SchNetModel.rst b/docs/autosummary/mlcolvar.core.nn.graph.SchNetModel.rst new file mode 100644 index 00000000..05a5e532 --- /dev/null +++ b/docs/autosummary/mlcolvar.core.nn.graph.SchNetModel.rst @@ -0,0 +1,42 @@ +mlcolvar.core.nn.graph.SchNetModel +================================== + +.. currentmodule:: mlcolvar.core.nn.graph + +.. autoclass:: SchNetModel + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~SchNetModel.__init__ + ~SchNetModel.forward + ~SchNetModel.reset_parameters + + + + +.. + + + .. rubric:: Attributes + + .. autosummary:: + + ~SchNetModel.T_destination + ~SchNetModel.call_super_init + ~SchNetModel.dump_patches + ~SchNetModel.in_features + ~SchNetModel.out_features + ~SchNetModel.training + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.core.nn.graph.radial.RadialEmbeddingBlock.rst b/docs/autosummary/mlcolvar.core.nn.graph.radial.RadialEmbeddingBlock.rst new file mode 100644 index 00000000..962513c3 --- /dev/null +++ b/docs/autosummary/mlcolvar.core.nn.graph.radial.RadialEmbeddingBlock.rst @@ -0,0 +1,39 @@ +mlcolvar.core.nn.graph.radial.RadialEmbeddingBlock +================================================== + +.. currentmodule:: mlcolvar.core.nn.graph.radial + +.. autoclass:: RadialEmbeddingBlock + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~RadialEmbeddingBlock.__init__ + ~RadialEmbeddingBlock.forward + + + + +.. + + + .. rubric:: Attributes + + .. autosummary:: + + ~RadialEmbeddingBlock.T_destination + ~RadialEmbeddingBlock.call_super_init + ~RadialEmbeddingBlock.dump_patches + ~RadialEmbeddingBlock.training + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.cvs.AutoEncoderCV.rst b/docs/autosummary/mlcolvar.cvs.AutoEncoderCV.rst index 13d32fab..93425727 100644 --- a/docs/autosummary/mlcolvar.cvs.AutoEncoderCV.rst +++ b/docs/autosummary/mlcolvar.cvs.AutoEncoderCV.rst @@ -32,10 +32,11 @@ .. autosummary:: - ~AutoEncoderCV.BLOCKS ~AutoEncoderCV.CHECKPOINT_HYPER_PARAMS_KEY ~AutoEncoderCV.CHECKPOINT_HYPER_PARAMS_NAME ~AutoEncoderCV.CHECKPOINT_HYPER_PARAMS_TYPE + ~AutoEncoderCV.DEFAULT_BLOCKS + ~AutoEncoderCV.MODEL_BLOCKS ~AutoEncoderCV.T_destination ~AutoEncoderCV.automatic_optimization ~AutoEncoderCV.call_super_init diff --git a/docs/autosummary/mlcolvar.cvs.BaseCV.rst b/docs/autosummary/mlcolvar.cvs.BaseCV.rst index 33fa73c0..4d368e28 100644 --- a/docs/autosummary/mlcolvar.cvs.BaseCV.rst +++ b/docs/autosummary/mlcolvar.cvs.BaseCV.rst @@ -22,6 +22,7 @@ ~BaseCV.forward_cv ~BaseCV.initialize_blocks ~BaseCV.initialize_transforms + ~BaseCV.parse_model ~BaseCV.parse_options ~BaseCV.setup ~BaseCV.test_step @@ -37,6 +38,8 @@ .. autosummary:: + ~BaseCV.DEFAULT_BLOCKS + ~BaseCV.MODEL_BLOCKS ~BaseCV.example_input_array ~BaseCV.n_cvs ~BaseCV.optimizer_name diff --git a/docs/autosummary/mlcolvar.cvs.Committor.rst b/docs/autosummary/mlcolvar.cvs.Committor.rst index 3aa14710..8b91866a 100644 --- a/docs/autosummary/mlcolvar.cvs.Committor.rst +++ b/docs/autosummary/mlcolvar.cvs.Committor.rst @@ -17,6 +17,7 @@ .. autosummary:: ~Committor.__init__ + ~Committor.forward_nn ~Committor.training_step @@ -29,10 +30,11 @@ .. autosummary:: - ~Committor.BLOCKS ~Committor.CHECKPOINT_HYPER_PARAMS_KEY ~Committor.CHECKPOINT_HYPER_PARAMS_NAME ~Committor.CHECKPOINT_HYPER_PARAMS_TYPE + ~Committor.DEFAULT_BLOCKS + ~Committor.MODEL_BLOCKS ~Committor.T_destination ~Committor.automatic_optimization ~Committor.call_super_init diff --git a/docs/autosummary/mlcolvar.cvs.DeepLDA.rst b/docs/autosummary/mlcolvar.cvs.DeepLDA.rst index 2d1eeb29..d4113105 100644 --- a/docs/autosummary/mlcolvar.cvs.DeepLDA.rst +++ b/docs/autosummary/mlcolvar.cvs.DeepLDA.rst @@ -32,10 +32,11 @@ .. autosummary:: - ~DeepLDA.BLOCKS ~DeepLDA.CHECKPOINT_HYPER_PARAMS_KEY ~DeepLDA.CHECKPOINT_HYPER_PARAMS_NAME ~DeepLDA.CHECKPOINT_HYPER_PARAMS_TYPE + ~DeepLDA.DEFAULT_BLOCKS + ~DeepLDA.MODEL_BLOCKS ~DeepLDA.T_destination ~DeepLDA.automatic_optimization ~DeepLDA.call_super_init diff --git a/docs/autosummary/mlcolvar.cvs.DeepTDA.rst b/docs/autosummary/mlcolvar.cvs.DeepTDA.rst index 9ebe1355..011f71b2 100644 --- a/docs/autosummary/mlcolvar.cvs.DeepTDA.rst +++ b/docs/autosummary/mlcolvar.cvs.DeepTDA.rst @@ -29,10 +29,11 @@ .. autosummary:: - ~DeepTDA.BLOCKS ~DeepTDA.CHECKPOINT_HYPER_PARAMS_KEY ~DeepTDA.CHECKPOINT_HYPER_PARAMS_NAME ~DeepTDA.CHECKPOINT_HYPER_PARAMS_TYPE + ~DeepTDA.DEFAULT_BLOCKS + ~DeepTDA.MODEL_BLOCKS ~DeepTDA.T_destination ~DeepTDA.automatic_optimization ~DeepTDA.call_super_init diff --git a/docs/autosummary/mlcolvar.cvs.DeepTICA.rst b/docs/autosummary/mlcolvar.cvs.DeepTICA.rst index 8dc7e0a7..6e8b0929 100644 --- a/docs/autosummary/mlcolvar.cvs.DeepTICA.rst +++ b/docs/autosummary/mlcolvar.cvs.DeepTICA.rst @@ -31,10 +31,11 @@ .. autosummary:: - ~DeepTICA.BLOCKS ~DeepTICA.CHECKPOINT_HYPER_PARAMS_KEY ~DeepTICA.CHECKPOINT_HYPER_PARAMS_NAME ~DeepTICA.CHECKPOINT_HYPER_PARAMS_TYPE + ~DeepTICA.DEFAULT_BLOCKS + ~DeepTICA.MODEL_BLOCKS ~DeepTICA.T_destination ~DeepTICA.automatic_optimization ~DeepTICA.call_super_init diff --git a/docs/autosummary/mlcolvar.cvs.RegressionCV.rst b/docs/autosummary/mlcolvar.cvs.RegressionCV.rst index 54edee56..e9e0422c 100644 --- a/docs/autosummary/mlcolvar.cvs.RegressionCV.rst +++ b/docs/autosummary/mlcolvar.cvs.RegressionCV.rst @@ -29,10 +29,11 @@ .. autosummary:: - ~RegressionCV.BLOCKS ~RegressionCV.CHECKPOINT_HYPER_PARAMS_KEY ~RegressionCV.CHECKPOINT_HYPER_PARAMS_NAME ~RegressionCV.CHECKPOINT_HYPER_PARAMS_TYPE + ~RegressionCV.DEFAULT_BLOCKS + ~RegressionCV.MODEL_BLOCKS ~RegressionCV.T_destination ~RegressionCV.automatic_optimization ~RegressionCV.call_super_init diff --git a/docs/autosummary/mlcolvar.cvs.VariationalAutoEncoderCV.rst b/docs/autosummary/mlcolvar.cvs.VariationalAutoEncoderCV.rst index dbfa2c47..5ac5a041 100644 --- a/docs/autosummary/mlcolvar.cvs.VariationalAutoEncoderCV.rst +++ b/docs/autosummary/mlcolvar.cvs.VariationalAutoEncoderCV.rst @@ -32,10 +32,11 @@ .. autosummary:: - ~VariationalAutoEncoderCV.BLOCKS ~VariationalAutoEncoderCV.CHECKPOINT_HYPER_PARAMS_KEY ~VariationalAutoEncoderCV.CHECKPOINT_HYPER_PARAMS_NAME ~VariationalAutoEncoderCV.CHECKPOINT_HYPER_PARAMS_TYPE + ~VariationalAutoEncoderCV.DEFAULT_BLOCKS + ~VariationalAutoEncoderCV.MODEL_BLOCKS ~VariationalAutoEncoderCV.T_destination ~VariationalAutoEncoderCV.automatic_optimization ~VariationalAutoEncoderCV.call_super_init diff --git a/docs/autosummary/mlcolvar.data.graph.AtomicNumberTable.rst b/docs/autosummary/mlcolvar.data.graph.AtomicNumberTable.rst new file mode 100644 index 00000000..881eff0e --- /dev/null +++ b/docs/autosummary/mlcolvar.data.graph.AtomicNumberTable.rst @@ -0,0 +1,34 @@ +mlcolvar.data.graph.AtomicNumberTable +===================================== + +.. currentmodule:: mlcolvar.data.graph + +.. autoclass:: AtomicNumberTable + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~AtomicNumberTable.__init__ + ~AtomicNumberTable.from_zs + ~AtomicNumberTable.index_to_symbol + ~AtomicNumberTable.index_to_z + ~AtomicNumberTable.z_to_index + ~AtomicNumberTable.zs_to_indices + + + + +.. + + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.data.graph.Configuration.rst b/docs/autosummary/mlcolvar.data.graph.Configuration.rst new file mode 100644 index 00000000..41e2a5fd --- /dev/null +++ b/docs/autosummary/mlcolvar.data.graph.Configuration.rst @@ -0,0 +1,43 @@ +mlcolvar.data.graph.Configuration +================================= + +.. currentmodule:: mlcolvar.data.graph + +.. autoclass:: Configuration + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + .. rubric:: Methods + + .. autosummary:: + + ~Configuration.__init__ + + + + +.. + + + .. rubric:: Attributes + + .. autosummary:: + + ~Configuration.environment + ~Configuration.system + ~Configuration.weight + ~Configuration.atomic_numbers + ~Configuration.positions + ~Configuration.cell + ~Configuration.pbc + ~Configuration.node_labels + ~Configuration.graph_labels + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.data.graph.create_dataset_from_configurations.rst b/docs/autosummary/mlcolvar.data.graph.create_dataset_from_configurations.rst new file mode 100644 index 00000000..d07941ab --- /dev/null +++ b/docs/autosummary/mlcolvar.data.graph.create_dataset_from_configurations.rst @@ -0,0 +1,23 @@ +mlcolvar.data.graph.create\_dataset\_from\_configurations +========================================================= + +.. currentmodule:: mlcolvar.data.graph + +.. autoclass:: create_dataset_from_configurations + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + + + +.. + + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.data.graph.get_neighborhood.rst b/docs/autosummary/mlcolvar.data.graph.get_neighborhood.rst new file mode 100644 index 00000000..0602c951 --- /dev/null +++ b/docs/autosummary/mlcolvar.data.graph.get_neighborhood.rst @@ -0,0 +1,23 @@ +mlcolvar.data.graph.get\_neighborhood +===================================== + +.. currentmodule:: mlcolvar.data.graph + +.. autoclass:: get_neighborhood + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + + + +.. + + + + + \ No newline at end of file diff --git a/docs/autosummary/mlcolvar.explain.graph_sensitivity.graph_node_sensitivity.rst b/docs/autosummary/mlcolvar.explain.graph_sensitivity.graph_node_sensitivity.rst new file mode 100644 index 00000000..1da846f6 --- /dev/null +++ b/docs/autosummary/mlcolvar.explain.graph_sensitivity.graph_node_sensitivity.rst @@ -0,0 +1,23 @@ +mlcolvar.explain.graph\_sensitivity.graph\_node\_sensitivity +============================================================ + +.. currentmodule:: mlcolvar.explain.graph_sensitivity + +.. autoclass:: graph_node_sensitivity + :members: + :show-inheritance: + :inherited-members: Module,LightningModule + + + .. automethod:: __init__ + + + + + +.. + + + + + \ No newline at end of file diff --git a/docs/notebooks/examples/ex_TPI-DeepTDA.ipynb b/docs/notebooks/examples/ex_TPI-DeepTDA.ipynb index 52805569..c7ab7826 100644 --- a/docs/notebooks/examples/ex_TPI-DeepTDA.ipynb +++ b/docs/notebooks/examples/ex_TPI-DeepTDA.ipynb @@ -167,7 +167,7 @@ "target_sigmas = [0.2, 0.2]\n", "nn_layers = [45,24,12,1]\n", "# MODEL\n", - "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, layers=nn_layers)" + "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, model=nn_layers)" ] }, { @@ -417,7 +417,7 @@ "target_sigmas = [0.2, 1.5, 0.2]\n", "nn_layers = [45,24,12,1]\n", "# MODEL\n", - "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, layers=nn_layers)" + "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, model=nn_layers)" ] }, { diff --git a/docs/notebooks/examples/ex_committor.ipynb b/docs/notebooks/examples/ex_committor.ipynb index c07e874d..b3d41279 100644 --- a/docs/notebooks/examples/ex_committor.ipynb +++ b/docs/notebooks/examples/ex_committor.ipynb @@ -323,7 +323,7 @@ " 'nn' : {'activation' : 'tanh'}}\n", "\n", "# initialize model\n", - "model = Committor(layers=[45, 32, 32, 1],\n", + "model = Committor(model=[45, 32, 32, 1],\n", " atomic_masses=atomic_masses,\n", " alpha=1e1,\n", " options=options, \n", @@ -807,7 +807,7 @@ " 'nn' : {'activation' : 'tanh'}}\n", "\n", "# initialize model\n", - "model = Committor(layers=[45, 32, 32, 1],\n", + "model = Committor(model=[45, 32, 32, 1],\n", " atomic_masses=atomic_masses,\n", " alpha=1e1,\n", " options=options, \n", diff --git a/docs/notebooks/paper_experiments/paper_2_supervised.ipynb b/docs/notebooks/paper_experiments/paper_2_supervised.ipynb index 38842fde..e5a524c4 100644 --- a/docs/notebooks/paper_experiments/paper_2_supervised.ipynb +++ b/docs/notebooks/paper_experiments/paper_2_supervised.ipynb @@ -190,7 +190,7 @@ "options = {'nn' : {'activation' : 'shifted_softplus'} }\n", "# MODEL\n", "if run_calculations:\n", - " model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, layers=nn_layers)\n", + " model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, model=nn_layers)\n", "else:\n", " model = torch.jit.load(f'{RESULTS_FOLDER}/model_deepTDA.pt')" ] diff --git a/docs/notebooks/paper_experiments/paper_4_multitask.ipynb b/docs/notebooks/paper_experiments/paper_4_multitask.ipynb index 03ca22bd..ea4deac4 100644 --- a/docs/notebooks/paper_experiments/paper_4_multitask.ipynb +++ b/docs/notebooks/paper_experiments/paper_4_multitask.ipynb @@ -216,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -598,7 +598,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/docs/notebooks/tutorials/adv_gnn_based_cvs.ipynb b/docs/notebooks/tutorials/adv_gnn_based_cvs.ipynb new file mode 100644 index 00000000..d7ac2d9c --- /dev/null +++ b/docs/notebooks/tutorials/adv_gnn_based_cvs.ipynb @@ -0,0 +1,7487 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using `mlcolvar` with graph neural networks (GNNs)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/luigibonati/mlcolvar/blob/main/docs/notebooks/tutorials/adv_gnn_based_cvs.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### NOTE\n", + "Most of the workings of the library are the same using standard feed-forward-nn-based machine-learning CVs or GNN-based ones.\n", + "Thus, it is recommended to first go through the basic tutorials for the standard scenario before moving to this tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Feed-Forward-based CVs vs GNN-based CVs\n", + "\n", + "The default setting of `mlcolvar` is to represent the CVs as the output nodes of Feed-Forward Neural Networks (FFNNs or NNs, for simplicity) which take as input a set of physical descriptors (e.g., distances, angles, etc.).\n", + "The code is thus designed to reflect this choice, with the default values of the classes set to intilialize the CV model in this framework, which is the most diffused for the time being in the field of machine-learning CVs and suits the needs of most users.\n", + "\n", + "However, recently a different approach have been proposed, in which the CVs are represented as Graph Neural Networks (GNNs) which directly take as input the Cartesian coordinates of the atoms in the studied system and return the CV space after a node-pooling operation on the output layer.\n", + "This approach is thus descriptor-free and goes in the direcion of a more automated way of desgining CVs.\n", + "Unfortunately, it typically comes at a higher computational cost (i.e., slower trainign and evaluation fo the CV) and the underlying codebase is more complex (i.e., more complex models and data format.)\n", + "\n", + "In this tutorial, we show how GNN models can be used within `mlcolvar` to build CVs using the implemented CV methods.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Colab setup\n", + "import os\n", + "\n", + "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", + " import subprocess\n", + " subprocess.run('wget https://raw.githubusercontent.com/luigibonati/mlcolvar/main/colab_setup.sh', shell=True)\n", + " cmd = subprocess.run('bash colab_setup.sh TUTORIAL', shell=True, stdout=subprocess.PIPE)\n", + " print(cmd.stdout.decode('utf-8'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Outline\n", + "Typically, the process of constructing a GNN-based CV requires the following ingredients;\n", + "1. A **dataset** of attributed connected graphs (nodes and edges), which are constructed from the atomic positions\n", + "2. A **GNN-model** to represent the CV. Different architectures can be used in this regard.\n", + "3. A **CV method** and the associated **loss function**. These are all the methods implemented for *standard* machine-learning CVs, except for those based on autoencoders. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load data\n", + "#### The inputs of GNNs CVs\n", + "The input of GNN models are attributed and connected graphs, in which nodes (representing the atoms, in our case) are connected by edges (the lines of the graph).\n", + "Nodes and edges are then assigned with scalar and, eventually, vector features that are then processed through the layers of the GNN.\n", + "\n", + "In the context of GNN-CVs, such graphs most likely are created directly from the atomic coordinates from a trajectory file and the connectivity between the nodes is determined according to a radial `cutoff`.\n", + "\n", + "#### Truncated graphs\n", + "In some cases, graphs can be built focusing the attention on a subset of the whole system, e.g., a molecule on a surface, but still keeping into account the interaction with the environment, e.g., the surface.\n", + "In this case, only the ndoes from the `system_selection` will be used for the final pooling, whereas the nodes from the `enviroment_selection` will be used only to update the information through the layers.\n", + "Moreover, to reduce the computational costs, only the atoms closer to the `system_selection` atoms will be included in the graphs, according to the set `cutoff` and a `buffer` value to ensure stability e continuity. \n", + "For example, this setup is useful when treting solvent or surface interactions.\n", + "\n", + "#### Create dataset from trajectory files\n", + "To make this process easier, in `mlcolvar` there is an util function to do this under-the-hood: `create_dataset_from_trajectories`, which is analogous to the create_dataset_from_files used with descriptors.\n", + "The loading process is built on the external library [`MDTraj`](https://www.mdtraj.org/), which can natively load most common trajectory+topology format used in biophysics.\n", + "On the other hand, for less-bio applications (e.g., solids, surfaces, molecules) we recommend using the `.xyz` file format.\n", + "\n", + "One advantage of MDTraj, is that it comes with a simple and user friendly synthax for atom selection, which can be used also here.\n", + "\n", + "Here, as an example, we load some data about the state A and B of Alanine Dipeptide." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset info:\n", + " DictDataset( \"data_list\": 4000, \"z_table\": [6, 7, 8], \"cutoff\": 10.0, \"used_idx\": tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), \"used_names\": [ACE1-CH3, ACE1-C, ACE1-O, ALA2-N, ALA2-CA, ALA2-CB, ALA2-C, ALA2-O, NME3-N, NME3-C], \"data_type\": graphs )\n", + "\n", + "Datamodule info:\n", + " DictModule(dataset -> DictDataset( \"data_list\": 4000, \"z_table\": [6, 7, 8], \"cutoff\": 10.0, \"used_idx\": tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), \"used_names\": [ACE1-CH3, ACE1-C, ACE1-O, ALA2-N, ALA2-CA, ALA2-CB, ALA2-C, ALA2-O, NME3-N, NME3-C], \"data_type\": graphs ),\n", + "\t\t train_loader -> DictLoader(length=0.8, batch_size=4000, shuffle=True),\n", + "\t\t valid_loader -> DictLoader(length=0.2, batch_size=4000, shuffle=True))\n" + ] + } + ], + "source": [ + "from mlcolvar.data import DictModule\n", + "from mlcolvar.utils.io import create_dataset_from_trajectories\n", + "\n", + "# loading arguments \n", + "# same as to laod_dataframe\n", + "load_args = [{'start' : 0, 'stop' : 10000, 'stride' : 5},\n", + " {'start' : 0, 'stop' : 10000, 'stride' : 5}]\n", + "\n", + "# create dataset\n", + "dataset = create_dataset_from_trajectories(\n", + " trajectories=[\"alad_A.trr\", \n", + " \"alad_B.trr\"],\n", + " top=\"alad.gro\", \n", + " folder=\"data/alanine_gnn\", \n", + " cutoff=10.0, # Angstrom \n", + " labels=None, \n", + " system_selection='all and not type H',\n", + " show_progress=False,\n", + " load_args=load_args,\n", + " lengths_conversion=10.0, # MDTraj uses nm by defualt, we use Angstroms\n", + " )\n", + "print('Dataset info:\\n', dataset, end=\"\\n\\n\")\n", + "\n", + "# load dataset into a DictModule\n", + "datamodule = DictModule(dataset=dataset)\n", + "print('Datamodule info:\\n', datamodule)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Accessing graph data\n", + "The built graphs are then stored as `torch_geometric.Data` objects into the usual `DictDataset` with the information about each graph entry (e.g., nodes positons, edges, weights, elabels etc.) under tehe key `data_list` and the common information for all the graphs (e.g., map from types to chemical species, cutoff) in the `metadata` attribute dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Example of a graph entry:\n", + " Data(edge_index=[2, 90], shifts=[90, 3], unit_shifts=[90, 3], positions=[10, 3], cell=[3, 3], node_attrs=[10, 3], graph_labels=[1, 1], n_system=[1, 1], n_env=[1, 1], weight=1.0, names_idx=[10])\n", + "\n", + "Dataset metadata:\n", + " {'z_table': [6, 7, 8], 'cutoff': 10.0, 'used_idx': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 'used_names': [ACE1-CH3, ACE1-C, ACE1-O, ALA2-N, ALA2-CA, ALA2-CB, ALA2-C, ALA2-O, NME3-N, NME3-C], 'data_type': 'graphs'}\n" + ] + } + ], + "source": [ + "print('Example of a graph entry:\\n', dataset['data_list'][0], end='\\n\\n')\n", + "print('Dataset metadata:\\n', dataset.metadata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initializing the GNN model\n", + "At variance with the procedure with FFNNs, here the model is initialized **outside** the CV class, to which is then passed only later as an input.\n", + "GNN architectures are indeed much more complex than FFNNs and have many parameters that can be set.\n", + "In addition, when introducing GNN models into the code, we maintained the standard CVs as the default, which still covers most of the users.\n", + "\n", + "Here, for example, we initialize a `SchNetModel`.\n", + "Many other architectures are available in [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) and can be readily adapted to this library.\n", + "\n", + "#### NOTE\n", + "As the input graph are built with the dataset and then processed in the GNN-model, it is wise to initialize the model directly refering to the values stored in the `dataset.metadata` (e.g., cutoff, z_table)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from mlcolvar.core.nn.graph.schnet import SchNetModel\n", + "\n", + "gnn_model = SchNetModel(n_out=1,\n", + " cutoff=dataset.metadata['cutoff'],\n", + " atomic_numbers=dataset.metadata['z_table'],\n", + " pooling_operation=\"mean\",\n", + " n_bases=16,\n", + " n_layers=2,\n", + " n_filters=16,\n", + " n_hidden_channels=16,\n", + " w_out_after_pool=True,\n", + " aggr='mean'\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initializing CV class\n", + "The initalization of the CV class is almost identical to the standard case, with the only difference that we provide the initialized GNN object as model.\n", + "\n", + "Here, for example, we use the `DeepTDA` CV." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/etrizio@iit.local/Bin/miniconda3/envs/graph_mlcolvar_test_2.5/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n" + ] + } + ], + "source": [ + "import torch\n", + "from mlcolvar.cvs import DeepTDA\n", + "\n", + "# we can still set the options for the optimizer the usual way\n", + "# options for the BLOCKS of the cv are disabled when passing an external model\n", + "options = {'optimizer' : {'lr' : 1e-3},\n", + " 'lr_scheduler': {\n", + " 'scheduler': torch.optim.lr_scheduler.ExponentialLR,\n", + " 'gamma': 0.9999}\n", + " }\n", + "\n", + "model = DeepTDA(n_states=2,\n", + " n_cvs=1,\n", + " target_centers=[-7, 7],\n", + " target_sigmas=[0.2, 0.2],\n", + " model=gnn_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training the CV\n", + "Here, everything works the same!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ad302baf46f44ca4b3aa26af17ab165c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1,1,figsize=(4,3))\n", + "plot_metrics(metrics.metrics,\n", + " keys=['train_loss', 'valid_loss'],\n", + " colors=['fessa1', 'fessa5'],\n", + " yscale='linear',\n", + " ax = ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Testing the model\n", + "As the graph data are stored as `torch_geometric.Data` they need to be loaded using a loader object.\n", + "For convenience, we implemented both in `DictDataset ` and `DictModule` a method `.get_graph_data` to do it so that one can simply evaluate the model calling either:\n", + "- `model(dataset.get_graph_data())` --> Returns the **whole dataset**\n", + "- `model(datamodule.get_graph_data())` --> Returns either the **train or valid dataset**" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1,2, figsize=(10,3))\n", + "\n", + "ax = axs[0]\n", + "out_graph = model(dataset.get_graph_inputs())\n", + "ax.hist(out_graph.detach().squeeze(), bins=100)\n", + "ax.set_title('From Dataset')\n", + "ax.set_xlabel('GNN CV')\n", + "ax.set_ylim(0,850)\n", + "\n", + "ax = axs[1]\n", + "out_graph = model(datamodule.get_graph_inputs(\"train\"))\n", + "ax.hist(out_graph.detach().squeeze(), bins=100)\n", + "out_graph = model(datamodule.get_graph_inputs(\"valid\"))\n", + "ax.hist(out_graph.detach().squeeze(), bins=100)\n", + "\n", + "ax.set_title('From Datamodule')\n", + "ax.set_xlabel('GNN CV')\n", + "ax.set_ylim(0,850)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Save the model to TorchScript\n", + "As for normal CVs, the frozen model can be saved to TorchScript suing the `Lightning` util `to_torchscript` using `method=trace`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/etrizio@iit.local/Bin/dev/mlcolvar/mlcolvar/data/datamodule.py:322: UserWarning: Length of split at index 1 is 0. This might result in an empty dataset.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "traced_model = model.to_torchscript('gnn_model.pt', method='trace')\n", + "\n", + "# we can also check the outputs coincide\n", + "torch.allclose(model(dataset.get_graph_inputs()), traced_model(dataset.get_graph_inputs()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "graph_mlcolvar_test_2.5", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/tutorials/adv_newcv_scratch.ipynb b/docs/notebooks/tutorials/adv_newcv_scratch.ipynb index 972775f1..b405b274 100644 --- a/docs/notebooks/tutorials/adv_newcv_scratch.ipynb +++ b/docs/notebooks/tutorials/adv_newcv_scratch.ipynb @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -72,7 +72,7 @@ "from mlcolvar.cvs import BaseCV\n", "\n", "class AutoEncoderCV(BaseCV, lightning.LightningModule):\n", - " BLOCKS = ['norm_in','encoder','decoder'] " + " DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] " ] }, { @@ -87,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -104,7 +104,7 @@ " with the input 'data'.\n", " \"\"\"\n", " \n", - " BLOCKS = ['norm_in','encoder','decoder'] " + " DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] " ] }, { @@ -136,12 +136,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AutoEncoderCV(BaseCV, lightning.LightningModule):\n", - " BLOCKS = ['norm_in','encoder','decoder'] \n", + " DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n", "\n", " def __init__(self,\n", "# ================================================ LOOK HERE 0.0 ================================================ \n", @@ -165,7 +165,7 @@ " Available blocks: ['norm_in', 'encoder','decoder'].\n", " Set 'block_name' = None or False to turn off that block\n", " \"\"\"\n", - " super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n", + " super().__init__(model=encoder_layers, **kwargs)\n", " \n", "# ================================================ LOOK HERE 0.0 ================================================ \n" ] @@ -185,19 +185,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AutoEncoderCV(BaseCV, lightning.LightningModule):\n", - " BLOCKS = ['norm_in','encoder','decoder'] \n", + " DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n", " \n", " def __init__(self,\n", " encoder_layers : list, \n", " decoder_layers : list = None, \n", " options : dict = None, \n", " **kwargs):\n", - " super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n", + " super().__init__(model=encoder_layers, **kwargs)\n", "\n", "# ================================================ LOOK HERE 0.0 ================================================ \n", " \n", @@ -224,21 +224,21 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mlcolvar.core.loss import MSELoss\n", "\n", "class AutoEncoderCV(BaseCV, lightning.LightningModule):\n", - " BLOCKS = ['norm_in','encoder','decoder'] \n", + " DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n", " \n", " def __init__(self,\n", " encoder_layers : list, \n", " decoder_layers : list = None, \n", " options : dict = None, \n", " **kwargs):\n", - " super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n", + " super().__init__(model=encoder_layers, **kwargs)\n", "\n", " # ======= OPTIONS ======= \n", " # parse and sanitize\n", @@ -283,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -291,14 +291,14 @@ "from mlcolvar.core.transform import Normalization\n", "\n", "class AutoEncoderCV(BaseCV, lightning.LightningModule):\n", - " BLOCKS = ['norm_in','encoder','decoder'] \n", + " DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n", " \n", " def __init__(self,\n", " encoder_layers : list, \n", " decoder_layers : list = None, \n", " options : dict = None, \n", " **kwargs):\n", - " super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n", + " super().__init__(model=encoder_layers, **kwargs)\n", "\n", " # ======= OPTIONS ======= \n", " # parse and sanitize\n", @@ -425,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -442,7 +442,7 @@ " with the input 'data'.\n", " \"\"\"\n", " \n", - " BLOCKS = ['norm_in','encoder','decoder'] \n", + " DEFAULT_BLOCKS = ['norm_in','encoder','decoder'] \n", " \n", " def __init__(self,\n", " encoder_layers : list, \n", @@ -465,7 +465,7 @@ " Available blocks: ['norm_in', 'encoder','decoder'].\n", " Set 'block_name' = None or False to turn off that block\n", " \"\"\"\n", - " super().__init__(in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs)\n", + " super().__init__(model=encoder_layers, **kwargs)\n", "\n", " # ======= OPTIONS ======= \n", " # parse and sanitize\n", @@ -625,7 +625,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pytorch", + "display_name": "graph_mlcolvar_test", "language": "python", "name": "python3" }, @@ -639,14 +639,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.18" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "1cbeac1d7079eaeba64f3210ccac5ee24400128e300a45ae35eee837885b08b3" - } - } + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/notebooks/tutorials/adv_preprocessing.ipynb b/docs/notebooks/tutorials/adv_preprocessing.ipynb index dd0b4a7d..766467b8 100644 --- a/docs/notebooks/tutorials/adv_preprocessing.ipynb +++ b/docs/notebooks/tutorials/adv_preprocessing.ipynb @@ -247,7 +247,7 @@ "source": [ "from mlcolvar.cvs import RegressionCV\n", "\n", - "model = RegressionCV(layers=[2,10,10,1], \n", + "model = RegressionCV(model=[2,10,10,1], \n", " preprocessing = pca ) \n", "\n", "# the preprocessing can also be saved later, like in:\n", diff --git a/docs/notebooks/tutorials/cvs_DeepTDA.ipynb b/docs/notebooks/tutorials/cvs_DeepTDA.ipynb index a2ee41e7..47ba9b2e 100644 --- a/docs/notebooks/tutorials/cvs_DeepTDA.ipynb +++ b/docs/notebooks/tutorials/cvs_DeepTDA.ipynb @@ -7,7 +7,7 @@ "source": [ "# Deep-TDA: Deep Targeted Discriminant Analysis\n", "Reference papers: \n", - "- *Deep-TDA*: _Trizio and Parrinello, [JPCL](https://pubs.acs.org/doi/full/10.1021/acs.jpclett.1c02317) (2021)_ [[arXiv]](https://128.84.4.34/abs/2107.05444).\n", + "- *Deep-TDA*: _Trizio and Parrinello, [JPCL](https://pubs.acs.org/doi/full/10.1021/acs.jpclett.1c02021)_ [[arXiv]](https://128.84.4.34/abs/2107.05444).\n", "- *TPI-Deep-TDA*: _Ray, Trizio and Parrinello, [JCP](https://pubs.aip.org/aip/jcp/article/158/20/204102/2891484) (2023)_ [[arXiv]](https://arxiv.org/abs/2303.01629).\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/luigibonati/mlcolvar/blob/main/docs/notebooks/tutorials/cvs_DeepTDA.ipynb)" @@ -81,18 +81,10 @@ "execution_count": 1, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/etrizio@iit.local/Bin/miniconda3/envs/mlcvs_test/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 1, @@ -160,7 +152,7 @@ { "data": { "text/plain": [ - "DictModule(dataset -> DictDataset( \"data\": [4002, 2], \"labels\": [4002] ),\n", + "DictModule(dataset -> DictDataset( \"data\": [4002, 2], \"labels\": [4002], \"data_type\": descriptors ),\n", "\t\t train_loader -> DictLoader(length=0.8, batch_size=0, shuffle=True),\n", "\t\t valid_loader -> DictLoader(length=0.2, batch_size=0, shuffle=True))" ] @@ -200,7 +192,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -241,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -254,7 +246,7 @@ "nn_layers = [2,24,12,1]\n", "\n", "# Initialize DeepTDA model\n", - "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, layers=nn_layers)" + "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, model=nn_layers)" ] }, { @@ -267,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -354,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -393,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -426,7 +418,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -480,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -522,7 +514,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -564,7 +556,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -583,7 +575,7 @@ " n_cvs=2,\n", " target_centers=target_centers, \n", " target_sigmas=target_sigmas,\n", - " layers=nn_layers)" + " model=nn_layers)" ] }, { @@ -596,7 +588,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -683,7 +675,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -722,7 +714,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -756,7 +748,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -809,7 +801,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -851,7 +843,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -890,7 +882,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -899,7 +891,7 @@ "target_sigmas = [0.2, 0.2, 0.2]\n", "nn_layers = [2,24,12,1]\n", "# MODEL\n", - "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, layers=nn_layers)" + "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, model=nn_layers)" ] }, { @@ -912,7 +904,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -999,7 +991,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1038,7 +1030,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1071,7 +1063,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1127,7 +1119,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1176,7 +1168,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1187,7 +1179,7 @@ "target_sigmas = [0.2, 1.5, 0.2]\n", "nn_layers = [2,24,12,1]\n", "# MODEL\n", - "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, layers=nn_layers)" + "model = DeepTDA(n_states=n_states, n_cvs=1,target_centers=target_centers, target_sigmas=target_sigmas, model=nn_layers)" ] }, { @@ -1199,7 +1191,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1285,7 +1277,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1322,7 +1314,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1366,7 +1358,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pytorch", + "display_name": "graph_mlcolvar_test", "language": "python", "name": "python3" }, @@ -1380,14 +1372,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.18" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "1cbeac1d7079eaeba64f3210ccac5ee24400128e300a45ae35eee837885b08b3" - } - } + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/notebooks/tutorials/cvs_committor.ipynb b/docs/notebooks/tutorials/cvs_committor.ipynb index 0c76e0fe..29a6616e 100644 --- a/docs/notebooks/tutorials/cvs_committor.ipynb +++ b/docs/notebooks/tutorials/cvs_committor.ipynb @@ -133,7 +133,7 @@ "options = {'optimizer' : {'lr': 1e-3, 'weight_decay': 1e-5}, \n", " 'lr_scheduler' : { 'scheduler' : lr_scheduler, 'gamma' : 0.99999 }}\n", "\n", - "model = Committor(layers=[2, 32, 32, 1],\n", + "model = Committor(model=[2, 32, 32, 1],\n", " atomic_masses=atomic_masses,\n", " alpha=1e-1,\n", " delta_f=0,\n", diff --git a/docs/notebooks/tutorials/data/alanine_gnn/alad.gro b/docs/notebooks/tutorials/data/alanine_gnn/alad.gro new file mode 100644 index 00000000..c633b39b --- /dev/null +++ b/docs/notebooks/tutorials/data/alanine_gnn/alad.gro @@ -0,0 +1,25 @@ +Generated by trjconv : Alanine in vacuum in water t= 0.00000 + 22 + 1ACE HH31 1 0.152 0.743 2.212 1.3463 0.7349 -0.6803 + 1ACE CH3 2 0.131 0.822 2.284 0.3769 0.1596 -0.3198 + 1ACE HH32 3 0.108 0.767 2.375 0.4564 -1.3627 -1.1820 + 1ACE HH33 4 0.052 0.894 2.264 -0.4328 -0.2112 1.3689 + 1ACE C 5 0.265 0.895 2.297 0.3480 0.2517 0.0099 + 1ACE O 6 0.269 0.977 2.388 0.4177 0.6042 0.0340 + 2ALA N 7 0.368 0.871 2.208 -0.4973 0.2362 0.3098 + 2ALA H 8 0.341 0.815 2.129 0.3380 2.0597 -1.3356 + 2ALA CA 9 0.488 0.955 2.197 0.3024 0.6262 0.4482 + 2ALA HA 10 0.544 0.900 2.122 0.3036 -0.7766 1.4157 + 2ALA CB 11 0.448 1.092 2.138 -0.8250 -0.8273 -0.4712 + 2ALA HB1 12 0.538 1.131 2.091 -2.1604 -0.2744 -2.7135 + 2ALA HB2 13 0.382 1.084 2.051 -1.4850 -0.3191 -0.0220 + 2ALA HB3 14 0.423 1.156 2.222 -1.1056 -1.9931 0.3539 + 2ALA C 15 0.582 0.976 2.321 -0.6113 -0.1699 -0.1353 + 2ALA O 16 0.703 0.990 2.301 0.0894 0.1629 0.1034 + 3NME N 17 0.532 0.966 2.446 -0.5630 1.1218 -0.6656 + 3NME H 18 0.432 0.954 2.452 -0.4594 0.4793 0.0514 + 3NME CH3 19 0.599 0.972 2.578 -0.1680 0.4666 0.2671 + 3NME HH31 20 0.661 0.882 2.577 -0.5665 0.1075 -2.8701 + 3NME HH32 21 0.656 1.064 2.574 -1.3881 1.2509 0.5814 + 3NME HH33 22 0.527 0.949 2.656 -1.4662 -1.5983 -1.3731 + 3.02334 3.02334 3.02334 diff --git a/docs/notebooks/tutorials/data/alanine_gnn/alad_A.trr b/docs/notebooks/tutorials/data/alanine_gnn/alad_A.trr new file mode 100644 index 00000000..6fe9d209 Binary files /dev/null and b/docs/notebooks/tutorials/data/alanine_gnn/alad_A.trr differ diff --git a/docs/notebooks/tutorials/data/alanine_gnn/alad_B.trr b/docs/notebooks/tutorials/data/alanine_gnn/alad_B.trr new file mode 100644 index 00000000..e37d8648 Binary files /dev/null and b/docs/notebooks/tutorials/data/alanine_gnn/alad_B.trr differ diff --git a/docs/notebooks/tutorials/intro_3_loss_optim.ipynb b/docs/notebooks/tutorials/intro_3_loss_optim.ipynb index 36e27ccf..7eb10823 100644 --- a/docs/notebooks/tutorials/intro_3_loss_optim.ipynb +++ b/docs/notebooks/tutorials/intro_3_loss_optim.ipynb @@ -83,7 +83,7 @@ "from mlcolvar.cvs import RegressionCV\n", "\n", "# define example CV\n", - "cv = RegressionCV(layers=[10,5,5,1], options={})\n", + "cv = RegressionCV(model=[10,5,5,1], options={})\n", "\n", "# choose optimizer\n", "cv.optimizer_name = 'Adam' \n", @@ -123,7 +123,7 @@ "options = {'optimizer' : {'lr' : 2e-3, 'weight_decay' : 1e-4} }\n", "\n", "# define example CV\n", - "cv = RegressionCV(layers=[10,5,5,1], options=options)\n", + "cv = RegressionCV(model=[10,5,5,1], options=options)\n", "\n", "print(f'optimizer_kwargs: {cv.optimizer_kwargs}')" ] @@ -155,7 +155,7 @@ "options = {'lr_scheduler' : { 'scheduler' : lr_scheduler, 'gamma' : 0.9999} }\n", "\n", "# define example CV\n", - "cv = RegressionCV(layers=[10,5,5,1], options=options)" + "cv = RegressionCV(model=[10,5,5,1], options=options)" ] }, { @@ -251,7 +251,7 @@ "from mlcolvar.cvs import DeepTICA\n", "\n", "# define CV\n", - "cv = DeepTICA(layers=[10, 5, 5, 2], options={})\n", + "cv = DeepTICA(model=[10, 5, 5, 2], options={})\n", "\n", "# print default loss mode\n", "print(f'default mode: {cv.loss_fn.mode}')\n", @@ -554,7 +554,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pytorch", + "display_name": "mlcvs_test", "language": "python", "name": "python3" }, @@ -570,12 +570,7 @@ "pygments_lexer": "ipython3", "version": "3.10.8" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "1cbeac1d7079eaeba64f3210ccac5ee24400128e300a45ae35eee837885b08b3" - } - } + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/requirements.yaml b/docs/requirements.yaml index 6693ccbb..d15ecac4 100644 --- a/docs/requirements.yaml +++ b/docs/requirements.yaml @@ -26,9 +26,12 @@ dependencies: - ipykernel - scikit-learn - scipy + - pyg # Pip-only installs - pip: - sphinx-copybutton - furo - KDEpy + - mdtraj + - matscipy diff --git a/mlcolvar/core/loss/committor_loss.py b/mlcolvar/core/loss/committor_loss.py index 03196b0f..5ed4e132 100644 --- a/mlcolvar/core/loss/committor_loss.py +++ b/mlcolvar/core/loss/committor_loss.py @@ -17,7 +17,10 @@ import torch from typing import Tuple, Union from mlcolvar.core.loss.utils.smart_derivatives import SmartDerivatives +import torch_geometric +import warnings +from mlcolvar.utils._code import scatter_sum # ============================================================================= # LOSS FUNCTIONS # ============================================================================= @@ -91,7 +94,7 @@ def __init__(self, self.n_dim = n_dim def forward(self, - x: torch.Tensor, + x: Union[torch.Tensor, torch_geometric.data.Batch], z: torch.Tensor, q: torch.Tensor, labels: torch.Tensor, @@ -103,7 +106,7 @@ def forward(self, Parameters ---------- - x : torch.Tensor + x : torch.Tensor or torch_geometric.data.Batch Model input, i.e., either positions or descriptors if using descriptors_derivatives z : torch.Tensor Model unactivated output, i.e., z value @@ -229,7 +232,23 @@ def committor_loss(x: torch.Tensor, if (z_threshold is not None and (z_regularization == 0 or z_threshold <= 0)) or (z_threshold is None and z_regularization != 0) or z_regularization < 0: raise ValueError(f"To apply the regularization on z space both z_threshold and z_regularization key must be positive. Found {z_threshold} and {z_regularization}!") + + # check if input is graph + if isinstance(x, torch_geometric.data.batch.Batch): + _is_graph_data = True + batch = torch.clone(x['batch']) + node_types = torch.where(x['node_attrs'])[1] + x = x['positions'] + else: + _is_graph_data = False + # checks and warnings + if _is_graph_data and descriptors_derivatives is not None: + raise ValueError("The descriptors_derivatives key cannot be used with GNN-based models!") + + if _is_graph_data and separate_boundary_dataset: + warnings.warn("Using GNN-based models it may be better to set separate_boundary_dataset to False") + # ------------------------ SETUP ------------------------ # inherit right device device = x.device @@ -242,16 +261,33 @@ def committor_loss(x: torch.Tensor, # Create masks to access different states data - mask_A = labels == 0 - mask_B = labels == 1 + mask_A = torch.nonzero(labels == 0, as_tuple=True) + mask_B = torch.nonzero(labels == 1, as_tuple=True) # create mask for variational data if separate_boundary_dataset: - mask_var = labels > 1 + mask_var = torch.nonzero(labels > 1, as_tuple=not(_is_graph_data)) else: mask_var = torch.ones_like(labels, dtype=torch.bool) + if _is_graph_data: + # this needs to be on the batch index, not only the labels + aux = torch.where(mask_var)[0].to(device) + mask_var_batches = torch.isin(batch, aux) + mask_var_batches = batch[mask_var_batches] + else: + mask_var_batches = mask_var + + # setup atomic masses + atomic_masses = atomic_masses.to(device) + + # mass should have size [1, n_atoms*spatial_dims] + if _is_graph_data: + atomic_masses = atomic_masses[node_types[mask_var_batches]].unsqueeze(-1) + else: + atomic_masses = atomic_masses.unsqueeze(0) + # Update weights of basin B using the information on the delta_f delta_f = torch.Tensor([delta_f]).to(device) # B higher in energy --> A-B < 0 @@ -275,44 +311,41 @@ def committor_loss(x: torch.Tensor, grad_outputs=grad_outputs, retain_graph=True, create_graph=create_graph)[0] - grad = grad[mask_var] - - if cell is not None: - grad = grad / cell - - # in case the input is not positions but descriptors, we need to correct the gradients up to the positions - if isinstance(descriptors_derivatives, SmartDerivatives): - # we use the precomputed derivatives from descriptors to pos - gradient_positions = descriptors_derivatives(grad, ref_idx[mask_var]).view(x[mask_var].shape[0], -1) - - # --> If we directly pass the matrix d_desc/d_pos - elif isinstance(descriptors_derivatives, torch.Tensor): - descriptors_derivatives = descriptors_derivatives.to(device) - gradient_positions = torch.einsum("bd,badx->bax", grad, descriptors_derivatives[ref_idx[mask_var]]).contiguous() - gradient_positions = gradient_positions.view(x[mask_var].shape[0], -1) + grad = grad[mask_var_batches] - # If the input was already positions + if descriptors_derivatives is not None: + # in case the input is not positions but descriptors, we need to correct the gradients up to the positions + if isinstance(descriptors_derivatives, SmartDerivatives): + # we use the precomputed derivatives from descriptors to pos + gradient_positions = descriptors_derivatives(grad, ref_idx[mask_var]).view(x[mask_var].shape[0], -1) + + # --> If we directly pass the matrix d_desc/d_pos + elif isinstance(descriptors_derivatives, torch.Tensor): + descriptors_derivatives = descriptors_derivatives.to(device) + gradient_positions = torch.einsum("bd,badx->bax", grad, descriptors_derivatives[ref_idx[mask_var]]).contiguous() + gradient_positions = gradient_positions.view(x[mask_var].shape[0], -1) else: + # we get the square of grad(q) and we multiply by the weight gradient_positions = grad - + + if cell is not None: + gradient_positions = gradient_positions / cell + # we do the square grad_square = torch.pow(gradient_positions, 2) - - # multiply by masses - try: - grad_square = torch.sum((grad_square * (1/atomic_masses)), - axis=1, - keepdim=True) - except RuntimeError as e: - raise RuntimeError(e, """[HINT]: Is you system in 3 dimension? By default the code assumes so, if it's not the case change the n_dim key to the right dimensionality.""") + grad_square = torch.sum((grad_square * (1/atomic_masses)), axis=1, keepdim=True) + + if _is_graph_data: + # we need to sum on the right batch first + grad_square = scatter_sum(grad_square, mask_var_batches, dim=0) + # variational contribution to loss: we sum over the batch loss_var = torch.mean(grad_square * w[mask_var]) if log_var: loss_var = torch.log1p(loss_var) else: - loss_var *= gamma - + loss_var = gamma*loss_var # 2. ----- BOUNDARY LOSS loss_A = gamma * torch.mean( q[mask_A].pow(2) ) @@ -329,4 +362,6 @@ def committor_loss(x: torch.Tensor, # 4. ----- TOTAL LOSS loss = loss_var + alpha*(loss_A + loss_B) + loss_z_diff + + # TODO maybe there is no need to detach them for logging return loss, loss_var.detach(), alpha*loss_A.detach(), alpha*loss_B.detach() \ No newline at end of file diff --git a/mlcolvar/core/loss/eigvals.py b/mlcolvar/core/loss/eigvals.py index e0c84345..4fc2c826 100644 --- a/mlcolvar/core/loss/eigvals.py +++ b/mlcolvar/core/loss/eigvals.py @@ -129,7 +129,7 @@ def reduce_eigenvalues_loss( else: n_eig = len(evals) - loss = None + loss = torch.zeros(1, dtype=evals.dtype, device=evals.device) if mode == "sum": loss = torch.sum(evals[:n_eig]) diff --git a/mlcolvar/core/loss/tda_loss.py b/mlcolvar/core/loss/tda_loss.py index 8bd8e830..50234ed4 100644 --- a/mlcolvar/core/loss/tda_loss.py +++ b/mlcolvar/core/loss/tda_loss.py @@ -15,7 +15,7 @@ # GLOBAL IMPORTS # ============================================================================= -from typing import Union +from typing import Union, List, Tuple from warnings import warn import torch @@ -32,10 +32,10 @@ class TDALoss(torch.nn.Module): def __init__( self, n_states: int, - target_centers: Union[list, torch.Tensor], - target_sigmas: Union[list, torch.Tensor], - alpha: float = 1, - beta: float = 100, + target_centers: Union[List[float], torch.Tensor], + target_sigmas: Union[List[float], torch.Tensor], + alpha: float = 1.0, + beta: float = 100.0, ): """Constructor. @@ -66,7 +66,7 @@ def __init__( def forward( self, H: torch.Tensor, labels: torch.Tensor, return_loss_terms: bool = False - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Compute the value of the loss function. Parameters @@ -107,12 +107,12 @@ def tda_loss( H: torch.Tensor, labels: torch.Tensor, n_states: int, - target_centers: Union[list, torch.Tensor], - target_sigmas: Union[list, torch.Tensor], + target_centers: Union[List[float], torch.Tensor], + target_sigmas: Union[List[float], torch.Tensor], alpha: float = 1, beta: float = 100, return_loss_terms: bool = False, -) -> torch.Tensor: +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Compute a loss function as the distance from a simple Gaussian target distribution. @@ -148,9 +148,9 @@ def tda_loss( term associated to the standard deviations of the target Gaussians. """ if not isinstance(target_centers, torch.Tensor): - target_centers = torch.Tensor(target_centers) + target_centers = torch.tensor(target_centers, dtype=H.dtype) if not isinstance(target_sigmas, torch.Tensor): - target_sigmas = torch.Tensor(target_sigmas) + target_sigmas = torch.tensor(target_sigmas, dtype=H.dtype) device = H.device target_centers = target_centers.to(device) @@ -165,7 +165,7 @@ def tda_loss( f"State {i} was not represented in this batch! Either use bigger batch_size or a more equilibrated dataset composition!" ) else: - H_red = H[torch.nonzero(labels == i, as_tuple=True)] + H_red = H[labels == i] # compute mean and standard deviation over the class i mu = torch.mean(H_red, 0) @@ -173,7 +173,7 @@ def tda_loss( warn( f"There is only one sample for state {i} in this batch! Std is set to 0, this may affect the training! Either use bigger batch_size or a more equilibrated dataset composition!" ) - sigma = 0 + sigma = torch.Tensor(0) else: sigma = torch.std(H_red, 0) @@ -189,3 +189,18 @@ def tda_loss( if return_loss_terms: return loss, loss_centers, loss_sigmas return loss + +def test_tda_loss(): + H = torch.randn(100) + H.requires_grad = True + labels = torch.zeros_like(H) + labels[-50:] = 1 + + Loss = TDALoss(n_states=2, target_centers=[-1, 1], target_sigmas=[0.1, 0.1]) + + loss = Loss(H=H, labels=labels, return_loss_terms=True) + + loss[0].backward() + +if __name__ == '__main__': + test_tda_loss() \ No newline at end of file diff --git a/mlcolvar/core/loss/utils/smart_derivatives.py b/mlcolvar/core/loss/utils/smart_derivatives.py index ac699da0..a2e32f63 100644 --- a/mlcolvar/core/loss/utils/smart_derivatives.py +++ b/mlcolvar/core/loss/utils/smart_derivatives.py @@ -970,7 +970,7 @@ def test_train_with_smart_derivatives(): datamodule = DictModule(dataset=smart_dataset, lengths=[0.8, 0.2], batch_size=80) - model = Committor(layers=[45, 10, 1], + model = Committor(model=[45, 10, 1], atomic_masses=atomic_masses, alpha=1, separate_boundary_dataset=True, diff --git a/mlcolvar/core/nn/__init__.py b/mlcolvar/core/nn/__init__.py index 4ccf68d1..fa34b8af 100644 --- a/mlcolvar/core/nn/__init__.py +++ b/mlcolvar/core/nn/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["FeedForward"] +__all__ = ["FeedForward", "BaseGNN", "SchNetModel", "GVPModel"] from .feedforward import * +from .graph import * \ No newline at end of file diff --git a/mlcolvar/core/nn/feedforward.py b/mlcolvar/core/nn/feedforward.py index f84596dd..19c233bb 100644 --- a/mlcolvar/core/nn/feedforward.py +++ b/mlcolvar/core/nn/feedforward.py @@ -15,10 +15,9 @@ # GLOBAL IMPORTS # ============================================================================= -from typing import Optional, Union +from typing import Optional, Union, Any import torch -import lightning from mlcolvar.core.nn.utils import get_activation, parse_nn_options @@ -27,7 +26,7 @@ # ============================================================================= -class FeedForward(lightning.LightningModule): +class FeedForward(torch.nn.Module): """Define a feedforward neural network given the list of layers. Optionally dropout and batchnorm can be applied (the order is activation -> dropout -> batchnorm). @@ -110,3 +109,6 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return self.nn(x) + + def backward(self, loss: torch.Tensor, *args: Any, **kwargs: Any): + return loss.backward() diff --git a/mlcolvar/core/nn/graph/__init__.py b/mlcolvar/core/nn/graph/__init__.py new file mode 100644 index 00000000..a8b3be6d --- /dev/null +++ b/mlcolvar/core/nn/graph/__init__.py @@ -0,0 +1,5 @@ +__all__ = ["BaseGNN", "SchNetModel", "GVPModel"] + +from .gnn import BaseGNN +from .schnet import SchNetModel +from .gvp import GVPModel \ No newline at end of file diff --git a/mlcolvar/core/nn/graph/gnn.py b/mlcolvar/core/nn/graph/gnn.py new file mode 100644 index 00000000..1f2deb1d --- /dev/null +++ b/mlcolvar/core/nn/graph/gnn.py @@ -0,0 +1,241 @@ +import torch +from torch import nn +from typing import List, Dict, Tuple + +from mlcolvar.core.nn.graph import radial +from mlcolvar.utils import _code + +""" +GNN models. +""" + +__all__ = ['BaseGNN'] + + +class BaseGNN(nn.Module): + """ + Base class for Graph Neural Network (GNN) models + """ + + def __init__( + self, + n_out: int, + cutoff: float, + atomic_numbers: List[int], + pooling_operation: str, + n_bases: int = 6, + n_polynomials: int = 6, + basis_type: str = 'bessel', + ) -> None: + """Initializes the core of a GNN model, taking care of edge embeddings. + + Parameters + ---------- + n_out : int + Number of the output scalar node features. + cutoff : float + Cutoff radius of the basis functions. Should be the same as the cutoff + radius used to build the graphs. + atomic_numbers : List[int] + The atomic numbers mapping. + pooling_operation : str + Type of pooling operation to combine node-level features into graph-level features, either mean or sum + n_bases : int, optional + Size of the basis set used for the embedding, by default 6 + n_polynomials : int, optional + Order of the polynomials in the basis functions, by default 6 + basis_type : str, optional + Type of the basis function, by default 'bessel' + """ + super().__init__() + + self._radial_embedding = radial.RadialEmbeddingBlock(cutoff=cutoff, + n_bases=n_bases, + n_polynomials=n_polynomials, + basis_type=basis_type + ) + self.register_buffer( + 'n_out', torch.tensor(n_out, dtype=torch.int64) + ) + self.register_buffer( + 'cutoff', torch.tensor(cutoff, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + 'atomic_numbers', torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.pooling_operation = pooling_operation + + @property + def out_features(self): + return self.n_out + + @property + def in_features(self): + return None + + def embed_edge( + self, data: Dict[str, torch.Tensor], normalize: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Performs the model edge embedding form `torch_geometric.data.Batch` object. + + Parameters + ---------- + data: Dict[str, torch.Tensor] + The data dict. Usually from the `to_dict` method of a + `torch_geometric.data.Batch` object. + normalize: bool + If to return the normalized distance vectors, by default True. + + Returns + ------- + edge_lengths: torch.Tensor (shape: [n_edges, 1]) + The edge lengths. + edge_length_embeddings: torch.Tensor (shape: [n_edges, n_bases]) + The edge length embeddings. + edge_unit_vectors: torch.Tensor (shape: [n_edges, 3]) + The normalized edge vectors. + """ + vectors, lengths = get_edge_vectors_and_lengths( + positions=data['positions'], + edge_index=data['edge_index'], + shifts=data['shifts'], + normalize=normalize, + ) + return lengths, self._radial_embedding(lengths), vectors + + def pooling(self, + input : torch.Tensor, + data : Dict[str, torch.Tensor]) -> torch.Tensor: + """Performs pooling of the node-level outputs to obtain a graph-level output + + Parameters + ---------- + input : torch.Tensor + Nodel level features to be pooled + data : Dict[str, torch.Tensor] + Data batch containing the graph data informations + + Returns + ------- + torch.Tensor + Pooled output + """ + if self.pooling_operation == 'mean': + if 'system_masks' not in data.keys(): + out = _code.scatter_mean(input, data['batch'], dim=0) + else: + out = input * data['system_masks'] + out = _code.scatter_sum(out, data['batch'], dim=0) + out = out / data['n_system'] + + elif self.pooling_operation == 'sum': + if 'system_masks' in data.keys(): + out = input * data['system_masks'] + else: + out = _code.scatter_sum(input, data['batch'], dim=0) + else: + raise ValueError (f"Invalid pooling operation! Found {self.pooling_operation}") + + return out + +def get_edge_vectors_and_lengths( + positions: torch.Tensor, + edge_index: torch.Tensor, + shifts: torch.Tensor, + normalize: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculates edge vectors and lengths by indices and shift vectors. + + Parameters + ---------- + positions: torch.Tensor (shape: [n_atoms, 3]) + The positions tensor. + edge_index: torch.Tensor (shape: [2, n_edges]) + The edge indices. + shifts: torch.Tensor (shape: [n_edges, 3]) + The shifts vector. + normalize: bool + If to return the normalized distance vectors, by default True. + + Returns + ------- + vectors: torch.Tensor (shape: [n_edges, 3]) + The distances vectors. + lengths: torch.Tensor (shape: [n_edges, 1]) + The edges lengths. + """ + sender = edge_index[0] + receiver = edge_index[1] + vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + + if normalize: + vectors = torch.nan_to_num(torch.div(vectors, lengths)) + + return vectors, lengths + + +def test_get_edge_vectors_and_lengths() -> None: + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + + data = dict() + data['positions'] = torch.tensor( + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], + dtype=torch.float64 + ) + data['edge_index'] = torch.tensor( + [[0, 0, 1, 1, 2, 2], [2, 1, 0, 2, 1, 0]] + ) + data['shifts'] = torch.tensor([ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, -0.2, 0.0], + [0.0, 0.0, 0.0], + ]) + + vectors, distances = get_edge_vectors_and_lengths(**data, normalize=False) + assert(torch.allclose(vectors, torch.tensor([[0.0700, -0.0700, 0.0000], + [0.0700, 0.0700, 0.0000], + [-0.070, -0.0700, 0.0000], + [0.0000, 0.0600, 0.0000], + [0.0000, -0.0600, 0.0000], + [-0.070, 0.0700, 0.0000]]) + ) + ) + assert(torch.allclose(distances,torch.tensor([[0.09899494936611666], + [0.09899494936611666], + [0.09899494936611666], + [0.06000000000000000], + [0.06000000000000000], + [0.09899494936611666]]) + ) + ) + + vectors, distances = get_edge_vectors_and_lengths(**data, normalize=True) + assert(torch.allclose(vectors, torch.tensor([[0.70710678118654757, -0.70710678118654757, 0.0], + [0.70710678118654757, 0.70710678118654757, 0.0], + [-0.7071067811865476, -0.70710678118654757, 0.0], + [0.00000000000000000, 1.00000000000000000, 0.0], + [0.00000000000000000, -1.00000000000000000, 0.0], + [-0.7071067811865476, 0.70710678118654757, 0.0]]) + ) + ) + + assert(torch.allclose(distances, torch.tensor([[0.09899494936611666], + [0.09899494936611666], + [0.09899494936611666], + [0.06000000000000000], + [0.06000000000000000], + [0.09899494936611666]]) + ) + ) + + torch.set_default_dtype(dtype) + +if __name__ == "__main__": + test_get_edge_vectors_and_lengths() \ No newline at end of file diff --git a/mlcolvar/core/nn/graph/gvp.py b/mlcolvar/core/nn/graph/gvp.py new file mode 100644 index 00000000..e1f6d7c7 --- /dev/null +++ b/mlcolvar/core/nn/graph/gvp.py @@ -0,0 +1,920 @@ +import functools +import math +import torch +from torch import nn +from torch_geometric.nn import MessagePassing +from typing import Tuple, Callable, Optional, List, Dict + +from mlcolvar.core.nn.graph.gnn import BaseGNN + +""" +The Geometric Vector Perceptron (GVP) layer. This module is taken from: +https://github.com/chaitjo/geometric-gnn-dojo/blob/main/models/layers/py, +and made compilable. +""" + +__all__ = ['GVPModel', 'GVPConvLayer', 'LayerNorm', 'Dropout'] + + +class GVPModel(BaseGNN): + """ + The Geometric Vector Perceptron (GVP) model [1, 2] with vector gate [2]. + + References + ---------- + .. [1] Jing, Bowen, et al. + "Learning from protein structure with geometric vector perceptrons." + International Conference on Learning Representations. 2020. + .. [2] Jing, Bowen, et al. + "Equivariant graph neural networks for 3d macromolecular structure." + arXiv preprint arXiv:2106.03843 (2021). + """ + def __init__( + self, + n_out: int, + cutoff: float, + atomic_numbers: List[int], + pooling_operation : str = 'mean', + n_bases: int = 8, + n_polynomials: int = 6, + n_layers: int = 1, + n_messages: int = 2, + n_feedforwards: int = 2, + n_scalars_node: int = 8, + n_vectors_node: int = 8, + n_scalars_edge: int = 8, + drop_rate: int = 0.1, + activation: str = 'SiLU', + basis_type: str = 'bessel', + smooth: bool = False, + ) -> None: + """Initializes a Geometric Vector Perceptron (GVP) model. + + Parameters + ---------- + n_out: int + Number of the output scalar node features. + cutoff: float + Cutoff radius of the basis functions. Should be the same as the cutoff + radius used to build the graphs. + atomic_numbers: List[int] + The atomic numbers mapping + pooling_operation : str + Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default 'mean' + n_bases: int + Size of the basis set used for the embedding, by default 8. + n_polynomials: bool + Order of the polynomials in the basis functions, by default 6. + n_layers: int + Number of the graph convolution layers, by default 1. + n_messages: int + Number of GVP layers to be used in the message functions, by default 2. + n_feedforwards: int + Number of GVP layers to be used in the feedforward functions, by default 2. + n_scalars_node: int + Size of the scalar channel of the node embedding in hidden layers, by default 8. + n_vectors_node: int + Size of the vector channel of the node embedding in hidden layers, by default 8. + n_scalars_edge: int + Size of the scalar channel of the edge embedding in hidden layers, by default 8. + drop_rate: int + Drop probability in all dropout layers, by default 0.1. + activation: str + Name of the activation function to be used in the GVPs (case sensitive), by default SiLU. + basis_type: str + Type of the basis function, by default bessel. + smooth: bool + If use the smoothed GVPConv, by default False. + """ + super().__init__( + n_out=n_out, + cutoff=cutoff, + atomic_numbers=atomic_numbers, + pooling_operation=pooling_operation, + n_bases=n_bases, + n_polynomials=n_polynomials, + basis_type=basis_type + ) + + self.W_e = nn.ModuleList([ + LayerNorm((n_bases, 1)), + GVP(in_dims=(n_bases, 1), + out_dims=(n_scalars_edge, 1), + activations=(None, None), + vector_gate=True + ) + ]) + + self.W_v = nn.ModuleList([ + LayerNorm((len(atomic_numbers), 0)), + GVP(in_dims=(len(atomic_numbers), 0), + out_dims=(n_scalars_node, n_vectors_node), + activations=(None, None), + vector_gate=True + ) + ]) + + self.layers = nn.ModuleList( + GVPConvLayer(node_dims=(n_scalars_node, n_vectors_node), + edge_dims=(n_scalars_edge, 1), + n_message=n_messages, + n_feedforward=n_feedforwards, + drop_rate=drop_rate, + activations=(eval(f'torch.nn.{activation}')(), None), + vector_gate=True, + cutoff=(cutoff if smooth else -1) + ) + for _ in range(n_layers) + ) + + self.W_out = nn.ModuleList([ + LayerNorm((n_scalars_node, n_vectors_node)), + GVP(in_dims=(n_scalars_node, n_vectors_node), + out_dims=(n_out, 0), + activations=(None, None), + vector_gate=True) + ]) + + def forward( + self, data: Dict[str, torch.Tensor], pool: bool = True + ) -> torch.Tensor: + """The forward pass. + + Parameters + ---------- + data: Dict[str, torch.Tensor] + The data dict. Usually came from the `to_dict` method of a + `torch_geometric.data.Batch` object. + pool: bool + If perform the pooling to the model output, by default True. + """ + h_V = (data['node_attrs'], None) + for w in self.W_v: + h_V = w(h_V) + h_V_1, h_V_2 = h_V + assert h_V_2 is not None + h_V = (h_V_1, h_V_2) + + h_E = self.embed_edge(data) + lengths = h_E[0] + h_E = (h_E[1], h_E[2].unsqueeze(-2)) + for w in self.W_e: + h_E = w(h_E) + h_E_1, h_E_2 = h_E + assert h_E_2 is not None + h_E = (h_E_1, h_E_2) + + for layer in self.layers: + h_V = layer(h_V, data['edge_index'], h_E, lengths) + + for w in self.W_out: + h_V = w(h_V) + out = h_V[0] + + if pool: + out = self.pooling(input=out, data=data) + + return out + + +class GVP(nn.Module): + """ + Geometric Vector Perceptron (GVP) layer from [1, 2] with vector gate [2]. + + References + ---------- + .. [1] Jing, Bowen, et al. + "Learning from protein structure with geometric vector perceptrons." + International Conference on Learning Representations. 2020. + .. [2] Jing, Bowen, et al. + "Equivariant graph neural networks for 3d macromolecular structure." + arXiv preprint arXiv:2106.03843 (2021). + """ + + def __init__( + self, + in_dims: Tuple[int, Optional[int]], + out_dims: Tuple[int, Optional[int]], + h_dim: Tuple[int, Optional[int]] = None, + activations: Tuple[ + Optional[Callable], Optional[Callable] + ] = (nn.functional.relu, torch.sigmoid), + vector_gate: bool = True, + ) -> None: + r"""Geometric Vector Perceptron layer. + + Updates the scalar node feature s as: + .. math:: bm{s}^n \leftarrow \sigma \left(\bm{s}'\right) \quad\text{with}\quad \bm{s}' \coloneq \bm{W}_m \left[{\|\bm{W}_h\vec{\bm{v}}^o\|_2 \atop \bm{s}^o}\right] + \bm{b} + + And the vector nore features as: + .. math:: \vec{\bm{v}}^n \leftarrow \sigma_g \left(\bm{W}_g\left(\sigma^+ \left(\bm{s}'\right)\right) + \bm{b}_g \right) \odot \bm{W}_\mu\bm{W}_h\vec{\bm{v}}^o + + Parameters + ---------- + in_dims : Tuple[int, Optional[int]] + Dimension of inputs + out_dims : Tuple[int, Optional[int]] + Dimension of outputs + h_dim : Tuple[int, Optional[int]], optional + Intermidiate number of vector channels, by default None + activations : Tuple[ Optional[Callable], Optional[Callable] ], optional + Scalar and vector activation functions (scalar_act, vector_act), by default (nn.functional.relu, torch.sigmoid) + vector_gate : bool, optional + Whether to use vector gating, by default True. The vector activation will be used as sigma^+ in vector gating if `True` + """ + super(GVP, self).__init__() + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.vector_gate = vector_gate + if self.vi: + self.h_dim = h_dim or max(self.vi, self.vo) + self.wh = nn.Linear(self.vi, self.h_dim, bias=False) + self.ws = nn.Linear(self.h_dim + self.si, self.so) + if self.vo: + self.wv = nn.Linear(self.h_dim, self.vo, bias=False) + if self.vector_gate: + self.wsv = nn.Linear(self.so, self.vo) + else: + self.wv = None + self.wsv = None + else: + self.wh = None + self.wv = None + self.wsv = None + self.ws = nn.Linear(self.si, self.so) + + self.scalar_act, self.vector_act = activations + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward( + self, + x: Tuple[torch.Tensor, Optional[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass of GVP + + Parameters + ---------- + x : Tuple[torch.Tensor, Optional[torch.Tensor]] + Input scalar and vector node embeddings + + Returns + ------- + Tuple[torch.Tensor, Optional[torch.Tensor]] + Input scalar and vector node embeddings + """ + + s, v = x + if v is not None: + assert self.wh is not None + v = torch.transpose(v, -1, -2) + vh = self.wh(v) + vn = _norm_no_nan(vh, axis=-2) + s = self.ws(torch.cat([s, vn], -1)) + if self.vo: + assert self.wv is not None + v = self.wv(vh) + v = torch.transpose(v, -1, -2) + if self.vector_gate: + assert self.wsv is not None + gate = ( + self.wsv(self.vector_act(s)) + if self.vector_act is not None + else self.wsv(s) + ) + v = v * torch.sigmoid(gate).unsqueeze(-1) + elif self.vector_act is not None: + v = v * self.vector_act( + _norm_no_nan(v, axis=-1, keepdims=True) + ) + else: + s = self.ws(s) + if self.vo: + v = torch.zeros( + s.shape[0], + self.vo, + 3, + device=self.dummy_param.device, + dtype=s.dtype + ) + else: + v = None + + if self.scalar_act is not None: + s = self.scalar_act(s) + + return s, v + + +class GVPConv(MessagePassing): + """ + Graph convolution / message passing with Geometric Vector Perceptrons. + """ + propagate_type = { + 's': torch.Tensor, + 'v': torch.Tensor, + 'edge_attr_s': torch.Tensor, + 'edge_attr_v': torch.Tensor, + 'edge_lengths': torch.Tensor, + } + + def __init__( + self, + in_dims, + out_dims, + edge_dims, + n_layers=3, + aggr='mean', + activations=(nn.functional.relu, torch.sigmoid), + vector_gate=True, + cutoff: float = -1.0, + ) -> None: + """Graph convolution / message passing with Geometric Vector Perceptrons. + Takes in a graph with node and edge embeddings, + and returns new node embeddings. + + This does NOT do residual updates and pointwise feedforward layers + --- see `GVPConvLayer` instead. + + Parameters + ---------- + in_dims : + input node embedding dimensions (n_scalar, n_vector) + out_dims : + output node embedding dimensions (n_scalar, n_vector) + edge_dims : + input edge embedding dimensions (n_scalar, n_vector) + n_layers : int, optional + number of GVPs in the message function, by default 3 + aggr : str, optional + Type of message aggregate function, by default 'mean' + activations : tuple, optional + activation functions (scalar_act, vector_act) to be used use in GVPs, by default (nn.functional.relu, torch.sigmoid) + vector_gate : bool, optional + Whether to use vector gating, by default True. The vector activation will be used as sigma^+ in vector gating if `True` + cutoff : float, optional + Radial cutoff, by default -1.0 + """ + super(GVPConv, self).__init__(aggr=aggr) + self.si, self.vi = in_dims + self.so, self.vo = out_dims + self.se, self.ve = edge_dims + self.cutoff = cutoff + + GVP_ = functools.partial( + GVP, activations=activations, vector_gate=vector_gate + ) + + self._module_list = torch.nn.ModuleList() + if n_layers == 1: + self._module_list.append( + GVP_(in_dims=(2 * self.si + self.se, 2 * self.vi + self.ve), + out_dims=(self.so, self.vo), + activations=(None, None)) + ) + else: + self._module_list.append( + GVP_(in_dims=(2 * self.si + self.se, 2 * self.vi + self.ve), + out_dims=out_dims) + ) + for i in range(n_layers - 2): + self._module_list.append(GVP_(out_dims, out_dims)) + self._module_list.append( + GVP_(in_dims=out_dims, + out_dims=out_dims, + activations=(None, None)) + ) + + def forward( + self, + x: Tuple[torch.Tensor, torch.Tensor], + edge_index: torch.Tensor, + edge_attr: Tuple[torch.Tensor, torch.Tensor], + edge_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of GVPConv + + Parameters + ---------- + x : Tuple[torch.Tensor, torch.Tensor] + Input scalar and vector node embeddings + edge_index : torch.Tensor + Index of edge sources and destinations + edge_attr : Tuple[torch.Tensor, torch.Tensor] + Edge attributes + edge_lengths : torch.Tensor + Edge lengths + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Output scalar and vector node embeddings + """ + x_s, x_v = x + assert x_v is not None + message = self.propagate( + edge_index, + s=x_s, + v=x_v.contiguous().view(x_v.shape[0], x_v.shape[1] * 3), + edge_attr_s=edge_attr[0], + edge_attr_v=edge_attr[1], + edge_lengths=edge_lengths, + ) + return _split(message, self.vo) + + def message( + self, + s_i: torch.Tensor, + v_i: torch.Tensor, + s_j: torch.Tensor, + v_j: torch.Tensor, + edge_attr_s: torch.Tensor, + edge_attr_v: torch.Tensor, + edge_lengths: torch.Tensor, + ) -> torch.Tensor: + assert edge_attr_s is not None + assert edge_attr_v is not None + v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3) + v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3) + message = _tuple_cat( + (s_j, v_j), (edge_attr_s, edge_attr_v), (s_i, v_i) + ) + message = self.message_func(message) + message_merged = _merge(*message) + if self.cutoff > 0: + # apply SchNet-style cutoff function + c = 0.5 * (torch.cos(edge_lengths * math.pi / self.cutoff) + 1.0) + message_merged = message_merged * c.view(-1, 1) + return message_merged + + def message_func( + self, x: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + for m in self._module_list: + x = m(x) + output_1, output_2 = x + assert output_2 is not None + return output_1, output_2 + + +class GVPConvLayer(nn.Module): + """ + Full graph convolution / message passing layer with + Geometric Vector Perceptrons. + Residually updates node embeddings with + aggregated incoming messages, applies a pointwise feedforward + network to node embeddings, and returns updated node embeddings. + + To only compute the aggregated messages, see `GVPConv`. + """ + + def __init__( + self, + node_dims, + edge_dims, + n_message=3, + n_feedforward=2, + drop_rate=0.1, + activations=(nn.functional.relu, torch.sigmoid), + vector_gate=True, + residual=True, + cutoff: float = -1.0, + ) -> None: + """Full graph convolution / message passing layer with + Geometric Vector Perceptrons. + Residually updates node embeddings with + aggregated incoming messages, applies a pointwise feedforward + network to node embeddings, and returns updated node embeddings. + + To only compute the aggregated messages see `GVPConv` instead. + + Parameters + ---------- + node_dims : + node embedding dimensions (n_scalar, n_vector) + edge_dims : + input edge embedding dimensions (n_scalar, n_vector) + n_message : int, optional + number of GVP layers to be used in message function, by default 3 + n_feedforward : int, optional + number of GVPs to be used use in feedforward function, by default 2 + drop_rate : float, optional + drop probability in all dropout layers, by default 0.1 + activations : tuple, optional + activation functions (scalar_act, vector_act) to be used use in GVPs, by default (nn.functional.relu, torch.sigmoid) + vector_gate : bool, optional + whether to use vector gating, by default True. The vector activation will be used as sigma^+ in vector gating if `True` + residual : bool, optional + whether to perform the update residually, by default True + cutoff : float, optional + radial cutoff, by default -1.0 + """ + super(GVPConvLayer, self).__init__() + self.conv = GVPConv( + node_dims, + node_dims, + edge_dims, + n_message, + aggr='mean', + activations=activations, + vector_gate=vector_gate, + cutoff=cutoff, + ) + GVP_ = functools.partial( + GVP, activations=activations, vector_gate=vector_gate + ) + self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) + self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) + + self._module_list = nn.ModuleList() + if n_feedforward == 1: + self._module_list.append( + GVP_(in_dims=node_dims, + out_dims=node_dims, + activations=(None, None)) + ) + else: + hid_dims = 4 * node_dims[0], 2 * node_dims[1] + self._module_list.append(GVP_(node_dims, hid_dims)) + self._module_list.extend( + GVP_(in_dims=hid_dims, out_dims=hid_dims) for _ in range(n_feedforward - 2) + ) + self._module_list.append( + GVP_(in_dims=hid_dims, out_dims=node_dims, activations=(None, None)) + ) + self.residual = residual + + def forward( + self, + x: Tuple[torch.Tensor, torch.Tensor], + edge_index: torch.Tensor, + edge_attr: Tuple[torch.Tensor, torch.Tensor], + edge_lengths: torch.Tensor, + node_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass of GVPConvLayer + + Parameters + ---------- + x : Tuple[torch.Tensor, torch.Tensor] + Input scalar and vector node embeddings + edge_index : torch.Tensor + Index of edge sources and destinations + edge_attr : Tuple[torch.Tensor, torch.Tensor] + Edge attributes + edge_lengths : torch.Tensor + Edge lengths + node_mask : Optional[torch.Tensor], optional + Mask to restrict the node update to a subset. + It should be a tensor of type `bool` to index the first dim of node embeddings (s, V), by default None. + If not `None`, only the selected nodes will be updated. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Output scalar and vector node embeddings + """ + + dh = self.conv(x, edge_index, edge_attr, edge_lengths) + + x_ = x + if node_mask is not None: + x, dh = _tuple_index(x, node_mask), _tuple_index(dh, node_mask) + + if self.residual: + input_1, input_2 = self.dropout[0](dh) + assert input_2 is not None + output_1, output_2 = self.norm[0]( + _tuple_sum(x, (input_1, input_2)) + ) + assert output_2 is not None + x = (output_1, output_2) + else: + x = dh + + dh = self.ff_func(x) + if self.residual: + input_1, input_2 = self.dropout[1](dh) + assert input_2 is not None + output_1, output_2 = self.norm[1]( + _tuple_sum(x, (input_1, input_2)) + ) + assert output_2 is not None + x = (output_1, output_2) + else: + x = dh + + if node_mask is not None: + x_[0][node_mask], x_[1][node_mask] = x[0], x[1] + x = x_ + return x + + def ff_func( + self, x: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + for m in self._module_list: + x = m(x) + output_1 = x[0] + output_2 = x[1] + assert output_2 is not None + return output_1, output_2 + + +class LayerNorm(nn.Module): + """ + Combined LayerNorm for tuples (s, V). + Takes tuples (s, V) as input and as output. + """ + + def __init__(self, dims) -> None: + super(LayerNorm, self).__init__() + self.s, self.v = dims + self.scalar_norm = nn.LayerNorm(self.s) + + def forward( + self, + x: Tuple[torch.Tensor, Optional[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass of LayerNorm + + Parameters + ---------- + x : Tuple[torch.Tensor, Optional[torch.Tensor]] + Input channels, if a single tensor is provided it assumes it to be the scalar channel + + Returns + ------- + Tuple[torch.Tensor, Optional[torch.Tensor]] + Normalized channels + """ + + s, v = x + if not self.v: + return self.scalar_norm(s), None + else: + assert v is not None + vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) + vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) + return self.scalar_norm(s), v / vn + + +class Dropout(nn.Module): + """ + Combined dropout for tuples (s, V). + Takes tuples (s, V) as input and as output. + """ + + def __init__(self, drop_rate) -> None: + super(Dropout, self).__init__() + self.sdropout = nn.Dropout(drop_rate) + self.vdropout = _VDropout(drop_rate) + + def forward( + self, + x: Tuple[torch.Tensor, Optional[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass of Dropout + + Parameters + ---------- + x : Tuple[torch.Tensor, Optional[torch.Tensor]] + Input channels, if a single tensor is provided it assumes it to be the scalar channel + + Returns + ------- + Tuple[torch.Tensor, Optional[torch.Tensor]] + Dropped out channels + """ + s, v = x + if v is None: + return self.sdropout(s), None + else: + assert v is not None + return self.sdropout(s), self.vdropout(v) + + +class _VDropout(nn.Module): + """ + Vector channel dropout where the elements of each + vector channel are dropped together. + """ + + def __init__(self, drop_rate) -> None: + super(_VDropout, self).__init__() + self.drop_rate = drop_rate + self.dummy_param = nn.Parameter(torch.empty(0)) + + def forward(self, x : torch.Tensor) -> torch.Tensor: + """Forward pass of _VDropout + + Parameters + ---------- + x : torch.Tensor + Vector channel + + Returns + ------- + torch.Tensor + Dropped out vector channel + """ + device = self.dummy_param.device + if not self.training: + return x + mask = torch.bernoulli( + (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) + ).unsqueeze(-1) + x = mask * x / (1 - self.drop_rate) + return x + + +def _tuple_sum( + input_1: Tuple[torch.Tensor, torch.Tensor], + input_2: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Sums any number of tuples (s, V) elementwise. + """ + out = [i + j for i, j in zip(input_1, input_2)] + return out[0], out[1] + + +@torch.jit.script_if_tracing +def _tuple_cat( + input_1: Tuple[torch.Tensor, torch.Tensor], + input_2: Tuple[torch.Tensor, torch.Tensor], + input_3: Tuple[torch.Tensor, torch.Tensor], + dim: int = -1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Concatenates any number of tuples (s, V) elementwise. + + Parameters + ---------- + input_1 : Tuple[torch.Tensor, torch.Tensor] + First input to concatenate + input_2 : Tuple[torch.Tensor, torch.Tensor] + Second input to concatenate + input_3 : Tuple[torch.Tensor, torch.Tensor] + Third input to concatenate + dim : int, optional + dimension along which to concatenate when viewed + as the `dim` index for the scalar-channel tensors, by default -1. + This means that `dim=-1` will be applied as + `dim=-2` for the vector-channel tensors. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Concatenated tuple + """ + + dim = int(dim % len(input_1[0].shape)) + s_args, v_args = list(zip(input_1, input_2, input_3)) + return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) + + +@torch.jit.script_if_tracing +def _tuple_index( + x: Tuple[torch.Tensor, torch.Tensor], idx: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Indexes a tuple (s, V) along the first dimension at a given index. + + Parameters + ---------- + x : Tuple[torch.Tensor, torch.Tensor] + Tuple to be indexed + idx : torch.Tensor + any object which can be used to index a `torch.Tensor` + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Tuple with the element at the given index + """ + return x[0][idx], x[1][idx] + + +@torch.jit.script_if_tracing +def _norm_no_nan( + x: torch.Tensor, + axis: int = -1, + keepdims: bool = False, + eps: float = 1e-8, + sqrt: bool = True +) -> torch.Tensor: + """L2 norm of tensor clamped above a minimum value `eps`. + + Parameters + ---------- + x : torch.Tensor + Input tensor + axis : int, optional + Axis along which to compute the norm, by default -1 + keepdims : bool, optional + Whether to keep the original dimensions, by default False + eps : float, optional + Lowest threshold for clamping the norm, by default 1e-8 + sqrt : bool, optional + Compute the sqaure root in L2 norm, by default True. + If `False`, returns the square of the L2 norm + + Returns + ------- + torch.Tensor + Normed tensor + """ + out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) + return torch.sqrt(out) if sqrt else out + + +@torch.jit.script_if_tracing +def _split(x: torch.Tensor, nv: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Splits a merged representation of (s, V) back into a tuple. + Should be used only with `_merge(s, V)` and only if the tuple + representation cannot be used. + + + Parameters + ---------- + x : torch.Tensor + the `torch.Tensor` returned from `_merge` + nv : int + the number of vector channels in the input to `_merge` + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + split representation + """ + s = x[..., :-3 * nv] + v = x[..., -3 * nv:].contiguous().view(x.shape[0], nv, 3) + return s, v + + +@torch.jit.script_if_tracing +def _merge(s: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Merges a tuple (s, V) into a single `torch.Tensor`, where the + vector channels are flattened and appended to the scalar channels. + Should be used only if the tuple representation cannot be used. + Use `_split(x, nv)` to reverse. + """ + v = v.contiguous().view(v.shape[0], v.shape[1] * 3) + return torch.cat([s, v], -1) + + +def test_gvp() -> None: + from mlcolvar.core.nn.graph.utils import _test_get_data + from mlcolvar.data.graph.utils import create_graph_tracing_example + + torch.manual_seed(0) + torch.set_default_dtype(torch.float64) + + model = GVPModel( + n_out=2, + cutoff=0.1, + atomic_numbers=[1, 8], + n_bases=6, + n_polynomials=6, + n_layers=2, + n_messages=2, + n_feedforwards=1, + n_scalars_node=16, + n_vectors_node=8, + n_scalars_edge=16, + drop_rate=0, + activation='SiLU', + ) + + data = _test_get_data() + ref_out = torch.tensor([[0.6100070244145421, -0.2559670171962067]] * 6) + assert ( torch.allclose(model(data), ref_out) ) + + traced_model = torch.jit.trace(model, example_inputs=create_graph_tracing_example(2)) + assert ( torch.allclose(traced_model(data), ref_out) ) + + model = GVPModel( + n_out=2, + cutoff=0.1, + atomic_numbers=[1, 8], + n_bases=6, + n_polynomials=6, + n_layers=2, + n_messages=2, + n_feedforwards=2, + n_scalars_node=16, + n_vectors_node=8, + n_scalars_edge=16, + drop_rate=0, + activation='SiLU', + ) + + data = _test_get_data() + ref_out = torch.tensor([[-0.3065361946949377, 0.16624918721972567]] * 6) + assert ( torch.allclose(model(data), ref_out) ) + + traced_model = torch.jit.trace(model, example_inputs=create_graph_tracing_example(2)) + assert ( torch.allclose(traced_model(data), ref_out) ) + + + torch.set_default_dtype(torch.float32) + +if __name__ == '__main__': + test_gvp() diff --git a/mlcolvar/core/nn/graph/radial.py b/mlcolvar/core/nn/graph/radial.py new file mode 100644 index 00000000..36224292 --- /dev/null +++ b/mlcolvar/core/nn/graph/radial.py @@ -0,0 +1,381 @@ +import torch +import numpy as np + +""" +The radial functions. This module is taken from MACE directly: +https://github.com/ACEsuit/mace/blob/main/mace/modules/radial.py +""" + +__all__ = ['RadialEmbeddingBlock'] + + +class GaussianBasis(torch.nn.Module): + """ + Gaussian basis functions. + """ + def __init__(self, cutoff: float, n_bases=32) -> None: + """Initialize a Gaussian basis function + + Parameters + ---------- + cutoff : float + Cutoff radius of the basis set + n_bases : int, optional + Size of the basis set, by default 32 + """ + super().__init__() + + offset = torch.linspace( + start=0.0, + end=cutoff, + steps=n_bases, + dtype=torch.get_default_dtype(), + ) + coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 + self.register_buffer( + 'cutoff', torch.tensor(cutoff, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + 'coeff', torch.tensor(coeff, dtype=torch.get_default_dtype()) + ) + self.register_buffer('offset', offset) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dist = x.view(-1, 1) - self.offset.view(1, -1) + return torch.exp(self.coeff * torch.pow(dist, 2)) + + def __repr__(self) -> str: + result = 'GAUSSIANBASIS [ ' + + data_string = '\033[32m{:d}\033[0m\033[36m 󰯰 \033[0m' + result = result + data_string.format(len(self.offset)) + result = result + '| ' + data_string = '\033[32m{:f}\033[0m\033[36m 󰳁 \033[0m' + result = result + data_string.format(self.cutoff) + result = result + ']' + + return result + + +class BesselBasis(torch.nn.Module): + r""" + Bessel radial basis functions (equation (7) in [1]) + + .. math:: RBF_n(d) = \sqrt{\frac{2}{c}\frac{sin(\frac{n\pi}{c}d)}{d}} + + References + ---------- + .. [1] Gasteiger, J.; Groß, J.; Günnemann, S. Directional Message Passing + for Molecular Graphs; ICLR 2020. + """ + + def __init__(self, cutoff: float, n_bases=8, trainable=False) -> None: + """Initializes Bessel radial basis function + + Parameters + ---------- + cutoff: float + Cutoff radius of the basis set + n_bases: int + Size of the basis set, by default 8 + trainable: bool + If to use trainable basis set parameters + """ + super().__init__() + + bessel_weights = ( + np.pi + / cutoff + * torch.linspace( + start=1.0, + end=n_bases, + steps=n_bases, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer('bessel_weights', bessel_weights) + + self.register_buffer( + 'cutoff', torch.tensor(cutoff, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + 'prefactor', + torch.tensor( + np.sqrt(2.0 / cutoff), dtype=torch.get_default_dtype() + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + numerator = torch.sin(self.bessel_weights * x) + return self.prefactor * (numerator / x) + + def __repr__(self) -> str: + result = 'BESSELBASIS [ ' + + data_string = '\033[32m{:d}\033[0m\033[36m 󰯰 \033[0m' + result = result + data_string.format(len(self.bessel_weights)) + result = result + '| ' + data_string = '\033[32m{:f}\033[0m\033[36m 󰳁 \033[0m' + result = result + data_string.format(self.cutoff) + if self.bessel_weights.requires_grad: + result = result + '|\033[36m TRAINABLE \033[0m' + result = result + ']' + + return result + + +class PolynomialCutoff(torch.nn.Module): + r"""Continuous cutoff function (equation (8) in [1]) + + .. math:: u(d) = 1 - \frac{(p+1)(p+2)}{2}d^p + p(p+2)d^{p+1} - \frac{p(p+1)}{2}d^{p+2} + + References + ---------- + .. [1] Gasteiger, J.; Groß, J.; Günnemann, S. Directional Message Passing + for Molecular Graphs; ICLR 2020. + """ + p: torch.Tensor + cutoff: torch.Tensor + + def __init__(self, cutoff: float, p: int = 6) -> None: + """initilalizes a polynomial cutoff function. + + Parameters + ---------- + cutoff: float + The cutoff radius. + p: int + Order of the polynomial, by default 6 + """ + super().__init__() + self.register_buffer( + 'p', torch.tensor(p, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + 'cutoff', torch.tensor(cutoff, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # fmt: off + envelope = ( + 1.0 + - (self.p + 1.0) * (self.p + 2.0) / 2.0 + * torch.pow(x / self.cutoff, self.p) + + self.p * (self.p + 2.0) + * torch.pow(x / self.cutoff, self.p + 1) + - self.p * (self.p + 1.0) / 2 + * torch.pow(x / self.cutoff, self.p + 2) + ) + # fmt: on + + # noinspection PyUnresolvedReferences + return envelope * (x < self.cutoff) + + def __repr__(self) -> str: + result = 'POLYNOMIALCUTOFF [ ' + + data_string = '\033[32m{:d}\033[0m\033[36m 󰰚 \033[0m' + result = result + data_string.format(int(self.p)) + result = result + '| ' + data_string = '\033[32m{:f}\033[0m\033[36m 󰳁 \033[0m' + result = result + data_string.format(self.cutoff) + result = result + ']' + + return result + + +class RadialEmbeddingBlock(torch.nn.Module): + """ + Radial embedding block [1] + + References + ---------- + .. [1] Gasteiger, J.; Groß, J.; Günnemann, S. Directional Message Passing + for Molecular Graphs; ICLR 2020. + """ + + def __init__( + self, + cutoff: float, + n_bases: int = 8, + n_polynomials: int = 6, + basis_type: str = 'bessel', + ) -> None: + """Initializes a radial embedding block + + Parameters + ---------- + cutoff : float + Cutoff radius. + n_bases : int, optional + Size of the basis set, by default 8 + n_polynomials : int, optional + Order of the polynomial for the polynomial cutoff, by default 6 + basis_type : str, optional + Type fo the basis function, by default 'bessel' + + Raises + ------ + RuntimeError + _description_ + """ + super().__init__() + self.n_out = n_bases + if basis_type == 'bessel': + self.bessel_fn = BesselBasis(cutoff=cutoff, n_bases=n_bases) + self.cutoff_fn = PolynomialCutoff(cutoff=cutoff, p=n_polynomials) + elif basis_type == 'gaussian': + self.bessel_fn = GaussianBasis(cutoff=cutoff, n_bases=n_bases) + self.cutoff_fn = None + else: + raise RuntimeError( + 'Unknown basis function type "{:s}" !'.format(basis_type) + ) + + def forward(self, edge_lengths: torch.Tensor) -> torch.Tensor: + """ + The forward pass of RadialEmbeddingBlock + + Parameters + ---------- + edge_lengths: torch.Tensor (shape: [n_edges, 1]) + Lengths of edges. + + Returns + ------- + edge_embedding: torch.Tensor (shape: [n_edges, n_bases]) + The radial edge embedding. + """ + r = self.bessel_fn(edge_lengths) # shape: [n_edges, n_bases] + if self.cutoff_fn is not None: + c = self.cutoff_fn(edge_lengths) # shape: [n_edges, 1] + return r * c + else: + return r + + +def test_bessel_basis() -> None: + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + + data = torch.tensor([ + [0.30216178425160090, 0.603495364055576400], + [0.29735174147757487, 0.565596622727919000], + [0.28586135770645804, 0.479487014442650350], + [0.26815929064765680, 0.358867177503655900], + [0.24496326504279375, 0.222421990229218020], + [0.21720530022724968, 0.090319042449653110], + [0.18598678410040770, -0.019467592388889482], + [0.15252575991598738, -0.094266103787986490], + [0.11809918979627002, -0.128642857533393970], + [0.08398320341397922, -0.124823366088228150] + ]) + + rbf = BesselBasis(6.0, 2) + + data_new = torch.stack( + [rbf(torch.ones(1) * i * 0.5 + 0.1) for i in range(0, 10)] + ) + + assert (torch.abs(data - data_new) < 1E-12).all() + + torch.set_default_dtype(dtype) + + print(rbf) + + +def test_gaussian_basis() -> None: + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + + data = torch.tensor([ + [0.9998611207557263, 0.6166385641763439], + [0.9950124791926823, 0.6669768108584744], + [0.9833348700493460, 0.7164317992468783], + [0.9650691177896804, 0.7642281651714904], + [0.9405880633643421, 0.8095716486678869], + [0.9103839103891423, 0.8516705072294410], + [0.8750517756337902, 0.8897581848801761], + [0.8352702114112720, 0.9231163463866358], + [0.7917795893122607, 0.9510973184771084], + [0.7453593045429805, 0.9731449630580510] + ]) + + rbf = GaussianBasis(6.0, 2) + + data_new = torch.stack( + [rbf(torch.ones(1) * i * 0.5 + 0.1)[0] for i in range(0, 10)] + ) + + assert (torch.abs(data - data_new) < 1E-12).all() + + torch.set_default_dtype(dtype) + + print(rbf) + + +def test_polynomial_cutoff() -> None: + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + + data = torch.tensor([ + [1.0000000000000000], + [0.9999919136092714], + [0.9995588277320531], + [0.9957733154296875], + [0.9803383630544124], + [0.9390599059360889], + [0.8554687500000000], + [0.7184512221655127], + [0.5317786922725198], + [0.3214569091796875] + ]) + + cutoff_function = PolynomialCutoff(6.0) + + data_new = torch.stack( + [cutoff_function(torch.ones(1) * i * 0.5) for i in range(0, 10)] + ) + + assert (torch.abs(data - data_new) < 1E-12).all() + + torch.set_default_dtype(dtype) + + print(cutoff_function) + +def test_radial_embedding_block(): + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float64) + + data = torch.tensor([ + [0.302161784075405670, 0.603495363703668900], + [0.297344780473306900, 0.565583382110980900], + [0.285645292705329600, 0.479124599728231300], + [0.266549578182040000, 0.356712961747292670], + [0.238761404317085600, 0.216790818528859370], + [0.201179558989195350, 0.083655164534829570], + [0.154832684273361420, -0.016206633178216297], + [0.104419964978618930, -0.064535087460860160], + [0.057909938358517744, -0.063080025890725560], + [0.023554408472511446, -0.035008673547055544] + ]) + + embedding = RadialEmbeddingBlock(6, 2, 6) + + data_new = torch.stack( + [embedding(torch.ones(1) * i * 0.5 + 0.1) for i in range(0, 10)] + ) + + assert (torch.abs(data - data_new) < 1E-12).all() + + torch.set_default_dtype(dtype) + + +if __name__ == '__main__': + test_bessel_basis() + test_gaussian_basis() + test_polynomial_cutoff() + test_radial_embedding_block() diff --git a/mlcolvar/core/nn/graph/schnet.py b/mlcolvar/core/nn/graph/schnet.py new file mode 100644 index 00000000..9b2c0f17 --- /dev/null +++ b/mlcolvar/core/nn/graph/schnet.py @@ -0,0 +1,382 @@ +import math +import torch +from torch import nn +from torch_geometric.nn import MessagePassing + +from mlcolvar.core.nn.graph.gnn import BaseGNN + +from typing import List, Dict + +""" +The SchNet components. This module is taken from the pgy package: +https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py +""" + +__all__ = ["SchNetModel", "InteractionBlock", "ShiftedSoftplus"] + +class SchNetModel(BaseGNN): + """ + The SchNet [1] model. This implementation is taken from torch_geometric: + https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py + + Parameters + ---------- + n_out: int + Size of the output node features. + cutoff: float + Cutoff radius of the basis functions. Should be the same as the cutoff + radius used to build the graphs. + atomic_numbers: List[int] + The atomic numbers mapping, e.g. the `atomic_numbers` attribute of a + `mlcolvar.graph.data.GraphDataSet` instance. + n_bases: int + Size of the basis set. + n_layers: int + Number of the graph convolution layers. + n_filters: int + Number of filters. + n_hidden_channels: int + Size of hidden embeddings. + aggr: str + Type of aggregation function for the GNN message passing. + w_out_after_pool: bool + If apply the readout MLP layer after the scatter sum. + References + ---------- + .. [1] Schütt, Kristof T., et al. "Schnet–a deep learning architecture for + molecules and materials." The Journal of Chemical Physics 148.24 + (2018). + """ + + def __init__( + self, + n_out: int, + cutoff: float, + atomic_numbers: List[int], + pooling_operation : str = 'mean', + n_bases: int = 16, + n_layers: int = 2, + n_filters: int = 16, + n_hidden_channels: int = 16, + aggr: str = 'mean', + w_out_after_pool: bool = False, + ) -> None: + """The SchNet model. This implementation is taken from torch_geometric: + https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/schnet.py + + Parameters + ---------- + n_out : int + Size of the output node features. + cutoff : float + Cutoff radius of the basis functions. Should be the same as the cutoff + radius used to build the graphs. + atomic_numbers : List[int] + The atomic numbers mapping. + pooling_operation : str + Type of pooling operation to combine node-level features into graph-level features, either mean or sum, by default 'mean' + n_bases : int, optional + Size of the basis set used for the embedding, by default 16 + n_layers : int, optional + Number of the graph convolution layers, by default 2 + n_filters : int, optional + Number of filters, by default 16 + n_hidden_channels : int, optional + Size of hidden embeddings, by default 16 + aggr : str, optional + Type of the GNN aggregation function, by default 'mean' + w_out_after_pool : bool, optional + Whether to apply the last linear transformation form hidden to output channels after the pooling sum, by default False + """ + + super().__init__( + n_out=n_out, + cutoff=cutoff, + atomic_numbers=atomic_numbers, + pooling_operation=pooling_operation, + n_bases=n_bases, + n_polynomials=0, + basis_type='gaussian' + ) + + # transforms embedding into hidden channels + self.W_v = nn.Linear( + in_features=len(atomic_numbers), + out_features=n_hidden_channels, + bias=False + ) + + # initialize layers with interaction blocks + self.layers = nn.ModuleList([ + InteractionBlock( + n_hidden_channels, n_bases, n_filters, cutoff, aggr + ) + for _ in range(n_layers) + ]) + + # transforms hidden channels into output channels + self.W_out = nn.ModuleList([ + nn.Linear(n_hidden_channels, n_hidden_channels // 2), + ShiftedSoftplus(), + nn.Linear(n_hidden_channels // 2, n_out) + ]) + + self._w_out_after_pool = w_out_after_pool + + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets all learnable parameters of the module. + """ + self.W_v.reset_parameters() + + for layer in self.layers: + layer.reset_parameters() + + nn.init.xavier_uniform_(self.W_out[0].weight) + self.W_out[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.W_out[2].weight) + self.W_out[2].bias.data.fill_(0) + + def forward( + self, data: Dict[str, torch.Tensor], pool: bool = True + ) -> torch.Tensor: + """ + The forward pass. + Parameters + ---------- + data: Dict[str, torch.Tensor] + The data dict. Usually came from the `to_dict` method of a + `torch_geometric.data.Batch` object. + pool: bool + If to perform the pooling to the model output. + """ + + # embed edges and node attrs + h_E = self.embed_edge(data) + h_V = self.W_v(data['node_attrs']) + + # update through layers + for layer in self.layers: + h_V = h_V + layer(h_V, data['edge_index'], h_E[0], h_E[1]) + + # in case the last linear transformation is performed BEFORE pooling + if not self._w_out_after_pool: + for w in self.W_out: + h_V = w(h_V) + out = h_V + + # perform pooling of the node-level ouptuts + if pool: + out = self.pooling(input=out, data=data) + + # in case the last linear transformation is performed AFTER pooling + if self._w_out_after_pool: + for w in self.W_out: + out = w(out) + + return out + +class InteractionBlock(nn.Module): + def __init__( + self, + hidden_channels: int, + num_gaussians: int, + num_filters: int, + cutoff: float, + aggr: str = 'mean' + ) -> None: + """SchNet interaction block + + Parameters + ---------- + hidden_channels : int + Size of hidden embeddings + num_gaussians : int + Number of Gaussians for the embedding + num_filters : int + Number of filters + cutoff : float + Radial cutoff + aggr : str, optional + Aggregation function, by default 'mean' + """ + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(num_gaussians, num_filters), + ShiftedSoftplus(), + nn.Linear(num_filters, num_filters), + ) + self.conv = CFConv( + hidden_channels, + hidden_channels, + num_filters, + self.mlp, + cutoff, + aggr + ) + self.act = ShiftedSoftplus() + self.lin = nn.Linear(hidden_channels, hidden_channels) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Resets all learnable parameters of the module. + """ + nn.init.xavier_uniform_(self.mlp[0].weight) + self.mlp[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.mlp[2].weight) + self.mlp[2].bias.data.fill_(0) + self.conv.reset_parameters() + nn.init.xavier_uniform_(self.lin.weight) + self.lin.bias.data.fill_(0) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_weight: torch.Tensor, + edge_attr: torch.Tensor, + ) -> torch.Tensor: + x = self.conv(x, edge_index, edge_weight, edge_attr) + x = self.act(x) + x = self.lin(x) + return x + + +class CFConv(MessagePassing): + """Continuos-filter convolution from SchNet""" + def __init__( + self, + in_channels: int, + out_channels: int, + num_filters: int, + network: nn.Sequential, + cutoff: float, + aggr: str = 'mean' + ) -> None: + """Applies a continuous-filter convolution + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + num_filters : int + Number of filters + network : nn.Sequential + Neural network + cutoff : float + Radial cutoff + aggr : str, optional + Aggregation function, by default 'mean' + """ + super().__init__(aggr=aggr) + self.lin1 = nn.Linear(in_channels, num_filters, bias=False) + self.lin2 = nn.Linear(num_filters, out_channels) + self.network = network + self.cutoff = cutoff + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.lin1.weight) + nn.init.xavier_uniform_(self.lin2.weight) + self.lin2.bias.data.fill_(0) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_weight: torch.Tensor, + edge_attr: torch.Tensor, + ) -> torch.Tensor: + C = 0.5 * (torch.cos(edge_weight * math.pi / self.cutoff) + 1.0) + W = self.network(edge_attr) * C.view(-1, 1) + + x = self.lin1(x) + x = self.propagate(edge_index, x=x, W=W) + x = self.lin2(x) + return x + + def message(self, x_j: torch.Tensor, W: torch.Tensor) -> torch.Tensor: + return x_j * W + +# TODO maybe remove and use the common one +class ShiftedSoftplus(nn.Module): + def __init__(self) -> None: + super().__init__() + self.shift = torch.log(torch.tensor(2.0)).item() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.softplus(x) - self.shift + + + +from mlcolvar.core.nn.graph.utils import _test_get_data +from mlcolvar.data.graph.utils import create_graph_tracing_example + +def test_schnet_1() -> None: + torch.manual_seed(0) + torch.set_default_dtype(torch.float64) + + model = SchNetModel( + n_out=2, + cutoff=0.1, + atomic_numbers=[1, 8], + n_bases=6, + n_layers=2, + n_filters=16, + n_hidden_channels=16 + ) + + data = _test_get_data() + ref_out = torch.tensor([[0.40384621527953063, -0.1257513365138969]] * 6) + assert ( torch.allclose(model(data), ref_out) ) + + model = SchNetModel( + n_out=2, + cutoff=0.1, + atomic_numbers=[1, 8], + n_bases=6, + n_layers=2, + n_filters=16, + n_hidden_channels=16, + pooling_operation='sum', + ) + + data = _test_get_data() + ref_out = torch.tensor([[0.5760462255365488, -0.4465858318467991]] * 6) + assert ( torch.allclose(model(data), ref_out) ) + + traced_model = torch.jit.trace(model, example_inputs=create_graph_tracing_example(2)) + assert ( torch.allclose(traced_model(data), ref_out) ) + + +def test_schnet_2() -> None: + torch.manual_seed(0) + torch.set_default_dtype(torch.float64) + + model = SchNetModel( + n_out=2, + cutoff=0.1, + atomic_numbers=[1, 8], + n_bases=6, + n_layers=2, + n_filters=16, + n_hidden_channels=16, + aggr='min', + w_out_after_pool=True + ) + + data = _test_get_data() + ref_out = torch.tensor([[0.3654537816221449, -0.0748265132499575]] * 6) + assert ( torch.allclose(model(data), ref_out) ) + + torch.set_default_dtype(torch.float32) + +if __name__ == "__main__": + test_schnet_1() \ No newline at end of file diff --git a/mlcolvar/core/nn/graph/utils.py b/mlcolvar/core/nn/graph/utils.py new file mode 100644 index 00000000..1a0e3780 --- /dev/null +++ b/mlcolvar/core/nn/graph/utils.py @@ -0,0 +1,54 @@ +import torch +import torch_geometric +import numpy as np + +from mlcolvar.data.graph import atomic, create_dataset_from_configurations +from mlcolvar.data import DictModule + + +def _test_get_data() -> torch_geometric.data.Batch: + # TODO: This is not a real test, but a helper function for other tests. + # Maybe should change its name. + torch.manual_seed(0) + torch.set_default_dtype(torch.float64) + + numbers = [8, 1, 1] + positions = np.array( + [ + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], + [[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07]], + [[0.0, 0.0, 0.0], [0.07, 0.0, 0.07], [-0.07, 0.0, 0.07]], + [[0.1, 0.0, 1.1], [0.17, 0.07, 1.1], [0.17, -0.07, 1.1]], + ], + dtype=np.float64 + ) + cell = np.identity(3, dtype=float) * 0.2 + graph_labels = np.array([[1]]) + node_labels = np.array([[0], [1], [1]]) + z_table = atomic.AtomicNumberTable.from_zs(numbers) + + config = [ + atomic.Configuration( + atomic_numbers=numbers, + positions=p, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + ) for p in positions + ] + dataset = create_dataset_from_configurations( + config, z_table, 0.1, show_progress=False + ) + + datamodule = DictModule( + dataset, + lengths=(1.0,), + batch_size=10, + shuffle=False, + ) + datamodule.setup() + + return next(iter(datamodule.train_dataloader()))['data_list'] \ No newline at end of file diff --git a/mlcolvar/core/transform/__init__.py b/mlcolvar/core/transform/__init__.py index 51d14adf..feaf3d19 100644 --- a/mlcolvar/core/transform/__init__.py +++ b/mlcolvar/core/transform/__init__.py @@ -1,4 +1,14 @@ -__all__ = ["Transform","Normalization","Statistics","SwitchingFunctions","MultipleDescriptors","PairwiseDistances","EigsAdjMat","ContinuousHistogram","Inverse","TorsionalAngle","SequentialTransform"] +__all__ = ["Transform", + "Normalization", + "Statistics", + "SwitchingFunctions", + "MultipleDescriptors", + "PairwiseDistances", + "EigsAdjMat", + "ContinuousHistogram", + "Inverse", + "TorsionalAngle", + "SequentialTransform"] from .transform import * from .utils import * diff --git a/mlcolvar/core/transform/tools/utils.py b/mlcolvar/core/transform/tools/utils.py index 6705ffff..3a73e589 100644 --- a/mlcolvar/core/transform/tools/utils.py +++ b/mlcolvar/core/transform/tools/utils.py @@ -3,7 +3,7 @@ from typing import Union, List -def batch_reshape(t: torch.Tensor, size: torch.Size) -> torch.Tensor: +def batch_reshape(t: torch.Tensor, size: List[int]) -> torch.Tensor: """Return value reshaped according to size. In case of batch unsqueeze and expand along the first dimension. For single inputs just pass. diff --git a/mlcolvar/core/transform/utils.py b/mlcolvar/core/transform/utils.py index 5a769ecb..23c35fad 100644 --- a/mlcolvar/core/transform/utils.py +++ b/mlcolvar/core/transform/utils.py @@ -146,7 +146,7 @@ def test_sequential_transform(): import lightning masses = initialize_committor_masses(atom_types=[0,0,0,0], masses=[1.008]) - model = Committor(layers=[6,2,1], atomic_masses=masses, alpha=1) + model = Committor(model=[6,2,1], atomic_masses=masses, alpha=1) model.preprocessing = sequential pos = torch.rand((5, 4, 3)) diff --git a/mlcolvar/cvs/committor/committor.py b/mlcolvar/cvs/committor/committor.py index 65dd5ceb..ad75217f 100644 --- a/mlcolvar/cvs/committor/committor.py +++ b/mlcolvar/cvs/committor/committor.py @@ -1,9 +1,10 @@ import torch import lightning from mlcolvar.cvs import BaseCV -from mlcolvar.core import FeedForward +from mlcolvar.core import FeedForward, BaseGNN from mlcolvar.core.loss import CommittorLoss from mlcolvar.core.nn.utils import Custom_Sigmoid +from typing import Union, List __all__ = ["Committor"] @@ -13,8 +14,10 @@ class Committor(BaseCV, lightning.LightningModule): The committor function q is expressed as the output of a neural network optimized with a self-consistent approach based on the Kolmogorov's variational principle for the committor and on the imposition of its boundary conditions. - **Data**: for training it requires a DictDataset with the keys 'data', 'labels' and 'weights' - + **Data**: for training it requires a DictDataset containing: + - If using descriptors as input, the keys 'data', 'labels' and 'weights'. + - If using graphs as input, `torch_geometric.data` with 'graph_labels' and 'weight' in their 'data_list'. + **Loss**: Minimize Kolmogorov's variational functional of q and impose boundary condition on the metastable states (CommittorLoss) References @@ -34,11 +37,12 @@ class Committor(BaseCV, lightning.LightningModule): Class to optimize the gradients calculation imporving speed and memory efficiency. """ - BLOCKS = ["nn", "sigmoid"] + DEFAULT_BLOCKS = ["nn", "sigmoid"] + MODEL_BLOCKS = ["nn", "sigmoid"] def __init__( self, - layers: list, + model: Union[List[int], FeedForward, BaseGNN], atomic_masses: torch.Tensor, alpha: float, gamma: float = 10000, @@ -74,7 +78,7 @@ def __init__( separate_boundary_dataset : bool, optional Switch to exculde boundary condition labeled data from the variational loss, by default True descriptors_derivatives : torch.nn.Module, optional - `SmartDerivatives` object to save memory and time when using descriptors. + `SmartDerivatives` object to save memory and time when using descriptors. Cannot be used with GNN models. See also mlcolvar.core.loss.committor_loss.SmartDerivatives log_var : bool, optional Switch to minimize the log of the variational functional, by default False. @@ -88,16 +92,18 @@ def __init__( Number of dimensions, by default 3. options : dict[str, Any], optional Options for the building blocks of the model, by default {}. - Available blocks: ['nn'] . + Available blocks: ['nn']. """ - super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs) + super().__init__(model, **kwargs) + + self.register_buffer('is_committor', torch.tensor(1, dtype=int)) # ======= LOSS ======= self.loss_fn = CommittorLoss(atomic_masses=atomic_masses, alpha=alpha, + cell=cell, gamma=gamma, delta_f=delta_f, - cell=cell, separate_boundary_dataset=separate_boundary_dataset, descriptors_derivatives=descriptors_derivatives, log_var=log_var, @@ -111,12 +117,18 @@ def __init__( options = self.parse_options(options) # ======= BLOCKS ======= - # initialize NN turning - o = "nn" - # set default activation to tanh - if "activation" not in options[o]: - options[o]["activation"] = "tanh" - self.nn = FeedForward(layers, **options[o]) + if not self._override_model: + # initialize NN + o = "nn" + # set default activation to tanh + if "activation" not in options[o]: + options[o]["activation"] = "tanh" + self.nn = FeedForward(self.layers, **options[o]) + elif self._override_model: + self.nn = model + + if self.nn.out_features != 1: + raise ValueError('Output of the model must be of dimension 1') # separately add sigmoid activation on last layer, this way it can be deactived o = "sigmoid" @@ -134,13 +146,19 @@ def training_step(self, train_batch, batch_idx): """Compute and return the training loss and record metrics.""" # =================get data=================== - x = train_batch["data"] - # check data are have shape (n_data, -1) - x = x.reshape((x.shape[0], -1)) - x.requires_grad = True - - labels = train_batch["labels"] - weights = train_batch["weights"] + if isinstance(self.nn, FeedForward): + x = train_batch["data"] + # check data have shape (n_data, -1) + x = x.reshape((x.shape[0], -1)) + x.requires_grad = True + + labels = train_batch["labels"] + weights = train_batch["weights"] + elif isinstance(self.nn, BaseGNN): + x = self._setup_graph_data(train_batch) + labels = x['graph_labels'] + weights = x['weight'].clone() + try: ref_idx = train_batch["ref_idx"] except KeyError: @@ -172,8 +190,7 @@ def training_step(self, train_batch, batch_idx): self.log(f"{name}_loss_bound_B", loss_bound_B, on_epoch=True) return loss - -def test_committor(): +def test_committor_1(): from mlcolvar.data import DictDataset, DictModule from mlcolvar.cvs.committor.utils import initialize_committor_masses, KolmogorovBias @@ -218,7 +235,7 @@ def test_committor(): -6.7121, -7.6094, -7.9009, -7.0479, -5.2398, -7.8241, -5.8642, -7.0701, -7.0348, -7.2577, -6.6142, -7.6322, -7.3279, -7.6393, -7.8608, -7.7037, -6.6949, -6.3947, -7.2246, -7.7009, -6.7359, -7.2186, -7.7849, -5.6882]) - model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) trainer.fit(model, datamodule) out = model(X) out.sum().backward() @@ -238,7 +255,7 @@ def test_committor(): [0.0783],[0.1384],[0.0689],[0.0649],[0.0983],[0.1548],[0.0778],[0.0934],[0.0858],[0.1203], [0.1073],[0.1139],[0.0716],[0.0988],[0.0918],[0.1109],[0.0918],[0.0928],[0.1070],[0.0742]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) - model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False) + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False) trainer.fit(model, datamodule) out = model(X) out.sum().backward() @@ -254,7 +271,7 @@ def test_committor(): [0.7714],[0.5826],[0.6442],[0.5796],[0.6132],[0.5923],[0.7023],[0.5731],[0.7308],[0.6404], [0.5781],[0.6850],[0.5960],[0.6718],[0.6626],[0.6069],[0.7319],[0.5498],[0.6772],[0.5847]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) - model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, log_var=True) + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, log_var=True) trainer.fit(model, datamodule) out = model(X) out.sum().backward() @@ -270,7 +287,7 @@ def test_committor(): [0.1337],[0.1444],[0.1603],[0.1396],[0.2043],[0.1964],[0.1459],[0.2243],[0.1930],[0.1893], [0.2634],[0.1868],[0.1340],[0.2483],[0.1550],[0.1559],[0.1614],[0.2020],[0.1270],[0.2555]]) trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) - model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=100, z_threshold=0.000001) + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=100, z_threshold=0.000001) trainer.fit(model, datamodule) out = model(X) out.sum().backward() @@ -281,7 +298,7 @@ def test_committor(): for z_regularization, z_threshold in zip([10, 0, -1, 10], [None, 10, 1, -1]): try: - model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=z_regularization, z_threshold=z_threshold, n_dim=2) + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=z_regularization, z_threshold=z_threshold, n_dim=2) trainer.fit(model, datamodule) except ValueError as e: print("[TEST LOG] Checked this error: ", e) @@ -289,10 +306,91 @@ def test_committor(): # test dimension error try: trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) - model = Committor(layers=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=10, z_threshold=1, n_dim=2) + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, z_regularization=10, z_threshold=1, n_dim=2) trainer.fit(model, datamodule) except RuntimeError as e: print("[TEST LOG] Checked this error: ", e) + + +def test_committor_2(): + from mlcolvar.data import DictDataset, DictModule + from mlcolvar.cvs.committor.utils import initialize_committor_masses, KolmogorovBias + + # create two fake atoms and use their fake positions + atomic_masses = initialize_committor_masses(atom_types=[0,1], masses=[15.999, 1.008]) + # create dataset + samples = 50 + X = torch.randn((4*samples, 6)) + + # create labels + y = torch.zeros(X.shape[0]) + y[samples:] += 1 + y[int(2*samples):] += 1 + y[int(3*samples):] += 1 + + # create weights + w = torch.ones(X.shape[0]) + + dataset = DictDataset({"data": X, "labels": y, "weights": w}) + datamodule = DictModule(dataset, lengths=[1]) + + # train model + trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) + + print() + print('NORMAL') + print() + # dataset separation + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) + trainer.fit(model, datamodule) + model(X).sum().backward() + bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1) + bias_model(X) + + # naive whole dataset + trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) + model = Committor(model=[6, 4, 2, 1], atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False) + trainer.fit(model, datamodule) + model(X).sum().backward() + + + print() + print('EXTERNAL FEEDFORWARD') + print() + # dataset separation + ff_model = FeedForward([6, 4, 2, 1]) + model = Committor(model=ff_model, atomic_masses=atomic_masses, alpha=1e-1, delta_f=0) + trainer.fit(model, datamodule) + model(X).sum().backward() + bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1) + bias_model(X) + + # naive whole dataset + trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0) + model = Committor(model=ff_model, atomic_masses=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False) + trainer.fit(model, datamodule) + model(X).sum().backward() + + + print() + print('EXTERNAL GNN') + print() + from mlcolvar.core.nn.graph import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + gnn_model = SchNetModel(1, 0.1, [1, 8]) + + model = Committor(model=gnn_model, + atomic_masses=atomic_masses, + alpha=1e-1, + delta_f=0) + + datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=3, n_atoms=3) + trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0, enable_model_summary=False) + trainer.fit(model, datamodule) + + example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) + + model(example_input_graph_test).sum().backward() @@ -352,7 +450,7 @@ def test_committor_with_derivatives(): datamodule = DictModule(dataset, lengths=[1.0]) # seed for reproducibility - model = Committor(layers=[45, 20, 1], + model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset) @@ -408,7 +506,7 @@ def test_committor_with_derivatives(): torch.manual_seed(42) datamodule = DictModule(dataset_desc, lengths=[1.0]) - model = Committor(layers=[45, 20, 1], + model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset, @@ -438,7 +536,7 @@ def test_committor_with_derivatives(): # test errors try: # separate boundary with explicit derivatives - model = Committor(layers=[45, 20, 1], + model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=True, @@ -470,7 +568,7 @@ def test_committor_with_derivatives(): torch.manual_seed(42) datamodule = DictModule(smart_dataset, lengths=[1.0]) - model = Committor(layers=[45, 20, 1], + model = Committor(model=[45, 20, 1], atomic_masses=masses, alpha=1, separate_boundary_dataset=separate_boundary_dataset, @@ -516,6 +614,29 @@ def test_committor_with_derivatives(): except ValueError as e: print("[TEST LOG] Checked this error: ", e) + + print() + print('EXTERNAL GNN') + print() + from mlcolvar.core.nn.graph import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + gnn_model = SchNetModel(1, 0.1, [1, 8]) + + model = Committor(model=gnn_model, + atomic_masses=masses, + alpha=1e-1, + delta_f=0) + + datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=3, n_atoms=3) + trainer = lightning.Trainer(max_epochs=5, logger=False, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0, enable_model_summary=False) + trainer.fit(model, datamodule) + + example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) + + model(example_input_graph_test).sum().backward() + + if __name__ == "__main__": - test_committor() + test_committor_1() + test_committor_2() test_committor_with_derivatives() \ No newline at end of file diff --git a/mlcolvar/cvs/committor/utils.py b/mlcolvar/cvs/committor/utils.py index 040f3a84..4aba561a 100644 --- a/mlcolvar/cvs/committor/utils.py +++ b/mlcolvar/cvs/committor/utils.py @@ -1,6 +1,8 @@ import torch import numpy as np from typing import List +from mlcolvar.core import FeedForward, BaseGNN +from mlcolvar.utils import _code from mlcolvar.data import DictDataset __all__ = ["KolmogorovBias", "compute_committor_weights", "initialize_committor_masses"] @@ -35,12 +37,34 @@ def __init__(self, self.epsilon = epsilon def forward(self, x): - x.requires_grad = True + if isinstance(self.input_model.nn, FeedForward): + x.requires_grad = True + + elif isinstance(self.input_model.nn, BaseGNN): + x['positions'].requires_grad_(True) + x['node_attrs'].requires_grad_(True) + q = self.input_model(x) grad_outputs = torch.ones_like(q) - grads = torch.autograd.grad(q, x, grad_outputs, retain_graph=True)[0] + + if isinstance(self.input_model.nn, BaseGNN): + grads = torch.autograd.grad(q, x['positions'], grad_outputs, retain_graph=True)[0] + + elif isinstance(self.input_model.nn, FeedForward): + grads = torch.autograd.grad(q, x, grad_outputs, retain_graph=True)[0] + grads_squared = torch.sum(torch.pow(grads, 2), 1) - bias = - self.lambd*(1/self.beta)*(torch.log( grads_squared + self.epsilon ) - torch.log(self.epsilon)) + + # gnn models need an additional scatter + if isinstance(self.input_model.nn, BaseGNN): + grads_squared = _code.scatter_sum(grads_squared, + x['batch'], + dim=0) + + print(grads_squared.shape) + + bias = - self.lambd*(1/self.beta)*(torch.log( grads_squared + self.epsilon ) - torch.log(self.epsilon)) + return bias def compute_committor_weights(dataset : DictDataset, @@ -66,23 +90,29 @@ def compute_committor_weights(dataset : DictDataset, ------- Updated dataset with weights and updated labels """ + if len(dataset) != len(bias): + raise ValueError('Dataset and bias have different lenghts!') if bias.isnan().any(): raise(ValueError('Found Nan(s) in bias tensor. Check before proceeding! If no bias was applied replace Nan with zero!')) - n_labels = len(torch.unique(dataset['labels'])) + if dataset.metadata['data_type'] == 'descriptors': + original_labels = dataset['labels'] + else: + original_labels = torch.Tensor([dataset['data_list'][i]['graph_labels'] for i in range(len(dataset))]) + + n_labels = len(torch.unique(original_labels)) if n_labels != len(data_groups): raise(ValueError(f'The number of labels ({n_labels}) and data groups ({len(data_groups)}) do not match! Ensure you are correctly mapping the data in your training set!')) - # TODO sign if not from committor bias weights = torch.exp(beta * bias) - new_labels = torch.zeros_like(dataset['labels']) + new_labels = torch.zeros_like(original_labels) data_groups = torch.Tensor(data_groups) # correct data labels according to iteration for j,index in enumerate(data_groups): - new_labels[torch.nonzero(dataset['labels'] == j, as_tuple=True)] = index + new_labels[torch.nonzero(original_labels == j, as_tuple=True)] = index for i in np.unique(data_groups): # compute average of exp(beta*V) on this simualtions @@ -90,10 +120,15 @@ def compute_committor_weights(dataset : DictDataset, # update the weights weights[torch.nonzero(new_labels == i, as_tuple=True)] = coeff * weights[torch.nonzero(new_labels == i, as_tuple=True)] - + # update dataset - dataset['weights'] = weights - dataset['labels'] = new_labels + if dataset.metadata['data_type'] == 'descriptors': + dataset['weights'] = weights + dataset['labels'] = new_labels + else: + for i in range(len(dataset)): + dataset['data_list'][i]['weight'] = weights[i] + dataset['data_list'][i]['graph_labels'] = new_labels[i] return dataset @@ -123,4 +158,72 @@ def initialize_committor_masses(atom_types: list, masses: list): # make it a tensor atomic_masses = torch.Tensor(atomic_masses) - return atomic_masses \ No newline at end of file + return atomic_masses + +def test_Kolmogorov_bias(): + # test on feed forward + from mlcolvar import DeepTDA + model = DeepTDA(n_states=2, + n_cvs=1, + target_centers=[-1,1], + target_sigmas=[0.1, 0.1], + model=[4,2,1]) + inp = torch.randn((10, 4)) + model_bias = KolmogorovBias(input_model=model, beta=1.0) + model_bias(inp) + + # test on GNN + from mlcolvar.core.nn.graph import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + + dataset = create_test_graph_input('dataset') + inp = dataset.get_graph_inputs() + + gnn_model = SchNetModel(n_out=1, + cutoff=0.1, + atomic_numbers=[1,8]) + + model = DeepTDA(n_states=2, + n_cvs=1, + target_centers=[-1,1], + target_sigmas=[0.1, 0.1], + model=gnn_model) + + model_bias = KolmogorovBias(input_model=model, beta=1.0) + model_bias(inp) + + +def test_compute_committor_weights(): + # descriptors + # create dataset + samples = 50 + X = torch.randn((3*samples, 6)) + + # create labels, bias and weights + y = torch.zeros(X.shape[0]) + y[samples:] += 1 + y[int(2*samples):] += 1 + bias = torch.zeros(X.shape[0]) + w = torch.zeros(X.shape[0]) + + # create and edit dataset + dataset = DictDataset({"data": X, "labels": y, "weights": w}) + dataset = compute_committor_weights(dataset=dataset, bias=bias, data_groups=[0,1,2], beta=1.0) + print(dataset) + assert (torch.allclose(dataset['weights'], torch.ones(X.shape[0]))) + + + # graphs + # create dataset + from mlcolvar.data.graph.utils import create_test_graph_input + dataset = create_test_graph_input('dataset', n_states=4, random_weights=True) + bias = torch.zeros(len(dataset)) + dataset = compute_committor_weights(dataset=dataset, bias=bias, data_groups=[0,1,2,3], beta=1) + aux = [] + for i in range(len(dataset)): + aux.append(dataset['data_list'][i]['weight']) + assert (torch.allclose(torch.ones(len(dataset)), torch.Tensor(aux))) + +if __name__ == '__main__': + test_Kolmogorov_bias() + test_compute_committor_weights() \ No newline at end of file diff --git a/mlcolvar/cvs/cv.py b/mlcolvar/cvs/cv.py index 7bfd95c5..0cff165d 100644 --- a/mlcolvar/cvs/cv.py +++ b/mlcolvar/cvs/cv.py @@ -1,5 +1,8 @@ import torch from mlcolvar.core.transform import Transform +from typing import Union, List +from mlcolvar.core.nn import FeedForward, BaseGNN +from mlcolvar.data.graph.utils import create_graph_tracing_example class BaseCV: @@ -9,10 +12,12 @@ class BaseCV: To inherit from this class, the class must define a BLOCKS class attribute. """ + DEFAULT_BLOCKS = [] + MODEL_BLOCKS = [] + def __init__( self, - in_features, - out_features, + model: Union[List[int], FeedForward, BaseGNN], preprocessing: torch.nn.Module = None, postprocessing: torch.nn.Module = None, *args, @@ -22,10 +27,6 @@ def __init__( Parameters ---------- - in_features : int - Number of inputs of the CV model - out_features : int - Number of outputs of the CV model, should be the number of CVs preprocessing : torch.nn.Module, optional Preprocessing module, default None postprocessing : torch.nn.Module, optional @@ -35,13 +36,13 @@ def __init__( super().__init__(*args, **kwargs) # The parent class sets in_features and out_features based on their own - # init arguments so we don't need to save them here (see #103). + # init arguments so we don't need to save them here (see #103). + # It is needed for compatibility with multiclass CVs self.save_hyperparameters(ignore=['in_features', 'out_features']) # MODEL + self.parse_model(model=model) self.initialize_blocks() - self.in_features = in_features - self.out_features = out_features # OPTIM self._optimizer_name = "Adam" @@ -59,12 +60,39 @@ def n_cvs(self): @property def example_input_array(self): - return torch.randn( - (1,self.in_features) - if self.preprocessing is None - or not hasattr(self.preprocessing, "in_features") - else self.preprocessing.in_features - ) + if self.in_features is not None: + return torch.randn( + (1,self.in_features) + if self.preprocessing is None + or not hasattr(self.preprocessing, "in_features") + else self.preprocessing.in_features + ) + else: + return create_graph_tracing_example(n_species=len(self.atomic_numbers)) + + + # TODO add general torch.nn.Module + def parse_model(self, model: Union[List[int], FeedForward, BaseGNN]): + if isinstance(model, list): + self.layers = model + self.BLOCKS = self.DEFAULT_BLOCKS + self._override_model = False + self.in_features = self.layers[0] + self.out_features = self.layers[-1] + elif isinstance(model, FeedForward) or isinstance(model, BaseGNN): + self.BLOCKS = self.MODEL_BLOCKS + self._override_model = True + self.in_features = model.in_features + self.out_features = model.out_features + # save buffers for the interface for PLUMED + if isinstance(model, BaseGNN): + self.register_buffer('n_out', model.n_out) + self.register_buffer('cutoff', model.cutoff) + self.register_buffer('atomic_numbers', model.atomic_numbers) + else: + raise ValueError( + f"Keyword model can either accept type list, FeedForward or BaseGNN. Found {type(model)}" + ) def parse_options(self, options: dict = None): """ @@ -78,7 +106,13 @@ def parse_options(self, options: dict = None): """ if options is None: options = {} - + else: + for o in options.keys(): + if o in self.DEFAULT_BLOCKS and self._override_model: + raise ValueError( + "Options on blocks are disabled if a model is provided!" + ) + for b in self.BLOCKS: options.setdefault(b, {}) @@ -225,3 +259,9 @@ def __setattr__(self, key, value): if (key == "loss_fn") and ("cannot assign" in str(e)): del self.loss_fn super().__setattr__(key, value) + + def _setup_graph_data(self, train_batch, key : str='data_list'): + data = train_batch[key] + data['positions'].requires_grad_(True) + data['node_attrs'].requires_grad_(True) + return data \ No newline at end of file diff --git a/mlcolvar/cvs/generator/generator.py b/mlcolvar/cvs/generator/generator.py index 6ce151ea..0d19b94c 100644 --- a/mlcolvar/cvs/generator/generator.py +++ b/mlcolvar/cvs/generator/generator.py @@ -35,7 +35,7 @@ class Generator(BaseCV, lightning.LightningModule): """ - BLOCKS = ["nn"] + DEFAULT_BLOCKS = ["nn"] def __init__(self, r: int, @@ -79,7 +79,7 @@ def __init__(self, Options for the building blocks of the model, by default {}. Available blocks: ['nn'] . """ - super().__init__(in_features=layers[0], out_features=r, **kwargs) + super().__init__(model=layers, **kwargs) # ======= LOSS ======= self.loss_fn = GeneratorLoss(r=r, diff --git a/mlcolvar/cvs/multitask/multitask.py b/mlcolvar/cvs/multitask/multitask.py index 90712122..1be64eaa 100644 --- a/mlcolvar/cvs/multitask/multitask.py +++ b/mlcolvar/cvs/multitask/multitask.py @@ -19,10 +19,11 @@ import torch from mlcolvar.cvs.cv import BaseCV +from mlcolvar.core.nn import BaseGNN # ============================================================================= -# VARIATIONAL AUTOENCODER CV +# MULTITASK CV # ============================================================================= @@ -98,6 +99,10 @@ def __init__( has always coefficient 1). """ + # check if model is GNN, not implemented yet TODO + if hasattr(main_cv, "nn") and isinstance(main_cv.nn, BaseGNN): + raise NotImplementedError('Multitask not supported (yet) for GNN-based CVs') + # This changes dynamically the class of this object to inherit both from # MultiTaskCV and main_cv.__class__ so that we can access all members of # main_cv and still be able to override some of them. diff --git a/mlcolvar/cvs/supervised/deeplda.py b/mlcolvar/cvs/supervised/deeplda.py index 55d949ff..cf754e29 100644 --- a/mlcolvar/cvs/supervised/deeplda.py +++ b/mlcolvar/cvs/supervised/deeplda.py @@ -1,10 +1,12 @@ import torch import lightning from mlcolvar.cvs import BaseCV -from mlcolvar.core import FeedForward, Normalization +from mlcolvar.core import FeedForward, BaseGNN, Normalization from mlcolvar.data import DictModule from mlcolvar.core.stats import LDA from mlcolvar.core.loss import ReduceEigenvaluesLoss +from typing import Union, List + __all__ = ["DeepLDA"] @@ -14,7 +16,9 @@ class DeepLDA(BaseCV, lightning.LightningModule): Non-linear generalization of LDA in which a feature map is learned by a neural network optimized as to maximize the classes separation. The method is described in [1]_. - **Data**: for training it requires a DictDataset with the keys 'data' and 'labels'. + **Data**: for training it requires a DictDataset containing: + - If using descriptors as input, the keys 'data' and 'labels' + - If using graphs as input, `torch_geometric.data` with 'graph_labels' in their 'data_list'. **Loss**: maximize LDA eigenvalues (ReduceEigenvaluesLoss) @@ -31,9 +35,14 @@ class DeepLDA(BaseCV, lightning.LightningModule): Eigenvalue reduction to a scalar quantity """ - BLOCKS = ["norm_in", "nn", "lda"] + DEFAULT_BLOCKS = ["norm_in", "nn", "lda"] + MODEL_BLOCKS = ["nn", "lda"] - def __init__(self, layers: list, n_states: int, options: dict = None, **kwargs): + def __init__(self, + model: Union[List[int], FeedForward, BaseGNN], + n_states: int, + options: dict = None, + **kwargs): """ Define a Deep Linear Discriminant Analysis (Deep-LDA) CV composed by a neural network module and a LDA object. @@ -41,8 +50,13 @@ def __init__(self, layers: list, n_states: int, options: dict = None, **kwargs): Parameters ---------- - layers : list - Number of neurons per layer + model : list or FeedForward or BaseGNN + Determines the underlying machine-learning model. One can pass: + 1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN. + The model Will be automatically intialized using a `mlcolvar.core.nn.feedforward.FeedForward` object. + The CV class will be initialized according to the DEFAULT_BLOCKS. + 2. An externally intialized model (either `mlcolvar.core.nn.feedforward.FeedForward` or `mlcolvar.core.nn.graph.BaseGNN` object). + The CV class will be initialized according to the MODEL_BLOCKS. n_states : int Number of states for the training options : dict[str, Any], optional @@ -50,7 +64,8 @@ def __init__(self, layers: list, n_states: int, options: dict = None, **kwargs): Available blocks: ['norm_in','nn','lda'] . Set 'block_name' = None or False to turn off that block """ - super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs) + super().__init__(model=model, **kwargs) + self.save_hyperparameters(ignore=['model']) # ======= LOSS ======= # Maximize the sum of all the LDA eigenvalues. @@ -65,26 +80,31 @@ def __init__(self, layers: list, n_states: int, options: dict = None, **kwargs): # ======= BLOCKS ======= - # initialize norm_in - o = "norm_in" - if (options[o] is not False) and (options[o] is not None): - self.norm_in = Normalization(self.in_features, **options[o]) + if not self._override_model: + # initialize norm_in + o = "norm_in" + if (options[o] is not False) and (options[o] is not None): + self.norm_in = Normalization(self.in_features, **options[o]) + + # initialize nn + o = "nn" + self.nn = FeedForward(self.layers, **options[o]) - # initialize nn - o = "nn" - self.nn = FeedForward(layers, **options[o]) + elif self._override_model: + self.nn = model # initialize lda o = "lda" - self.lda = LDA(layers[-1], n_states, **options[o]) + self.lda = LDA(self.nn.out_features, n_states, **options[o]) # regularization self.lorentzian_reg = 40 # == 2/sw_reg, see set_regularization self.set_regularization(sw_reg=0.05) def forward_nn(self, x: torch.Tensor) -> torch.Tensor: - if self.norm_in is not None: - x = self.norm_in(x) + if not self._override_model: + if self.norm_in is not None: + x = self.norm_in(x) x = self.nn(x) return x @@ -137,13 +157,19 @@ def regularization_lorentzian(self, x): def training_step(self, train_batch, batch_idx): """Compute and return the training loss and record metrics.""" # =================get data=================== - x = train_batch["data"] - y = train_batch["labels"] + if isinstance(self.nn, FeedForward): + x = train_batch["data"] + labels = train_batch["labels"] + elif isinstance(self.nn, BaseGNN): + x = self._setup_graph_data(train_batch) + labels = x['graph_labels'].squeeze() + # =================forward==================== h = self.forward_nn(x) + # ===================lda====================== eigvals, _ = self.lda.compute( - h, y, save_params=True if self.training else False + h, labels, save_params=True if self.training else False ) # ===================loss===================== loss = self.loss_fn(eigvals) @@ -151,6 +177,7 @@ def training_step(self, train_batch, batch_idx): s = self.lda(h) lorentzian_reg = self.regularization_lorentzian(s) loss += lorentzian_reg + # ====================log===================== name = "train" if self.training else "valid" loss_dict = {f"{name}_loss": loss, f"{name}_lorentzian_reg": lorentzian_reg} @@ -164,7 +191,7 @@ def test_deeplda(n_states=2): in_features, out_features = 2, n_states - 1 layers = [in_features, 50, 50, out_features] - + # create dataset n_points = 500 X, y = [], [] @@ -187,6 +214,9 @@ def test_deeplda(n_states=2): "nn": {"activation": "relu"}, "lda": {}, } + print() + print('NORMAL') + print() model = DeepLDA(layers, n_states, options=opts) # create trainer and fit @@ -200,6 +230,55 @@ def test_deeplda(n_states=2): with torch.no_grad(): s = model(X).numpy() + + # feedforward external + print() + print('EXTERNAL') + print() + ff_model = FeedForward(layers=layers) + model = DeepLDA(ff_model, n_states) + + # create trainer and fit + trainer = lightning.Trainer( + max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False + ) + trainer.fit(model, datamodule) + + # eval + model.eval() + with torch.no_grad(): + s = model(X).numpy() + print(s) + + # gnn external + print() + print('GNN') + print() + from mlcolvar.core.nn.graph.schnet import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + gnn_model = SchNetModel(2, 0.1, [1, 8]) + model = DeepLDA(gnn_model, n_states) + + datamodule = create_test_graph_input(output_type='datamodule', n_samples=200, n_states=n_states) + + # create trainer and fit + trainer = lightning.Trainer( + max_epochs=1, log_every_n_steps=2, logger=False, enable_checkpointing=False, enable_model_summary=False + ) + trainer.fit(model, datamodule) + + traced_model = model.to_torchscript( + file_path=None, method="trace") + + example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=n_states) + assert torch.allclose(model(example_input_graph_test), traced_model(example_input_graph_test)) + + # eval + model.eval() + with torch.no_grad(): + s = model(example_input_graph_test).numpy() + print(s) + if __name__ == "__main__": test_deeplda(n_states=2) diff --git a/mlcolvar/cvs/supervised/deeptda.py b/mlcolvar/cvs/supervised/deeptda.py index e58f215c..1dfeb321 100644 --- a/mlcolvar/cvs/supervised/deeptda.py +++ b/mlcolvar/cvs/supervised/deeptda.py @@ -1,35 +1,36 @@ import torch import lightning from mlcolvar.cvs import BaseCV -from mlcolvar.core import FeedForward, Normalization +from mlcolvar.core import FeedForward, BaseGNN, Normalization from mlcolvar.core.loss import TDALoss from mlcolvar.data import DictModule +from typing import Union, List __all__ = ["DeepTDA"] - class DeepTDA(BaseCV, lightning.LightningModule): """ Deep Targeted Discriminant Analysis (Deep-TDA) CV. Combine the inputs with a neural-network and optimize it in a way such that the data are distributed accordingly to a mixture of Gaussians. The method is described in [1]_. - - **Data**: for training it requires a DictDataset with the keys 'data' and 'labels'. - + **Data**: for training it requires a DictDataset containing: + - If using descriptors as input, the keys 'data' and 'labels'. + - If using graphs as input, `torch_geometric.data` with 'graph_labels' in their 'data_list'. + **Loss**: distance of the samples of each class from a set of Gaussians (TDALoss) - References ---------- .. [1] E. Trizio and M. Parrinello, "From enhanced sampling to reaction profiles", The Journal of Physical Chemistry Letters 12, 8621– 8626 (2021). - See also -------- mlcolvar.core.loss.TDALoss Distance from a simple Gaussian target distribution. """ - BLOCKS = ["norm_in", "nn"] + DEFAULT_BLOCKS = ["norm_in", "nn"] + MODEL_BLOCKS = ["nn"] + # TODO n_states optional? def __init__( @@ -38,14 +39,13 @@ def __init__( n_cvs: int, target_centers: list, target_sigmas: list, - layers: list, + model: Union[List[int], FeedForward, BaseGNN], options: dict = None, **kwargs, ): """ Define Deep Targeted Discriminant Analysis (Deep-TDA) CV composed by a neural network module. By default a module standardizing the inputs is also used. - Parameters ---------- n_states : int @@ -56,15 +56,20 @@ def __init__( Centers of the Gaussian targets target_sigmas : list Standard deviations of the Gaussian targets - layers : list - Number of neurons per layer + model : list or FeedForward or BaseGNN + Determines the underlying machine-learning model. One can pass: + 1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN. + The model Will be automatically intialized using a `mlcolvar.core.nn.feedforward.FeedForward` object. + The CV class will be initialized according to the DEFAULT_BLOCKS. + 2. An externally intialized model (either `mlcolvar.core.nn.feedforward.FeedForward` or `mlcolvar.core.nn.graph.BaseGNN` object). + The CV class will be initialized according to the MODEL_BLOCKS. options : dict[str, Any], optional Options for the building blocks of the model, by default {}. Available blocks: ['norm_in', 'nn']. Set 'block_name' = None or False to turn off that block - """ - - super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs) + """ + super().__init__(model, **kwargs) + self.save_hyperparameters(ignore=['model']) # ======= LOSS ======= self.loss_fn = TDALoss( @@ -106,43 +111,54 @@ def __init__( ) # ======= BLOCKS ======= - - # Initialize norm_in - o = "norm_in" - if (options[o] is not False) and (options[o] is not None): - self.norm_in = Normalization(self.in_features, **options[o]) - - # initialize NN - o = "nn" - self.nn = FeedForward(layers, **options[o]) - - def training_step(self, train_batch, batch_idx): + if not self._override_model: + # Initialize norm_in + o = "norm_in" + if (options[o] is not False) and (options[o] is not None): + self.norm_in = Normalization(self.in_features, **options[o]) + + # initialize NN + o = "nn" + self.nn = FeedForward(self.layers, **options[o]) + elif self._override_model: + self.nn = model + + def training_step(self, train_batch, *args, **kwargs) -> torch.Tensor: """Compute and return the training loss and record metrics.""" # =================get data=================== - x = train_batch["data"] - labels = train_batch["labels"] + if isinstance(self.nn, FeedForward): + x = train_batch["data"] + labels = train_batch["labels"] + elif isinstance(self.nn, BaseGNN): + x = self._setup_graph_data(train_batch) + labels = x['graph_labels'].squeeze() + # =================forward==================== z = self.forward_cv(x) + # ===================loss===================== - loss, loss_centers, loss_sigmas = self.loss_fn( - z, labels, return_loss_terms=True - ) - # ====================log=====================+ + loss, loss_centers, loss_sigmas = self.loss_fn(z, + labels, + return_loss_terms=True + ) + + # ====================log===================== name = "train" if self.training else "valid" self.log(f"{name}_loss", loss, on_epoch=True) self.log(f"{name}_loss_centers", loss_centers, on_epoch=True) self.log(f"{name}_loss_sigmas", loss_sigmas, on_epoch=True) + return loss -# TODO signature of tests? import numpy as np - def test_deeptda_cv(): from mlcolvar.data import DictDataset + # feedforward with layers for states_and_cvs in [[2, 1], [3, 1], [3, 2], [5, 4]]: + print(states_and_cvs) # get the number of states and cvs for the test run n_states = states_and_cvs[0] n_cvs = states_and_cvs[1] @@ -155,18 +171,18 @@ def test_deeptda_cv(): # test initialize via dictionary options = {"nn": {"activation": "relu"}} + print() + print('NORMAL') + print() model = DeepTDA( n_states=n_states, n_cvs=n_cvs, target_centers=target_centers, target_sigmas=target_sigmas, - layers=layers, + model=layers, options=options, ) - print("----------") - print(model) - # create dataset samples = 100 X = torch.randn((samples * n_states, 2)) @@ -180,17 +196,71 @@ def test_deeptda_cv(): datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=samples) # train model trainer = lightning.Trainer( - accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False + accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False, enable_model_summary=False ) trainer.fit(model, datamodule) # trace model traced_model = model.to_torchscript( - file_path=None, method="trace", example_inputs=X[0] + file_path=None, method="trace") + model.eval() + assert torch.allclose(model(X), traced_model(X)) + + print() + print('EXTERNAL FEEDFORWARD') + print() + # feedforward external + ff_model = FeedForward(layers=layers) + model = DeepTDA( + n_states=n_states, + n_cvs=n_cvs, + target_centers=target_centers, + target_sigmas=target_sigmas, + model=ff_model + ) + + # train model + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False, enable_model_summary=False ) + trainer.fit(model, datamodule) + + # trace model + traced_model = model.to_torchscript( + file_path=None, method="trace") model.eval() assert torch.allclose(model(X), traced_model(X)) + print() + print('EXTERNAL GNN') + print() + # gnn external + from mlcolvar.core.nn.graph.schnet import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + gnn_model = SchNetModel(n_cvs, 0.1, [1, 8]) + model = DeepTDA( + n_states=n_states, + n_cvs=n_cvs, + target_centers=target_centers, + target_sigmas=target_sigmas, + model=gnn_model + ) + datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=n_states) + + # train model + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=2, logger=False, enable_checkpointing=False, enable_model_summary=False + ) + trainer.fit(model, datamodule) + + # trace model + traced_model = model.to_torchscript( + file_path=None, method="trace") + + # check on a different number of atoms + example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) + assert torch.allclose(model(example_input_graph_test), traced_model(example_input_graph_test)) if __name__ == "__main__": test_deeptda_cv() + diff --git a/mlcolvar/cvs/supervised/regression.py b/mlcolvar/cvs/supervised/regression.py index 1ea52331..558db234 100644 --- a/mlcolvar/cvs/supervised/regression.py +++ b/mlcolvar/cvs/supervised/regression.py @@ -1,8 +1,10 @@ import torch import lightning from mlcolvar.cvs import BaseCV -from mlcolvar.core import FeedForward, Normalization +from mlcolvar.core import FeedForward, Normalization, BaseGNN from mlcolvar.core.loss import MSELoss +from typing import Union, List + __all__ = ["RegressionCV"] @@ -12,8 +14,10 @@ class RegressionCV(BaseCV, lightning.LightningModule): Example of collective variable obtained with a regression task. Combine the inputs with a neural-network and optimize it to match a target function. - **Data**: for training it requires a DictDataset with the keys 'data' and 'target' and optionally 'weights'. - + **Data**: for training it requires a DictDataset containing: + - If using descriptors as input, the keys 'data', 'target' and optionally 'weights'. + - If using graphs as input, `torch_geometric.data` with 'graph_labels' with the target values and 'weight' in their 'data_list'. + **Loss**: least squares (MSELoss). See also @@ -22,22 +26,34 @@ class RegressionCV(BaseCV, lightning.LightningModule): (weighted) Mean Squared Error (MSE) loss function. """ - BLOCKS = ["norm_in", "nn"] + DEFAULT_BLOCKS = ["norm_in", "nn"] + MODEL_BLOCKS = ["nn"] - def __init__(self, layers: list, options: dict = None, **kwargs): + + def __init__(self, + model: Union[List[int], FeedForward, BaseGNN], + options: dict = None, + **kwargs): """Example of collective variable obtained with a regression task. By default a module standardizing the inputs is used. Parameters ---------- - layers : list - Number of neurons per layer + model : list or FeedForward or BaseGNN + Determines the underlying machine-learning model. One can pass: + 1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN. + The model Will be automatically intialized using a `mlcolvar.core.nn.feedforward.FeedForward` object. + The CV class will be initialized according to the DEFAULT_BLOCKS. + 2. An externally intialized model (either `mlcolvar.core.nn.feedforward.FeedForward` or `mlcolvar.core.nn.graph.BaseGNN` object). + The CV class will be initialized according to the MODEL_BLOCKS. options : dict[str, Any], optional Options for the building blocks of the model, by default None. Available blocks: ['norm_in', 'nn']. Set 'block_name' = None or False to turn off that block """ - super().__init__(in_features=layers[0], out_features=layers[-1], **kwargs) + super().__init__(model, **kwargs) + self.save_hyperparameters(ignore=['model']) + # ======= LOSS ======= self.loss_fn = MSELoss() @@ -46,25 +62,38 @@ def __init__(self, layers: list, options: dict = None, **kwargs): # parse and sanitize options = self.parse_options(options) - # Initialize norm_in - o = "norm_in" - if (options[o] is not False) and (options[o] is not None): - self.norm_in = Normalization(self.in_features, **options[o]) + # ======= BLOCKS ======= + if not self._override_model: + # Initialize norm_in + o = "norm_in" + if (options[o] is not False) and (options[o] is not None): + self.norm_in = Normalization(self.in_features, **options[o]) - # initialize NN - o = "nn" - self.nn = FeedForward(layers, **options[o]) + # initialize NN + o = "nn" + self.nn = FeedForward(self.layers, **options[o]) + elif self._override_model: + self.nn = model def training_step(self, train_batch, batch_idx): """Compute and return the training loss and record metrics.""" # =================get data=================== - x = train_batch["data"] - labels = train_batch["target"] loss_kwargs = {} - if "weights" in train_batch: - loss_kwargs["weights"] = train_batch["weights"] + if isinstance(self.nn, FeedForward): + x = train_batch["data"] + labels = train_batch["target"] + if "weights" in train_batch: + loss_kwargs["weights"] = train_batch["weights"] + elif isinstance(self.nn, BaseGNN): + x = self._setup_graph_data(train_batch) + # TODO maybe add an external key like target? + labels = x['graph_labels'].squeeze() + if "weights" in x: + loss_kwargs["weights"] = x["weights"] + # =================forward==================== y = self.forward_cv(x) + # ===================loss===================== loss = self.loss_fn(y, labels, **loss_kwargs) # ====================log===================== @@ -82,10 +111,13 @@ def test_regression_cv(): in_features, out_features = 2, 1 layers = [in_features, 5, 10, out_features] + print() + print('NORMAL') + print() # initialize via dictionary options = {"nn": {"activation": "relu"}} - model = RegressionCV(layers=layers, options=options) + model = RegressionCV(model=layers, options=options) print("----------") print(model) @@ -123,7 +155,89 @@ def test_regression_cv(): accelerator="cpu", max_epochs=1, logger=None, enable_checkpointing=False ) - model = RegressionCV(layers=[2, 10, 10, 1]) + model = RegressionCV(model=[2, 10, 10, 1]) + model.loss_fn = lambda y, y_ref: (y - y_ref).abs().mean() + trainer.fit(model, datamodule) + + print() + print('EXTERNAL FEEDFORWARD') + print() + ff_model = FeedForward(layers=layers) + # create model + model = RegressionCV(model=ff_model) + + # create dataset + X = torch.randn((100, 2)) + y = X.square().sum(1) + dataset = DictDataset({"data": X, "target": y}) + datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=25) + # train model + model.optimizer_name = "SGD" + model.optimizer_kwargs.update(dict(lr=1e-2)) + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=1, logger=None, enable_checkpointing=False + ) + trainer.fit(model, datamodule) + model.eval() + # trace model + traced_model = model.to_torchscript( + file_path=None, method="trace", example_inputs=X[0] + ) + assert torch.allclose(model(X), traced_model(X)) + + # weighted loss + print("weighted loss") + w = torch.randn((100)) + dataset_weights = DictDataset({"data": X, "target": y, "weights": w}) + datamodule_weights = DictModule( + dataset_weights, lengths=[0.75, 0.2, 0.05], batch_size=25 + ) + trainer.fit(model, datamodule_weights) + + # use custom loss + print("custom loss") + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=1, logger=None, enable_checkpointing=False + ) + + model = RegressionCV(model=ff_model) + model.loss_fn = lambda y, y_ref: (y - y_ref).abs().mean() + trainer.fit(model, datamodule) + + print() + print('EXTERNAL GNN') + print() + # gnn external + from mlcolvar.core.nn.graph.schnet import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + gnn_model = SchNetModel(1, 0.1, [1, 8]) + # create model + model = RegressionCV(model=gnn_model) + + datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=2) + # train model + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=1, logger=False, enable_checkpointing=False, enable_model_summary=False + ) + trainer.fit(model, datamodule) + model.eval() + # trace model + traced_model = model.to_torchscript(file_path=None, method="trace") + example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) + assert torch.allclose(model(example_input_graph_test), traced_model(example_input_graph_test)) + + # weighted loss + print("weighted loss") + datamodule_weights = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=2, random_weights=True) + trainer.fit(model, datamodule_weights) + + # use custom loss + print("custom loss") + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=1, logger=False, enable_checkpointing=False, enable_model_summary=False + ) + + model = RegressionCV(model=gnn_model) model.loss_fn = lambda y, y_ref: (y - y_ref).abs().mean() trainer.fit(model, datamodule) diff --git a/mlcolvar/cvs/timelagged/deeptica.py b/mlcolvar/cvs/timelagged/deeptica.py index b95c7547..428c7b41 100644 --- a/mlcolvar/cvs/timelagged/deeptica.py +++ b/mlcolvar/cvs/timelagged/deeptica.py @@ -1,9 +1,10 @@ import torch import lightning from mlcolvar.cvs import BaseCV -from mlcolvar.core import FeedForward, Normalization +from mlcolvar.core import FeedForward, BaseGNN, Normalization from mlcolvar.core.stats import TICA from mlcolvar.core.loss import ReduceEigenvaluesLoss +from typing import Union, List __all__ = ["DeepTICA"] @@ -16,10 +17,12 @@ class DeepTICA(BaseCV, lightning.LightningModule): approximated by TICA. The method is described in [1]_. Note that from the point of view of the architecture DeepTICA is similar to the SRV [2] method. - **Data**: for training it requires a DictDataset with the keys 'data' (input at time t) - and 'data_lag' (input at time t+lag), as well as the corresponding 'weights' and - 'weights_lag' which will be used to weight the time correlation functions. - This can be created with the helper function `create_timelagged_dataset`. + **Data**: for training it requires a DictDataset containing: + - If using descriptors as input, the keys 'data' (input at time t) + and 'data_lag' (input at time t+lag), as well as the corresponding 'weights' and + 'weights_lag' which will be used to weight the time correlation functions. + - If using graphs as input, the keys 'data_list' and 'data_list_lag', each containing the respective 'weight' + This can be created in both cases with the helper function `create_timelagged_dataset`. **Loss**: maximize TICA eigenvalues (ReduceEigenvaluesLoss) @@ -40,17 +43,26 @@ class DeepTICA(BaseCV, lightning.LightningModule): Create dataset of time-lagged data. """ - BLOCKS = ["norm_in", "nn", "tica"] + DEFAULT_BLOCKS = ["norm_in", "nn", "tica"] + MODEL_BLOCKS = ["nn", "tica"] - def __init__(self, layers: list, n_cvs: int = None, options: dict = None, **kwargs): + def __init__(self, + model: Union[List[int], FeedForward, BaseGNN], + n_cvs: int = None, + options: dict = None, **kwargs): """ Define a Deep-TICA CV, composed of a neural network module and a TICA object. By default a module standardizing the inputs is also used. Parameters ---------- - layers : list - Number of neurons per layer + model : list or FeedForward or BaseGNN + Determines the underlying machine-learning model. One can pass: + 1. A list of integers corresponding to the number of neurons per layer of a feed-forward NN. + The model Will be automatically intialized using a `mlcolvar.core.nn.feedforward.FeedForward` object. + The CV class will be initialized according to the DEFAULT_BLOCKS. + 2. An externally intialized model (either `mlcolvar.core.nn.feedforward.FeedForward` or `mlcolvar.core.nn.graph.BaseGNN` object). + The CV class will be initialized according to the MODEL_BLOCKS. n_cvs : int, optional Number of cvs to optimize, default None (= last layer) options : dict[str, Any], optional @@ -58,15 +70,13 @@ def __init__(self, layers: list, n_cvs: int = None, options: dict = None, **kwar Available blocks: ['norm_in','nn','tica']. Set 'block_name' = None or False to turn off that block """ - super().__init__( - in_features=layers[0], - out_features=n_cvs if n_cvs is not None else layers[-1], - **kwargs, - ) + super().__init__(model, **kwargs) # ======= LOSS ======= # Maximize the squared sum of all the TICA eigenvalues. self.loss_fn = ReduceEigenvaluesLoss(mode="sum2") + # here we need to override the self.out_features attribute + self.out_features = n_cvs # ======= OPTIONS ======= # parse and sanitize @@ -74,22 +84,27 @@ def __init__(self, layers: list, n_cvs: int = None, options: dict = None, **kwar # ======= BLOCKS ======= - # initialize norm_in - o = "norm_in" - if (options[o] is not False) and (options[o] is not None): - self.norm_in = Normalization(self.in_features, **options[o]) + if not self._override_model: + # initialize norm_in + o = "norm_in" + if (options[o] is not False) and (options[o] is not None): + self.norm_in = Normalization(self.in_features, **options[o]) - # initialize nn - o = "nn" - self.nn = FeedForward(layers, **options[o]) + # initialize nn + o = "nn" + self.nn = FeedForward(self.layers, **options[o]) + + elif self._override_model: + self.nn = model - # initialize lda + # initialize tica o = "tica" - self.tica = TICA(layers[-1], n_cvs, **options[o]) + self.tica = TICA(self.nn.out_features, n_cvs, **options[o]) def forward_nn(self, x: torch.Tensor) -> torch.Tensor: - if self.norm_in is not None: - x = self.norm_in(x) + if not self._override_model: + if self.norm_in is not None: + x = self.norm_in(x) x = self.nn(x) return x @@ -111,10 +126,17 @@ def training_step(self, train_batch, batch_idx): 3) Compute TICA """ # =================get data=================== - x_t = train_batch["data"] - x_lag = train_batch["data_lag"] - w_t = train_batch["weights"] - w_lag = train_batch["weights_lag"] + if isinstance(self.nn, FeedForward): + x_t = train_batch["data"] + x_lag = train_batch["data_lag"] + w_t = train_batch["weights"] + w_lag = train_batch["weights_lag"] + elif isinstance(self.nn, BaseGNN): + x_t = self._setup_graph_data(train_batch, key='data_list') + x_lag = self._setup_graph_data(train_batch, key='data_list_lag') + w_t = x_t['weight'] + w_lag = x_lag['weight'] + # =================forward==================== f_t = self.forward_nn(x_t) f_lag = self.forward_nn(x_lag) @@ -139,12 +161,15 @@ def test_deep_tica(): from mlcolvar.utils.timelagged import create_timelagged_dataset # create dataset - X = np.loadtxt("mlcolvar/tests/data/mb-mcmc.dat") - X = torch.Tensor(X) + # X = np.loadtxt("mlcolvar/tests/data/mb-mcmc.dat") + X = torch.randn((10000, 2)) dataset = create_timelagged_dataset(X, lag_time=1) datamodule = DictModule(dataset, batch_size=10000) # create cv + print() + print('NORMAL') + print() layers = [2, 10, 10, 2] model = DeepTICA(layers, n_cvs=1) @@ -163,5 +188,55 @@ def test_deep_tica(): print(X.shape, "-->", s.shape) + print() + print('EXTERNAL') + print() + ff_model = FeedForward(layers=layers) + model = DeepTICA(ff_model, n_cvs=1) + + # change loss options + model.loss_fn.mode = "sum2" + + # create trainer and fit + trainer = lightning.Trainer( + max_epochs=1, log_every_n_steps=2, logger=None, enable_checkpointing=False + ) + trainer.fit(model, datamodule) + + model.eval() + with torch.no_grad(): + s = model(X).numpy() + print(X.shape, "-->", s.shape) + + + # gnn external + print() + print('GNN') + print() + from mlcolvar.core.nn.graph.schnet import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + gnn_model = SchNetModel(2, 0.1, [1, 8]) + model = DeepTICA(gnn_model, n_cvs=1) + + # change loss options + model.loss_fn.mode = "sum2" + + # create trainer and fit + trainer = lightning.Trainer( + max_epochs=1, log_every_n_steps=2, logger=False, enable_checkpointing=False, enable_model_summary=False, + ) + + dataset = create_test_graph_input(output_type='dataset', n_samples=200, n_states=2) + lagged_dataset = create_timelagged_dataset(dataset, logweights=torch.randn(len(dataset))) + + datamodule = DictModule(dataset=lagged_dataset) + trainer.fit(model, datamodule) + + model.eval() + with torch.no_grad(): + example_input_graph_test = create_test_graph_input(output_type='example', n_atoms=4, n_samples=3, n_states=2) + s = model(example_input_graph_test).numpy() + print(X.shape, "-->", s.shape) + if __name__ == "__main__": test_deep_tica() diff --git a/mlcolvar/cvs/unsupervised/autoencoder.py b/mlcolvar/cvs/unsupervised/autoencoder.py index bb9839bb..ac8a7741 100644 --- a/mlcolvar/cvs/unsupervised/autoencoder.py +++ b/mlcolvar/cvs/unsupervised/autoencoder.py @@ -41,7 +41,7 @@ class AutoEncoderCV(BaseCV, lightning.LightningModule): (weighted) Mean Squared Error (MSE) loss function. """ - BLOCKS = ["norm_in", "encoder", "decoder"] + DEFAULT_BLOCKS = ["norm_in", "encoder", "decoder"] def __init__( self, @@ -67,9 +67,16 @@ def __init__( Available blocks: ['norm_in', 'encoder','decoder']. Set 'block_name' = None or False to turn off that block """ - super().__init__( - in_features=encoder_layers[0], out_features=encoder_layers[-1], **kwargs - ) + # external model not supported for autoencoder CVs (yet) + if not isinstance(encoder_layers, list): + raise NotImplementedError( + f"Encoder layer must be a list. found {type(encoder_layers)}" + ) + super().__init__(model=encoder_layers, **kwargs) + # this makes checkpointing safe, to avoid double model keys + self.save_hyperparameters(ignore=['model']) + self.hparams.pop('model') + # ======= LOSS ======= # Reconstruction (MSE) loss diff --git a/mlcolvar/cvs/unsupervised/vae.py b/mlcolvar/cvs/unsupervised/vae.py index 7ff297b4..a530b469 100644 --- a/mlcolvar/cvs/unsupervised/vae.py +++ b/mlcolvar/cvs/unsupervised/vae.py @@ -58,7 +58,7 @@ class VariationalAutoEncoderCV(BaseCV, lightning.LightningModule): Evidence Lower BOund loss function """ - BLOCKS = ["norm_in", "encoder", "decoder"] + DEFAULT_BLOCKS = ["norm_in", "encoder", "decoder"] def __init__( self, @@ -90,7 +90,17 @@ def __init__( Set ``'block_name' = None`` or ``False`` to turn off a block. Encoder and decoder cannot be turned off. """ - super().__init__(in_features=encoder_layers[0], out_features=n_cvs, **kwargs) + if not isinstance(encoder_layers, list): + raise NotImplementedError( + f"Encoder layer must be a list. found {type(encoder_layers)}" + ) + super().__init__(model=encoder_layers, **kwargs) + # this makes checkpointing safe, to avoid double model keys + self.save_hyperparameters(ignore=['model']) + self.hparams.pop('model') + + # here we need to override the self.out_features attribute + self.out_features = n_cvs # ======= LOSS ======= # ELBO loss function when latent space and reconstruction distributions are Gaussians. diff --git a/mlcolvar/data/dataloader.py b/mlcolvar/data/dataloader.py index 7ac18861..cc2c1ed4 100644 --- a/mlcolvar/data/dataloader.py +++ b/mlcolvar/data/dataloader.py @@ -223,7 +223,7 @@ def set_dataset_and_batch_size( self._dataset = old_dataset self._batch_size = old_batch_size raise ValueError( - f"batch_size (length {batch_size_len}) must have length equal to the number of datasets (length {len(self.dataset)}." + f"batch_size (length {len(self._batch_size)} must have length equal to the number of datasets (length {len(self.dataset)}." ) # The number of batches per epoch must be the same for all datasets. diff --git a/mlcolvar/data/datamodule.py b/mlcolvar/data/datamodule.py index 4323b211..2bf26fae 100644 --- a/mlcolvar/data/datamodule.py +++ b/mlcolvar/data/datamodule.py @@ -20,14 +20,14 @@ import warnings import torch +import torch_geometric import numpy as np import lightning -from torch.utils.data import random_split, Subset +from torch.utils.data import Subset from torch import default_generator, randperm from mlcolvar.data import DictLoader, DictDataset - # ============================================================================= # DICTIONARY DATAMODULE CLASS # ============================================================================= @@ -122,7 +122,10 @@ def __init__( """ super().__init__() self.dataset = dataset - self.lengths = lengths + self.DataLoader = self._get_dataloader() + + self._lengths = lengths + # Keeping this private for now. Changing it at runtime would # require changing dataset_split and the dataloaders. self._random_split = random_split @@ -135,6 +138,9 @@ def __init__( ) # Make sure batch_size and shuffle are lists. + + if self._dataset_type == 'graphs' and batch_size == 0: + batch_size = len(dataset) # make this explicit for torch_geometric if isinstance(batch_size, int): self.batch_size = [batch_size for _ in lengths] else: @@ -152,6 +158,25 @@ def __init__( self.valid_loader = None self.test_loader = None + @property + def _dataset_type(self): + if not isinstance(self.dataset, list): + _dataset_type = self.dataset.metadata['data_type'] + else: + it = iter(list(self.dataset)) + _dataset_type = next(it).metadata['data_type'] + if not all(d.metadata['data_type'] for d in it): + raise ValueError("not all the dataset are of the same type!") + return _dataset_type + + def _get_dataloader(self): + # decide which loader to use + if self._dataset_type == 'descriptors': + DataLoader = DictLoader + elif self._dataset_type == 'graphs': + DataLoader = torch_geometric.loader.DataLoader + return DataLoader + def setup(self, stage: Optional[str] = None): if self._dataset_split is None: if isinstance(self.dataset, DictDataset): @@ -165,7 +190,7 @@ def train_dataloader(self): """Return training dataloader.""" self._check_setup() if self.train_loader is None: - self.train_loader = DictLoader( + self.train_loader = self.DataLoader( self._dataset_split[0], batch_size=self.batch_size[0], shuffle=self.shuffle[0], @@ -175,12 +200,12 @@ def train_dataloader(self): def val_dataloader(self): """Return validation dataloader.""" self._check_setup() - if len(self.lengths) < 2: + if len(self._lengths) < 2: raise NotImplementedError( "Validation dataset not available, you need to pass two lengths to datamodule." ) if self.valid_loader is None: - self.valid_loader = DictLoader( + self.valid_loader = self.DataLoader( self._dataset_split[1], batch_size=self.batch_size[1], shuffle=self.shuffle[1], @@ -190,12 +215,12 @@ def val_dataloader(self): def test_dataloader(self): """Return test dataloader.""" self._check_setup() - if len(self.lengths) < 3: + if len(self._lengths) < 3: raise NotImplementedError( "Test dataset not available, you need to pass three lengths to datamodule." ) if self.test_loader is None: - self.test_loader = DictLoader( + self.test_loader = self.DataLoader( self._dataset_split[2], batch_size=self.batch_size[2], shuffle=self.shuffle[2], @@ -210,11 +235,11 @@ def teardown(self, stage: str): def __repr__(self) -> str: string = f"DictModule(dataset -> {self.dataset.__repr__()}" - string += f",\n\t\t train_loader -> DictLoader(length={self.lengths[0]}, batch_size={self.batch_size[0]}, shuffle={self.shuffle[0]})" - if len(self.lengths) >= 2: - string += f",\n\t\t valid_loader -> DictLoader(length={self.lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})" - if len(self.lengths) >= 3: - string += f",\n\t\t\ttest_loader =DictLoader(length={self.lengths[2]}, batch_size={self.batch_size[2]}, shuffle={self.shuffle[2]})" + string += f",\n\t\t train_loader -> DictLoader(length={self._lengths[0]}, batch_size={self.batch_size[0]}, shuffle={self.shuffle[0]})" + if len(self._lengths) >= 2: + string += f",\n\t\t valid_loader -> DictLoader(length={self._lengths[1]}, batch_size={self.batch_size[1]}, shuffle={self.shuffle[1]})" + if len(self._lengths) >= 3: + string += f",\n\t\t\ttest_loader =DictLoader(length={self._lengths[2]}, batch_size={self.batch_size[2]}, shuffle={self.shuffle[2]})" string += f")" return string @@ -225,7 +250,7 @@ def _split(self, dataset): """ dataset_split = split_dataset( - dataset, self.lengths, self._random_split, self.generator + dataset, self._lengths, self._random_split, self.generator ) return dataset_split @@ -237,6 +262,23 @@ def _check_setup(self): "outside a Lightning trainer please call .setup() first." ) + def get_graph_inputs(self, mode='train'): + """Generate an input that can be used as input for a GNN model + + Parameters + ---------- + mode : str, optional + Type of loader to be used, either 'train' or 'val'/'valid', by default 'train' + """ + self.setup() + if mode == 'train': + loader=self.train_dataloader + elif (mode=='val' or mode=='valid'): + loader=self.val_dataloader + else: + raise ValueError(f"Mode can either be 'train', 'val', 'valid', found {mode}!") + + return next(iter(loader()))['data_list'] def split_dataset( dataset, diff --git a/mlcolvar/data/dataset.py b/mlcolvar/data/dataset.py index d9b274a5..cccf2431 100644 --- a/mlcolvar/data/dataset.py +++ b/mlcolvar/data/dataset.py @@ -1,7 +1,9 @@ import torch +import torch_geometric import numpy as np from mlcolvar.core.transform.utils import Statistics from torch.utils.data import Dataset +from operator import itemgetter __all__ = ["DictDataset"] @@ -14,7 +16,13 @@ class DictDataset(Dataset): 'weights' : np.asarray([0.5,1.5,1.5,0.5]) } """ - def __init__(self, dictionary: dict = None, feature_names=None, create_ref_idx : bool = False, **kwargs): + def __init__(self, + dictionary: dict=None, + feature_names = None, + metadata: dict = None, + data_type : str = 'descriptors', + create_ref_idx : bool = False, + **kwargs): """Create a Dataset from a dictionary or from a list of kwargs. Parameters @@ -23,6 +31,12 @@ def __init__(self, dictionary: dict = None, feature_names=None, create_ref_idx : Dictionary with names and tensors feature_names : array-like List or numpy array with feature names + metadata : dict + Dictionary with metadata quantities shared across the whole dataset. + data_type : str + Type of data stored in the dataset, either 'descriptors' or 'graphs', by default 'descriptors'. + This will be stored in the dataset.metadata dictionary. + """ # assert type dict @@ -30,7 +44,18 @@ def __init__(self, dictionary: dict = None, feature_names=None, create_ref_idx : raise TypeError( f"DictDataset requires a dictionary , not {type(dictionary)}." ) - + + if (metadata is not None) and (not isinstance(metadata, dict)): + raise TypeError( + f"DictDataset metadata requires a dictionary , not {type(metadata)}." + ) + + # assert data_type is 'descriptors' or 'graphs' + if not data_type in ['descriptors', 'graphs']: + raise TypeError( + f"data_type expected to be either 'descriptors' or 'graph', found {data_type}" + ) + # Add kwargs to dict if dictionary is None: dictionary = {} @@ -38,10 +63,23 @@ def __init__(self, dictionary: dict = None, feature_names=None, create_ref_idx : if len(dictionary) == 0: raise ValueError("Empty datasets are not supported") + # initialize metadata as dict + if metadata is None: + metadata = {} + + if 'data_type' in metadata.keys(): + if not metadata['data_type'] == data_type: + raise ValueError(f"Two different data_type specified. Found {metadata['data_type']} in metadata and {data_type} as keyword") + else: + metadata['data_type'] = data_type + # convert to torch.Tensors for key, val in dictionary.items(): if not isinstance(val, torch.Tensor): - dictionary[key] = torch.Tensor(val) + if key in ["data_list", "data_list_lag"]: + dictionary[key] = val + else: + dictionary[key] = torch.Tensor(val) # save dictionary self._dictionary = dictionary @@ -49,10 +87,13 @@ def __init__(self, dictionary: dict = None, feature_names=None, create_ref_idx : # save feature names self.feature_names = feature_names + # save metadata + self.metadata = metadata + # check that all elements of dict have same length it = iter(dictionary.values()) self.length = len(next(it)) - if not all(len(l) == self.length for l in it): + if not all([len(l) == self.length for l in it]): raise ValueError("not all arrays in dictionary have same length!") # add indexing of entries for shuffling and slicing reference @@ -62,12 +103,14 @@ def __init__(self, dictionary: dict = None, feature_names=None, create_ref_idx : def __getitem__(self, index): if isinstance(index, str): - # raise TypeError(f'Index ("{index}") should be a slice, and not a string. To access the stored dictionary use .dictionary["{index}"] instead.') return self._dictionary[index] - else: + else: slice_dict = {} for key, val in self._dictionary.items(): - slice_dict[key] = val[index] + try: + slice_dict[key] = val[index] + except: + slice_dict[key] = list(itemgetter(*index)(val)) return slice_dict def __setitem__(self, index, value): @@ -95,6 +138,10 @@ def get_stats(self): stats dictionary of dictionaries with statistics """ + if self.metadata == 'graph': + raise ValueError ( + "Method get_stats not supported for graph-based dataset!" + ) stats = {} for k in self.keys: print("KEY: ", k, end="\n\n\n") @@ -105,7 +152,12 @@ def get_stats(self): def __repr__(self) -> str: string = "DictDataset(" for key, val in self._dictionary.items(): - string += f' "{key}": {list(val.shape)},' + if key in ["data_list", "data_list_lag"]: + string += f' "{key}": {len(val)},' + else: + string += f' "{key}": {list(val.shape)},' + for key, val in self.metadata.items(): + string += f' "{key}": {val},' string = string[:-1] + " )" return string @@ -124,15 +176,30 @@ def feature_names(self, value): np.asarray(value, dtype=str) if value is not None else value ) + def get_graph_inputs(self): + """Generate and input suitable for graph models. Returns the whole dataset as a single batch not shuffled""" + assert self.metadata['data_type'] == 'graphs', ( + 'Graph inputs can only be generated for graph-based datasets' + ) + loader = torch_geometric.loader.DataLoader(self, + batch_size=len(self), + shuffle=False ) + return next(iter(loader))['data_list'] def test_DictDataset(): + # descriptors based # from list + data = torch.Tensor([[1.0], [2.0], [0.3], [0.4]]) + labels = [0, 0, 1, 1] + weights = np.asarray([0.5, 1.5, 1.5, 0.5]) dataset_dict = { - "data": torch.Tensor([[1.0], [2.0], [0.3], [0.4]]), - "labels": [0, 0, 1, 1], - "weights": np.asarray([0.5, 1.5, 1.5, 0.5]), + "data": data, + "labels": labels, + "weights": weights, } - + + # this to have the right signature in asserts + from mlcolvar.data.dataset import DictDataset dataset = DictDataset(dataset_dict) print(len(dataset)) print(dataset[0]) @@ -141,18 +208,106 @@ def test_DictDataset(): # test with dataloader from torch.utils.data import DataLoader - loader = DataLoader(dataset, batch_size=1) batch = next(iter(loader)) print(batch["data"]) # test with fastdataloader - from .dataloader import DictLoader - + from mlcolvar.data import DictLoader loader = DictLoader(dataset, batch_size=1) batch = next(iter(loader)) print(batch) + from mlcolvar.data.graph.atomic import AtomicNumberTable, Configuration + from mlcolvar.data.graph.utils import create_dataset_from_configurations + # graphs based + numbers = [8, 1, 1] + positions = np.array( + [[[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]]], + dtype=float + ) + cell = np.identity(3, dtype=float) * 0.2 + graph_labels = np.array([[1], [0], [1]]) + node_labels = np.array([[0], [1], [1]]) + z_table = AtomicNumberTable.from_zs(numbers) + + config = [Configuration( + atomic_numbers=numbers, + positions=positions[i] + 0.1*i, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels[i], + graph_labels=graph_labels, + ) for i in range(3)] + graph_dataset = create_dataset_from_configurations(config, + z_table, + 0.1, + show_progress=False + ) + print(graph_dataset) + assert(isinstance(graph_dataset, DictDataset)) + + # check __getitem__ + # string + out = dataset['data'] + assert( torch.allclose(out, data) ) + out = graph_dataset['data_list'] + assert( torch.allclose(out[1]['positions'], torch.Tensor(positions+0.1))) + + # int + out = dataset[1] + assert( torch.allclose(out['data'], data[1]) ) + out = graph_dataset[1] + assert( torch.allclose(out['data_list']['positions'], torch.Tensor(positions+0.1))) + + + # list + out = dataset[[0,1,2]] + assert( torch.allclose(out['data'], data[[0,1,2]]) ) + out = graph_dataset[[0,1,2]] + for i in [0,1,2]: + assert( torch.allclose(out['data_list'][i]['positions'], torch.Tensor(positions+0.1*i))) + + # slice + out = dataset[0:2] + assert( torch.allclose(out['data'], data[[0,1]]) ) + out = graph_dataset[0:2] + for i in [0,1]: + assert( torch.allclose(out['data_list'][i]['positions'], torch.Tensor(positions+0.1*i))) + + # range + out = dataset[range(0,2)] + assert( torch.allclose(out['data'], data[[0,1]]) ) + out = graph_dataset[range(0,2)] + for i in [0,1]: + assert( torch.allclose(out['data_list'][i]['positions'], torch.Tensor(positions+0.1*i))) + + # np.ndarray + out = dataset[np.array(1)] + assert( torch.allclose(out['data'], data[1]) ) + out = graph_dataset[np.array(1)] + assert( torch.allclose(out['data_list']['positions'], torch.Tensor(positions+0.1))) + + out = dataset[np.array([0,1,2])] + assert( torch.allclose(out['data'], data[[0,1,2]]) ) + out = graph_dataset[np.array([0,1,2])] + for i in [0,1,2]: + assert( torch.allclose(out['data_list'][i]['positions'], torch.Tensor(positions+0.1*i))) + + # torch.Tensor + out = dataset[torch.tensor([1], dtype=torch.long)] + assert( torch.allclose(out['data'], data[1]) ) + out = graph_dataset[torch.tensor([1], dtype=torch.long)] + assert( torch.allclose(out['data_list']['positions'], torch.Tensor(positions+0.1))) + + out = dataset[torch.tensor([0,1,2], dtype=torch.long)] + assert( torch.allclose(out['data'], data[[0,1,2]]) ) + out = graph_dataset[torch.tensor([0,1,2], dtype=torch.long)] + for i in [0,1,2]: + assert( torch.allclose(out['data_list'][i]['positions'], torch.Tensor(positions+0.1*i))) + if __name__ == "__main__": - test_DictDataset() + test_DictDataset() \ No newline at end of file diff --git a/mlcolvar/data/graph/__init__.py b/mlcolvar/data/graph/__init__.py new file mode 100644 index 00000000..5322ece6 --- /dev/null +++ b/mlcolvar/data/graph/__init__.py @@ -0,0 +1,5 @@ +__all__ = ["AtomicNumberTable", "Configuration", "Configurations", "get_neighborhood", "create_dataset_from_configurations", "create_test_graph_input"] + +from .atomic import * +from .neighborhood import * +from .utils import * \ No newline at end of file diff --git a/mlcolvar/data/graph/atomic.py b/mlcolvar/data/graph/atomic.py new file mode 100644 index 00000000..78c0eb31 --- /dev/null +++ b/mlcolvar/data/graph/atomic.py @@ -0,0 +1,151 @@ +import warnings +import numpy as np +import mdtraj as md +from dataclasses import dataclass +from typing import List, Iterable, Optional + +""" +The helper functions for atomic data. This module is taken from MACE directly: +https://github.com/ACEsuit/mace/blob/main/mace/tools/utils.py +https://github.com/ACEsuit/mace/blob/main/mace/data/utils.py +""" + +__all__ = ['AtomicNumberTable', 'Configuration', 'Configurations'] + + +class AtomicNumberTable: + """The atomic number table. + Used to map between one hot encodings and a given set of actual atomic numbers. + """ + + def __init__(self, zs: List[int]) -> None: + """Initializes an atomi number table object + + Parameters + ---------- + zs: List[int] + The atomic numbers in this table + """ + self.zs = zs + self.masses = [1.0] * len(zs) + for i in range(len(zs)): + try: + m = md.element.Element.getByAtomicNumber(zs[i]).mass + self.masses[i] = m + except Exception: + warnings.warn( + 'Can not assign mass for atom number: {:d}'.format(zs[i]) + ) + + def __len__(self) -> int: + """Number of elements in the table""" + return len(self.zs) + + def __str__(self) -> str: + return f'AtomicNumberTable: {tuple(s for s in self.zs)}' + + def index_to_z(self, index: int) -> int: + """Maps the encoding to the actual atomic number + + Parameters + ---------- + index: int + Index of the encoding to be mapped + """ + return self.zs[index] + + def index_to_symbol(self, index: int) -> str: + """Map the encoding to the atomic symbol + + Parameters + ---------- + index: int + Index of the encoding to be mapped + """ + return md.element.Element.getByAtomicNumber(self.zs[index]).symbol + + def z_to_index(self, atomic_number: int) -> int: + """Maps an atomic number to the encoding. + + Parameters + ---------- + atomic_number: int + The atomic number to be mapped + """ + return self.zs.index(atomic_number) + + def zs_to_indices(self, atomic_numbers: np.ndarray) -> np.ndarray: + """Maps an array of atomic number to the encodings. + + Parameters + ---------- + atomic_numbers: numpy.ndarray + The atomic numbers to be mapped + """ + to_index_fn = np.vectorize(self.z_to_index) + return to_index_fn(atomic_numbers) + + @classmethod + def from_zs(cls, atomic_numbers: Iterable[int]) -> 'AtomicNumberTable': + """Build the table from an array atomic numbers. + + Parameters + ---------- + atomic_numbers: Iterable[int] + The atomic numbers to be used for building the table + """ + z_set = set() + for z in atomic_numbers: + z_set.add(z) + return cls(sorted(list(z_set))) + + +def get_masses(atomic_numbers: Iterable[int]) -> List[float]: + """Get atomic masses from atomic numbers. + + Parameters + ---------- + atomic_numbers: Iterable[int] + The atomic numbers for which to return the atomic masses + """ + return AtomicNumberTable.from_zs(atomic_numbers).masses.copy() + + +@dataclass +class Configuration: + """ + Internal helper class that describe a given configuration of the system. + """ + atomic_numbers: np.ndarray # shape: [n_atoms] + positions: np.ndarray # shape: [n_atoms, 3], units: Ang + cell: np.ndarray # shape: [n_atoms, 3], units: Ang + pbc: Optional[tuple] # shape: [3] + node_labels: Optional[np.ndarray] # shape: [n_atoms, n_node_labels] + graph_labels: Optional[np.ndarray] # shape: [n_graph_labels, 1] + weight: Optional[float] = 1.0 # shape: [] + system: Optional[np.ndarray] = None # shape: [n_system_atoms] + environment: Optional[np.ndarray] = None # shape: [n_environment_atoms] + + +Configurations = List[Configuration] + + +def test_atomic_number_table() -> None: + table = AtomicNumberTable([1, 6, 7, 8]) + + numbers = np.array([1, 7, 6, 8]) + assert ( + table.zs_to_indices(numbers) == np.array([0, 2, 1, 3], dtype=int) + ).all() + + numbers = np.array([1, 1, 1, 6, 8, 1]) + assert ( + table.zs_to_indices(numbers) == np.array([0, 0, 0, 1, 3, 0], dtype=int) + ).all() + + table_1 = AtomicNumberTable.from_zs([6] * 3 + [1] * 10 + [7] * 3 + [8] * 2) + assert table_1.zs == table.zs + + +if __name__ == '__main__': + test_atomic_number_table() diff --git a/mlcolvar/data/graph/neighborhood.py b/mlcolvar/data/graph/neighborhood.py new file mode 100644 index 00000000..09791f34 --- /dev/null +++ b/mlcolvar/data/graph/neighborhood.py @@ -0,0 +1,254 @@ +import numpy as np +from matscipy.neighbours import neighbour_list +from typing import Optional, Tuple, List + +""" +The neighbor list function. This module is taken from MACE directly: +https://github.com/ACEsuit/mace/blob/main/mace/data/neighborhood.py +""" + +__all__ = ['get_neighborhood'] + + +def get_neighborhood( + positions: np.ndarray, # [num_positions, 3] + cutoff: float, + pbc: Optional[Tuple[bool, bool, bool]] = None, + cell: Optional[np.ndarray] = None, # [3, 3] + true_self_interaction: Optional[bool] = False, + system_indices: Optional[List[int]] = None, + environment_indices: Optional[List[int]] = None, + buffer: float = 0.0 +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get the neighbor list of a given set atoms. + + Parameters + ---------- + positions: numpy.ndarray (shape: [N, 3]) + The positions array. + cutoff: float + The cutoff radius. + pbc: Tuple[bool, bool, bool] (shape: [3]) + If to enable PBC in the directions of the three lattice vectors. + cell: numpy.ndarray (shape: [3, 3]) + The lattice vectors. + true_self_interaction: bool + If to keep self-edges that don't cross periodic boundaries. + system_indices: List[int] + Indices of the atoms to be considered as the 'system' if + restricting the neighborhood to a subsystem (i.e., system + environment), see also Notes section. + environment_indices: List[int] + Indices of the atoms to be considered as the 'environment' if + restricting the neighborhood to a subsystem (i.e., system + environment), see also Notes section. + Only atoms within the cutoff will be included as active enviroment atoms + buffer: float + Buffer size used in finding active environment atoms, if + restricting the neighborhood to a subsystem (i.e., system + environment), see also Notes section. + + Returns + ------- + edge_index: numpy.ndarray (shape: [2, n_edges]) + The edge indices (i.e., source and destination) in the graph. + shifts: numpy.ndarray (shape: [n_edges, 3]) + The shift vectors (unit_shifts * cell_lengths). + unit_shifts: numpy.ndarray (shape: [n_edges, 3]) + The unit shift vectors (number of PBC crossed by the edges). + + Notes + ----- + Arguments `system_indices` and `environment_indices` must present at the + same time. When these arguments are given, only edges in the [subsystem] + formed by [the systems atoms] and [the environment atoms within the cutoff + radius of the systems atoms] will be kept. + These two lists could not contain common atoms. + """ + + if system_indices is not None or environment_indices is not None: + assert system_indices is not None and environment_indices is not None + + system_indices = np.array(system_indices) + environment_indices = np.array(environment_indices) + assert np.intersect1d(system_indices, environment_indices).size == 0 + else: + assert buffer == 0.0 + + if pbc is None: + pbc = (False, False, False) + + if cell is None or cell.any() == np.zeros((3, 3)).any(): + cell = np.identity(3, dtype=float) + + assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) + assert cell.shape == (3, 3) + + pbc_x = pbc[0] + pbc_y = pbc[1] + pbc_z = pbc[2] + identity = np.identity(3, dtype=float) + max_positions = np.max(np.absolute(positions)) + 1 + # Extend cell in non-periodic directions + # For models with more than 5 layers, the multiplicative constant needs to + # be increased. + if not pbc_x: + cell[:, 0] = max_positions * 5 * cutoff * identity[:, 0] + if not pbc_y: + cell[:, 1] = max_positions * 5 * cutoff * identity[:, 1] + if not pbc_z: + cell[:, 2] = max_positions * 5 * cutoff * identity[:, 2] + + sender, receiver, unit_shifts, distances = neighbour_list( + quantities='ijSd', + pbc=pbc, + cell=cell, + positions=positions, + cutoff=float(cutoff + buffer), + # self_interaction=True, # we want edges from atom to itself in different periodic images + # use_scaled_positions=False, # positions are not scaled positions + ) + + if not true_self_interaction: + # Eliminate self-edges that don't cross periodic boundaries + true_self_edge = sender == receiver + true_self_edge &= np.all(unit_shifts == 0, axis=1) + keep_edge = ~true_self_edge + + # NOTE: after eliminating self-edges, it can be that no edges remain + # in this system + sender = sender[keep_edge] + receiver = receiver[keep_edge] + unit_shifts = unit_shifts[keep_edge] + distances = distances[keep_edge] + + if system_indices is not None: + # Get environment atoms that are neighbors of the system. + keep_edge = np.where(np.in1d(receiver, system_indices))[0] + keep_sender = np.intersect1d(sender[keep_edge], environment_indices) + keep_atom = np.concatenate((system_indices, np.unique(keep_sender))) + # Get the edges in the subsystem. + keep_sender = np.where(np.in1d(sender, keep_atom))[0] + keep_receiver = np.where(np.in1d(receiver, keep_atom))[0] + keep_edge = np.intersect1d(keep_sender, keep_receiver) + keep_edge_distance = np.where(distances <= cutoff)[0] + keep_edge = np.intersect1d(keep_edge, keep_edge_distance) + # Get the edges + sender = sender[keep_edge] + receiver = receiver[keep_edge] + unit_shifts = unit_shifts[keep_edge] + + # Build output + edge_index = np.stack((sender, receiver)) # [2, n_edges] + + # From the docs: With the shift vector S, the distances D between atoms + # can be computed from: D = positions[j]-positions[i]+S.dot(cell) + shifts = np.dot(unit_shifts, cell) # [n_edges, 3] + + return edge_index, shifts, unit_shifts + + +def test_get_neighborhood() -> None: + + positions = np.array( + [[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]], dtype=float + ) + cell = np.array([[4, 0, 0], [0, 4, 0], [0, 0, 4]], dtype=float) + + n, s, u = get_neighborhood(positions, cutoff=5.0) + assert ( + n == np.array( + [[0, 0, 1, 1, 1, 2, 2, 2, 3, 3], [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]], + dtype=int + ) + ).all() + + n, s, u = get_neighborhood(positions, cutoff=2.0) + assert ( + n == np.array([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]], dtype=int) + ).all() + + n, s, u = get_neighborhood( + positions, cutoff=2.0, pbc=[True] * 3, cell=cell + ) + assert ( + n == np.array( + [[0, 0, 1, 1, 2, 2, 3, 3], [3, 1, 0, 2, 1, 3, 2, 0]], dtype=int + ) + ).all() + assert ( + s == np.array( + [ + [-4.0, -4.0, -4.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [4.0, 4.0, 4.0] + ], + dtype=float + ) + ).all() + assert ( + u == np.array( + [ + [-1, -1, -1], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [1, 1, 1] + ], + dtype=int + ) + ).all() + + n, s, u = get_neighborhood( + positions, + cutoff=2.0, + pbc=[True] * 3, + cell=cell, + system_indices=[0, 1], + environment_indices=[2, 3] + ) + assert ( + n == np.array( + [[0, 0, 1, 1, 2, 2, 3, 3], [3, 1, 0, 2, 1, 3, 2, 0]], dtype=int + ) + ).all() + + n, s, u = get_neighborhood( + positions, + cutoff=2.0, + pbc=[True] * 3, + cell=cell, + system_indices=[0], + environment_indices=[1, 2, 3] + ) + assert ( + n == np.array( + [[0, 0, 1, 3], [3, 1, 0, 0]], dtype=int + ) + ).all() + assert ( + s == np.array( + [ + [-4.0, -4.0, -4.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [4.0, 4.0, 4.0] + ], + dtype=float + ) + ).all() + assert ( + u == np.array( + [[-1, -1, -1], [0, 0, 0], [0, 0, 0], [1, 1, 1]], + dtype=int + ) + ).all() + + +if __name__ == "__main__": + test_get_neighborhood() diff --git a/mlcolvar/data/graph/utils.py b/mlcolvar/data/graph/utils.py new file mode 100644 index 00000000..1732ce65 --- /dev/null +++ b/mlcolvar/data/graph/utils.py @@ -0,0 +1,839 @@ +import copy +from collections import defaultdict +from typing import Union + +import torch +import torch_geometric +from torch_geometric.data import Data, HeteroData +from torch_geometric.transforms import BaseTransform + +from mlcolvar.data import DictDataset, DictModule +from mlcolvar.data.graph import atomic +from mlcolvar.data.graph.neighborhood import get_neighborhood +from mlcolvar.utils.plot import pbar + +from typing import List + +__all__ = ["create_dataset_from_configurations", "create_test_graph_input"] + +def _create_dataset_from_configuration( + config: atomic.Configuration, + z_table: atomic.AtomicNumberTable, + cutoff: float, + buffer: float = 0.0, +) -> torch_geometric.data.Data: + """Build the torch_geometric graph data object from a configuration. + + Parameters + ---------- + config: mlcolvar.data.graph.atomic.Configuration + The configuration from which to generate the graph data + z_table: mlcolvar.data.graph.atomic.AtomicNumberTable + The atomic number table used to build the node attributes + cutoff: float + The graph cutoff radius + buffer: float + Buffer size used in finding active environment atoms if + restricting the neighborhood to a subsystem (i.e., system + environment), + `see also mlcolvar.data.grap.neighborhood.get_neighborhood` + """ + + assert config.graph_labels is None or len(config.graph_labels.shape) == 2 + + # NOTE: here we do not take care about the nodes that are not taking part + # the graph, like, we don't even change the node indices in `edge_index`. + # Here we simply ignore them, and rely on the `RemoveIsolatedNodes` method + # that will be called later (in `create_dataset_from_configurations`). + edge_index, shifts, unit_shifts = get_neighborhood( + positions=config.positions, + cutoff=cutoff, + cell=config.cell, + pbc=config.pbc, + system_indices=config.system, + environment_indices=config.environment, + buffer=buffer + ) + edge_index = torch.tensor(edge_index, dtype=torch.long) + shifts = torch.tensor(shifts, dtype=torch.get_default_dtype()) + unit_shifts = torch.tensor( + unit_shifts, dtype=torch.get_default_dtype() + ) + + positions = torch.tensor( + config.positions, dtype=torch.get_default_dtype() + ) + cell = torch.tensor(config.cell, dtype=torch.get_default_dtype()) + + indices = z_table.zs_to_indices(config.atomic_numbers) + one_hot = to_one_hot( + torch.tensor(indices, dtype=torch.long).unsqueeze(-1), + n_classes=len(z_table), + ) + + node_labels = ( + torch.tensor(config.node_labels, dtype=torch.get_default_dtype()) + if config.node_labels is not None + else None + ) + + graph_labels = ( + torch.tensor(config.graph_labels, dtype=torch.get_default_dtype()) + if config.graph_labels is not None + else None + ) + + weight = ( + torch.tensor(config.weight, dtype=torch.get_default_dtype()) + if config.weight is not None + else 1 + ) + + n_system = ( + torch.tensor( + [[len(config.system)]], dtype=torch.get_default_dtype() + ) if config.system is not None + else torch.tensor( + [[one_hot.shape[0]]], dtype=torch.get_default_dtype() + ) + ) + + n_env = ( + torch.tensor( + [[one_hot.shape[0] - n_system.to(torch.int).item()]], dtype=torch.get_default_dtype() + ) + ) + + if config.system is not None: + system_masks = torch.zeros((one_hot.shape[0], 1), dtype=torch.bool) + system_masks[config.system, 0] = 1 + else: + system_masks = None + + return torch_geometric.data.Data( + edge_index=edge_index, + shifts=shifts, + unit_shifts=unit_shifts, + positions=positions, + cell=cell, + node_attrs=one_hot, + node_labels=node_labels, + graph_labels=graph_labels, + n_system=n_system, + n_env=n_env, + system_masks=system_masks, + weight=weight, + ) + + +def create_dataset_from_configurations( + config: atomic.Configurations, + z_table: atomic.AtomicNumberTable, + cutoff: float, + buffer: float = 0.0, + atom_names: List = None, + remove_isolated_nodes: bool = False, + show_progress: bool = True +) -> DictDataset: + """Build DictDataset object containing torch_geometric graph data objects from configurations. + + Parameters + ---------- + config: mlcolvar.graph.utils.atomic.Configurations + The configurations from whihc to generate the dataset + z_table: mlcolvar.graph.utils.atomic.AtomicNumberTable + The atomic number table used to build the node attributes + cutoff: float + The graph cutoff radius + buffer: float + Buffer size used in finding active environment atoms if + restricting the neighborhood to a subsystem (i.e., system + environment), + `see also mlcolvar.data.grap.neighborhood.get_neighborhood` + remove_isolated_nodes: bool + If to remove isolated nodes from the dataset + show_progress: bool + If to show the progress bar + """ + if show_progress: + items = pbar(config, frequency=0.0001, prefix='Making graphs') + else: + items = config + + data_list = [ + _create_dataset_from_configuration( + config=c, + z_table=z_table, + cutoff=cutoff, + buffer=buffer, + ) for c in items + ] + + if atom_names is None: + atom_names_system = [f"X{i}" for i in range(data_list[0]['n_system'].to(torch.int64).item())] + atom_names_env = [f"Y{i}" for i in range(data_list[0]['n_env'].to(torch.int64).item())] + atom_names = atom_names_system + atom_names_env + + # this is only to check what isolated nodes have been removed + _aux_pos = torch.Tensor((np.array([d['positions'].numpy() for d in data_list]))) + if remove_isolated_nodes: + # TODO: not the worst way to fake the `is_node_attr` method of + # `torch_geometric.data.storage.GlobalStorage` ... + # I mean, when there are exact three atoms in the graph, the + # `RemoveIsolatedNodes` method will remove the cell vectors that + # correspond to the isolated node ... This is a consequence of that + # pyg regarding the cell vectors as some kind of node features. + # So here we first remove the isolated nodes, then set the cell back. + cell_list = [d.cell.clone() for d in data_list] + transform = _RemoveIsolatedNodes() + data_list = [transform(d) for d in data_list] + + # check what have been removed and restore cell + unique_idx = [] # store the indeces of the atoms that have been used at least once + for i in range(len(data_list)): + data_list[i].cell = cell_list[i] + # get and save the original index before removing isolated nodes for each entry + original_idx = torch.unique( torch.where(torch.isin(torch.round(_aux_pos[i], decimals=5), + torch.round(data_list[i]['positions'], decimals=5)) + )[0] + ) + data_list[i]['names_idx'] = original_idx.to(torch.int64) + + # update if needed the overall list + check = np.isin(original_idx.numpy(), unique_idx, invert=True) + if check.any(): + aux = np.where(check)[0] + unique_idx.extend(original_idx[aux].tolist()) + + unique_idx.sort() + unique_idx = torch.Tensor(unique_idx).to(torch.int64) + # here we simply have to take all the atoms + else: + unique_idx = torch.arange(data_list[0]['n_system'].item()).to(torch.int64) + for i in range(len(data_list)): + data_list[i]['names_idx'] = unique_idx + + # we also save the names of the atoms that have been actually used + unique_names = np.array(atom_names)[unique_idx] + unique_names = unique_names.tolist() + + dataset = DictDataset(dictionary={'data_list' : data_list}, + metadata={'z_table' : z_table.zs, + 'cutoff' : cutoff, + 'used_idx' : unique_idx, + 'used_names' : unique_names}, + data_type='graphs') + + return dataset + +def to_one_hot(indices: torch.Tensor, n_classes: int) -> torch.Tensor: + """Generates one-hot encoding with `n_classes` classes from `indices` + + Parameters + ---------- + indices: torch.Tensor (shape: [N, 1]) + Node indices + n_classes: int + Number of classes + + Returns + ------- + encoding: torch.tensor (shape: [N, n_classes]) + The one-hot encoding + """ + shape = indices.shape[:-1] + (n_classes,) + oh = torch.zeros(shape, device=indices.device).view(shape) + + # scatter_ is the in-place version of scatter + oh.scatter_(dim=-1, index=indices, value=1) + + return oh.view(*shape) + + +class _RemoveIsolatedNodes(BaseTransform): + r"""Removes isolated nodes from the graph + This is taken from pytorch_geometric with a small modification to avoid the bug when n_nodes==n_edges + """ + def forward( + self, + data: Union[Data, HeteroData], + ) -> Union[Data, HeteroData]: + # Gather all nodes that occur in at least one edge (across all types): + n_ids_dict = defaultdict(list) + for edge_store in data.edge_stores: + if 'edge_index' not in edge_store: + continue + + if edge_store._key is None: + src = dst = None + else: + src, _, dst = edge_store._key + + n_ids_dict[src].append(edge_store.edge_index[0]) + n_ids_dict[dst].append(edge_store.edge_index[1]) + + n_id_dict = {k: torch.cat(v).unique() for k, v in n_ids_dict.items()} + + n_map_dict = {} + for node_store in data.node_stores: + if node_store._key not in n_id_dict: + n_id_dict[node_store._key] = torch.empty(0, dtype=torch.long) + + idx = n_id_dict[node_store._key] + assert data.num_nodes is not None + mapping = idx.new_zeros(data.num_nodes) + mapping[idx] = torch.arange(idx.numel(), device=mapping.device) + n_map_dict[node_store._key] = mapping + + for edge_store in data.edge_stores: + if 'edge_index' not in edge_store: + continue + + if edge_store._key is None: + src = dst = None + else: + src, _, dst = edge_store._key + + row = n_map_dict[src][edge_store.edge_index[0]] + col = n_map_dict[dst][edge_store.edge_index[1]] + edge_store.edge_index = torch.stack([row, col], dim=0) + + old_data = copy.copy(data) + for out, node_store in zip(data.node_stores, old_data.node_stores): + for key, value in node_store.items(): + if key == 'num_nodes': + out.num_nodes = n_id_dict[node_store._key].numel() + elif node_store.is_node_attr(key) and key not in ['shifts', 'unit_shifts']: + out[key] = value[n_id_dict[node_store._key]] + + return data + + +def create_test_graph_input(output_type: str, + n_atoms: int = 3, + n_samples: int = 60, + n_states: int = 2, + random_weights = False, + add_noise = True): + """ + Util function to generate several types of mock graph data objects for testing purposes. + The graphs are created drawing positions from a predefined set of positions that cover most use cases. + It can generate: one or some configuration objects, a dataset, a datamodule, a batch of example inputs or a single item. + + Parameters + ---------- + output_type : str + Type of graph data object to create. Can be: 'configuration', 'configurations', 'datamodule', 'dataset', 'batch', 'example' + n_atoms : int, optional + Number of atoms for creating the graph, either 3 or 4, by default 3 + n_samples : int, optional + Number of samples per state to create, by default 60 + n_states : int, optional + Number of states for which to create data, by default 2. Configurations are then labelled accordingly. + random_weights : bool, optional + If to assign random weights to the entries, otherwise unitary weights are given, by default False + add_noise : bool, optional + If to add a random noise for each entry to the predefined positions, by default True + + Returns + ------- + Graph data object of the chosen type + """ + if n_atoms == 3: + numbers = [8, 1, 1] + node_labels = np.array([[0], [1], [1]]) + _ref_positions = np.array( + [ + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], + [[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07]], + [[0.0, 0.0, 0.0], [0.11, 0.11, 0.11], [-0.07, 0.0, 0.07]], + [[0.1, 0.0, 1.1], [0.17, 0.07, 1.1], [0.17, -0.07, 1.1]], + ], + dtype=np.float64 + ) + + elif n_atoms == 4: + numbers = [8, 1, 1, 8] + node_labels = np.array([[0], [1], [1], [0]]) + _ref_positions = np.array( + [ + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0] , [0.07, -0.07, 0.0], [0.05, -0.05, 0.0]], + [[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0], [0.05, 0.05, 0.0]], + [[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0], [0.05, 0.05, 0.0]], + [[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07], [0.0, 0.05, 0.05]], + [[0.0, 0.0, 0.0], [0.11, 0.11, 0.11] , [-0.07, 0.0, 0.07], [-0.05, 0.0, 0.05]], + [[0.1, 0.0, 1.1], [0.17, 0.07, 1.1] , [0.17, -0.07, 1.1], [0.15, -0.05, 1.1]], + ], + dtype=np.float64 + ) + else: + raise ValueError(f'Example input can be generated either with 3 or 4 atoms, found {n_atoms}') + + + idx = np.random.randint(low=0, high=6, size=(n_samples*n_states)) + positions = _ref_positions[idx, :, :] + + # let's add some noise to the positions for fun + if add_noise: + noise = np.random.randn(*positions.shape)*1e-5 + positions = positions + noise + + cell = np.identity(3, dtype=float) * 0.2 + graph_labels = np.zeros((n_samples*n_states, 1, 1)) + for i in range(1, n_states): + graph_labels[n_samples * i :] += 1 + z_table = atomic.AtomicNumberTable.from_zs(numbers) + + if random_weights: + weights = np.random.random_sample((n_samples*n_states, 1, 1)) + else: + weights = np.ones((n_samples*n_states, 1, 1)) + config = [ + atomic.Configuration( + atomic_numbers=numbers, + positions=positions[i], + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels[i], + weight=weights[i] + ) for i in range(0, n_samples*n_states) + ] + + if output_type == 'configuration': + return config[0] + if output_type == 'configurations': + return config + + dataset = create_dataset_from_configurations( + config, z_table, 0.1, show_progress=False, remove_isolated_nodes=True + ) + + if output_type == 'dataset': + return dataset + + datamodule = DictModule( + dataset, + lengths=(0.8, 0.2), + batch_size=0, + shuffle=False, + ) + + if output_type == 'datamodule': + return datamodule + + datamodule.setup() + batch = next(iter(datamodule.train_dataloader())) + if output_type == 'batch': + return batch + example = batch['data_list'].get_example(0) + example['batch'] = torch.zeros(len(example['positions']), dtype=torch.int64) + if output_type == 'example': + return example + + return None + +def create_graph_tracing_example(n_species : int): + """ + Util to create a tracing example for graph based models. + + Parameters + ---------- + n_species : int + Number of chemical species to be considered in the model. + + Returns + ------- + dict + Tracing graph input example as dict. + """ + numbers = [1, 1, 1] + node_labels = np.array([[0], [0], [0]]) + _ref_positions = np.array( + [ + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], + [[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0]], + [[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07]], + [[0.0, 0.0, 0.0], [0.11, 0.11, 0.11], [-0.07, 0.0, 0.07]], + [[0.1, 0.0, 1.1], [0.17, 0.07, 1.1], [0.17, -0.07, 1.1]], + ], + dtype=np.float64 + ) + + idx = np.random.randint(low=0, high=6, size=1) + positions = _ref_positions[idx, :, :] + cell = np.identity(3, dtype=float) * 0.2 + graph_labels = np.zeros((1, 1, 1)) + + z_table = atomic.AtomicNumberTable.from_zs(numbers) + + weights = np.ones((1, 1, 1)) + config = [ + atomic.Configuration( + atomic_numbers=numbers, + positions=positions[i], + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels[i], + weight=weights[i] + ) for i in range(0, 1) + ] + + # here we do not remove isolated nodes + dataset = create_dataset_from_configurations( + config, z_table, 0.1, show_progress=False, remove_isolated_nodes=False + ) + + datamodule = DictModule( + dataset, + lengths=(0.8, 0.2), + batch_size=0, + shuffle=False, + ) + + datamodule.setup() + batch = next(iter(datamodule.train_dataloader())) + example = batch['data_list'].get_example(0) + example['batch'] = torch.zeros(len(example['positions']), dtype=torch.int64) + + example = example.to_dict() + example['node_attrs'] = torch.cat((example['node_attrs'], torch.zeros(3, n_species - 1)), 1) + return example + +# =============================================================================== +# =============================================================================== +# ==================================== TESTS ==================================== +# =============================================================================== +# =============================================================================== + +import numpy as np + +def test_to_one_hot() -> None: + i = torch.tensor([[0], [2], [1]], dtype=torch.int64) + e = to_one_hot(i, 4) + assert ( + e == torch.tensor( + [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0]], dtype=torch.int64 + ) + ).all() + +def test_from_configuration() -> None: + # fake atomic numbers, positions, cell, graph label, node labels + numbers = [8, 1, 1] + positions = np.array([[0.0, 0.0, 0.0], + [0.07, 0.07, 0.0], + [0.07, -0.07, 0.0]], + dtype=float + ) + cell = np.identity(3, dtype=float) * 0.2 + graph_labels = np.array([[1]]) + node_labels = np.array([[0], [1], [1]]) + + # init AtomicNumber object + z_table = atomic.AtomicNumberTable.from_zs(numbers) + + # initialize configuration using all atoms + config = atomic.Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + ) + + # create dataset from a configuration + data = _create_dataset_from_configuration(config, z_table, 0.1) + + # check edges and shifts are created correctly + assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2], + [2, 1, 0, 2, 1, 0]]) + ).all() + + assert(data['shifts'] == torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, -0.2, 0.0], + [0.0, 0.0, 0.0]]) + ).all() + + assert(data['unit_shifts'] == torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 0.0]]) + ).all() + + # check correct storage + assert(data['positions'] == torch.tensor([[0.0, 0.0, 0.0], + [0.07, 0.07, 0.0], + [0.07, -0.07, 0.0]]) + ).all() + + assert(data['cell'] == torch.tensor([[0.2, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, 0.0, 0.2]]) + ).all() + + assert(data['node_attrs'] == torch.tensor([[0.0, 1.0], + [1.0, 0.0], + [1.0, 0.0]]) + ).all() + + assert(data['node_labels'] == torch.tensor([[0.0], + [1.0], + [1.0]]) + ).all() + + assert(data['graph_labels'] == torch.tensor([[1.0]])).all() + assert(data['weight'] == 1.0) + + # initialize configuration using two atoms (1 system, 1 env) as a subset + config = atomic.Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + system=[1], + environment=[2] + ) + + data = _create_dataset_from_configuration(config, z_table, 0.1) + + # check edges and shift are computed correctly + assert(data['edge_index'] == torch.tensor([[1, 2], + [2, 1]]) + ).all() + assert (data['shifts'] == torch.tensor([[0.0, 0.2, 0.0], + [0.0, -0.2, 0.0]]) + ).all() + assert(data['unit_shifts'] == torch.tensor([[0.0, 1.0, 0.0], + [0.0, -1.0, 0.0]]) + ).all() + + # initialize configuration using three atoms (1 system, 2 env) as a subset and no buffer + config = atomic.Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + system=[0], + environment=[1, 2] + ) + + data = _create_dataset_from_configuration(config, z_table, 0.1) + assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2], + [2, 1, 0, 2, 1, 0]]) + ).all() + + + # check if pbc and cutoffs works. now the third atoms is too far + positions = np.array([[0.0, 0.0, 0.0], + [0.07, 0.07, 0.0], + [0.07, -0.08, 0.0]], + dtype=float + ) + + config = atomic.Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + system=[0], + environment=[1, 2] + ) + # create dataset with same cutoff + data = _create_dataset_from_configuration(config, z_table, 0.1) + + # check third atom is not included anymore + assert (data['edge_index'] == torch.tensor([[0, 1], + [1, 0]]) + ).all() + + # create dataset with slightly large cutoff + data = _create_dataset_from_configuration(config, z_table, 0.11) + + # check the edge with the third atom is created once again + assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2], + [2, 1, 0, 2, 1, 0]]) + ).all() + + # check with buffer layer + # the third atoms should be included but with no edge to the system atom + data = _create_dataset_from_configuration(config, z_table, 0.1, 0.01) + assert(data['edge_index'] == torch.tensor([[0, 1, 1, 2], + [1, 0, 2, 1]]) + ).all() + assert(data['shifts'] == torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, -0.2, 0.0]]) + ).all() + assert(data['unit_shifts'] == torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, -1.0, 0.0]]) + ).all() + + # create a list of configurations + config = [atomic.Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=np.array([[i]]), + ) for i in range(0, 10)] + + # create dataset from list of configurations + dataset = create_dataset_from_configurations(config, + z_table, + 0.1, + show_progress=False) + + # check if the labels of the entries are created correctly + assert dataset.metadata['z_table'] == [1, 8] + assert (dataset[0]['data_list']['graph_labels'] == torch.tensor([[0.0]])).all() + assert (dataset[2]['data_list']['graph_labels'] == torch.tensor([[2.0]])).all() + assert (dataset[4]['data_list']['graph_labels'] == torch.tensor([[4.0]])).all() + + # dataset_1 = dataset[np.array([0, -1])] + assert dataset.metadata['z_table'] == [1, 8] + assert (dataset[ 0]['data_list']['graph_labels'] == torch.tensor([[0.0]])).all() + assert (dataset[-1]['data_list']['graph_labels'] == torch.tensor([[9.0]])).all() + + + +def test_from_configurations() -> None: + # fake atomic numbers, positions, cell, graph label, node labels + numbers = [8, 1, 1] + positions = np.array([[0.0, 0.0, 0.0], + [0.07, 0.07, 0.0], + [0.07, -0.07, 0.0]], + dtype=float + ) + cell = np.identity(3, dtype=float) * 0.2 + graph_labels = np.array([[1]]) + node_labels = np.array([[0], [1], [1]]) + + # init AtomicNumber object + z_table = atomic.AtomicNumberTable.from_zs(numbers) + + # initialize configuration using all atoms + config = atomic.Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + ) + + # create dataset from a configuration, even if single is the multiple function + dataset = create_dataset_from_configurations([config], + z_table, + 0.1, + remove_isolated_nodes=True, + show_progress=False + )[0] + + # take data entry from the DictDataset + data = dataset['data_list'] + + # check edges and shifts are created correctly + assert(data['edge_index'] == torch.tensor([[0, 0, 1, 1, 2, 2], + [2, 1, 0, 2, 1, 0]]) + ).all() + assert(data['shifts'] == torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, -0.2, 0.0], + [0.0, 0.0, 0.0]]) + ).all() + + assert(data['unit_shifts'] == torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 0.0]]) + ).all() + + # check correct storage + assert(data['positions'] == torch.tensor([[0.0, 0.0, 0.0], + [0.07, 0.07, 0.0], + [0.07, -0.07, 0.0]]) + ).all() + + assert(data['cell'] == torch.tensor([[0.2, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, 0.0, 0.2]]) + ).all() + + assert(data['node_attrs'] == torch.tensor([[0.0, 1.0], + [1.0, 0.0], + [1.0, 0.0]]) + ).all() + assert(data['node_labels'] == torch.tensor([[0.0], + [1.0], + [1.0]]) + ).all() + assert(data['graph_labels'] == torch.tensor([[1.0]])).all() + assert(data['weight'] == 1.0) + + # initialize configuration using three atoms (1 system, 2 env) as a subset and no buffer + config = atomic.Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + system=[1], + environment=[2] + ) + dataset = create_dataset_from_configurations([config], + z_table, + 0.1, + remove_isolated_nodes=True, + show_progress=False + )[0] + + # take data entry from the DictDataset + data = dataset['data_list'] + + assert(data['positions'] == torch.tensor([[0.07, 0.07, 0.0], + [0.07, -0.07, 0.0]]) + ).all() + assert(data['cell'] == torch.tensor([[0.2, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, 0.0, 0.2]]) + ).all() + assert(data['node_attrs'] == torch.tensor([[1.0, 0.0], + [1.0, 0.0]]) + ).all() + assert(data['edge_index'] == torch.tensor([[0, 1], + [1, 0]]) + ).all() + assert(data['shifts'] == torch.tensor([[0.0, 0.2, 0.0], + [0.0, -0.2, 0.0]]) + ).all() + assert(data['unit_shifts'] == torch.tensor([[0.0, 1.0, 0.0], + [0.0, -1.0, 0.0]]) + ).all() + +if __name__ == '__main__': + test_to_one_hot() + test_from_configuration() + test_from_configurations() \ No newline at end of file diff --git a/mlcolvar/data/utils.py b/mlcolvar/data/utils.py new file mode 100644 index 00000000..46deee84 --- /dev/null +++ b/mlcolvar/data/utils.py @@ -0,0 +1,151 @@ +import torch +import numpy as np + +from mlcolvar.data import DictDataset +from mlcolvar.data.graph.atomic import AtomicNumberTable + +__all__ = ["save_dataset", "load_dataset", "save_dataset_configurations_as_extyz"] + +def save_dataset(dataset: DictDataset, file_name: str) -> None: + """Save a dataset to disk. + + Parameters + ---------- + dataset: DictDataset + Dataset to be saved + file_name: str + Name of the file to save to + """ + assert isinstance(dataset, DictDataset) + + torch.save(dataset, file_name) + + +def load_dataset(file_name: str) -> DictDataset: + """Load a dataset from disk. + + Parameters + ---------- + file_name: str + Name of the file to load the dataset from + """ + dataset = torch.load(file_name) + + assert isinstance(dataset, DictDataset) + + return dataset + + +def save_dataset_configurations_as_extyz(dataset: DictDataset, file_name: str) -> None: + """Save a dataset to disk in the extxyz format. + + Parameters + ---------- + dataset: DictDataset + Dataset to be saved with data_type graphs + file_name: str + Name of the file to save to + """ + # check the dataset type is 'graphs' + if not dataset.metadata["data_type"] == "graphs": + raise( + ValueError("Can only save to extxyz dataset with data_type='graphs'!") + ) + + # initialize the atomic number object + z_table = AtomicNumberTable.from_zs(dataset.metadata["z_table"]) + + # create file + fp = open(file_name, 'w') + + for i in range(len(dataset)): + d = dataset[i]['data_list'] + + # print number of atoms + print(len(d['positions']), file=fp) + + # header line for configuration d + # Lattice, properties, pbc + line = ( + 'Lattice="{:s}" '.format((r'{:.5f} ' * 9).strip()) + + 'Properties=species:S:1:pos:R:3 pbc="T T T"' + ) + + # cell info + cell = [c.item() for c in d['cell'].flatten()] + print(line.format(*cell), file=fp) + + # write atoms positions + for j in range(0, len(d['positions'])): + # chemical symbol + s = z_table.index_to_symbol(np.where(d['node_attrs'][j])[0][0]) + print('{:2s}'.format(s), file=fp, end=' ') + + # positions + positions = [p.item() for p in d['positions'][j]] + print('{:10.5f} {:10.5f} {:10.5f}'.format(*positions), file=fp) + fp.close() + + + + +import tempfile + +def test_save_dataset(): + # check using descriptors dataset + dataset_dict = { + "data": torch.Tensor([[1.0], [2.0], [0.3], [0.4]]), + "labels": [0, 0, 1, 1], + "weights": np.asarray([0.5, 1.5, 1.5, 0.5]), + } + dataset = DictDataset(dataset_dict) + + # save to temporary working directory + with tempfile.TemporaryDirectory() as tmpdir: + save_dataset(dataset=dataset, file_name=f'{tmpdir}/saved_dataset') + + # load and check it's ok + loaded = load_dataset(file_name=f'{tmpdir}/saved_dataset') + assert(torch.allclose(dataset['data'], loaded['data'])) + + # check using graph dataset + from mlcolvar.data.graph.atomic import AtomicNumberTable, Configuration + from mlcolvar.data.graph.utils import create_dataset_from_configurations + numbers = [8, 1, 1] + positions = np.array( + [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], + dtype=float + ) + cell = np.identity(3, dtype=float) * 0.2 + graph_labels = np.array([[1]]) + node_labels = np.array([[0], [1], [1]]) + z_table = AtomicNumberTable.from_zs(numbers) + + config = [Configuration( + atomic_numbers=numbers, + positions=positions, + cell=cell, + pbc=[True] * 3, + node_labels=node_labels, + graph_labels=graph_labels, + )] + dataset = create_dataset_from_configurations( + config, z_table, 0.1, show_progress=False + ) + + # save dataset + with tempfile.TemporaryDirectory() as tmpdir: + save_dataset(dataset=dataset, file_name=f'{tmpdir}/saved_dataset') + + # load and check it's ok + loaded = load_dataset(file_name=f'{tmpdir}/saved_dataset') + assert(torch.allclose(dataset['data_list'][0]['positions'], loaded['data_list'][0]['positions'])) + + # save to extxyz + with tempfile.TemporaryDirectory() as tmpdir: + save_dataset_configurations_as_extyz(dataset=dataset, file_name=f'{tmpdir}/saved_dataset') + +if __name__ == "__main__": + test_save_dataset() + + diff --git a/mlcolvar/explain/__init__.py b/mlcolvar/explain/__init__.py index 7cfef572..40afa6c4 100644 --- a/mlcolvar/explain/__init__.py +++ b/mlcolvar/explain/__init__.py @@ -1,7 +1,9 @@ __all__ = [ "sensitivity_analysis", "plot_sensitivity", + "graph_node_sensitivity" ] from .sensitivity import * +from .graph_sensitivity import * # from .lasso import * # lasso requires additional dependencies diff --git a/mlcolvar/explain/graph_sensitivity.py b/mlcolvar/explain/graph_sensitivity.py new file mode 100644 index 00000000..c6934c22 --- /dev/null +++ b/mlcolvar/explain/graph_sensitivity.py @@ -0,0 +1,277 @@ +import numpy as np +from typing import Dict +import torch + +from mlcolvar.data import DictModule +from mlcolvar.utils.plot import pbar +from mlcolvar.core.nn import BaseGNN + + +__all__ = ['graph_node_sensitivity'] + + +def graph_node_sensitivity( + model, + dataset, + component: int = 0, + device: str = 'cpu', + batch_size: int = None, + show_progress: bool = True +) -> Dict[str, np.ndarray]: + """Performs a sensitivity analysis on a GNN-based CV model using + partial derivatives w.r.t. nodes' positions. + This allows us to measure which atom is most important to the CV model. + + Parameters + ---------- + model: mlcolvar.cvs.BaseCV + Collective variable model based on GNN + dataset: mlcovar.data.DictDataset + Graph-based dataset on which to compute the sensitivity analysis + device: str + Name of the device on which to perform the computation + batch_size: + Batch size used for evaluating the CV + show_progress: bool + If to show the progress bar + + Returns + ------- + results: dictionary + Results of the sensitivity analysis, containing 'node_indices', + 'sensitivities', and 'sensitivities_components', ordered according to + the node indices. + + See also + -------- + mlcolvar.utils.explain.sensitivity_analysis + Perform the sensitivity analysis of a feedforward model. + """ + # check model is GNN-based + if not isinstance(model.nn, BaseGNN): + raise ValueError ( + "The CV model is not based on GNN! Maybe you should use the feedforward sensitivity_analysis from mlcolvar.utils.explain.sensitivity!" + ) + + model = model.to(device) + + gradients = get_dataset_cv_gradients( + model=model, + dataset=dataset, + component=component, + batch_size=batch_size, + show_progress=show_progress, + progress_prefix='Getting gradients' + ) + sensitivities_components = np.linalg.norm(gradients, axis=-1) + + results = {} + results['atoms_list'] = np.array(dataset.metadata['used_names']) + results['node_labels'] = [str(a) for a in results['atoms_list']] + # results['node_labels_components'] = np.array([np.array(dataset.metadata['used_names'])[dataset[i]['data_list']['names_idx']] for i in range(len(dataset))]) + results['sensitivities'] = sensitivities_components.mean(axis=0) + results['sensitivities_components'] = sensitivities_components + + return results + +def get_dataset_cv_values( + model, + dataset, + batch_size: int = None, + show_progress: bool = True, + progress_prefix: str = 'Calculating CV values' +) -> np.ndarray: + """Gets the values of a CV model on a given dataset. + The calculation will run on the device where the model is on. + + Parameters + ---------- + model: mlcolvar.cvs.BaseCV + Collective variable model + dataset: mlcovar.data.DictDataset + Dataset on which to compute the sensitivity analysis + batch_size: + Batch size used for evaluating the CV + show_progress: bool + If to show the progress bar + """ + datamodule = DictModule( + dataset=dataset, + lengths=(1.0,), + batch_size=batch_size, + random_split=False, + shuffle=False + ) + datamodule.setup() + + cv_values = [] + device = next(model.parameters()).device + + if show_progress: + items = pbar( + datamodule.train_dataloader(), + frequency=0.001, + prefix=progress_prefix + ) + else: + items = datamodule.train_dataloader() + + with torch.no_grad(): + for batchs in items: + outputs = model(batchs['data_list'].to(device).to_dict()) + outputs = outputs.cpu().numpy() + cv_values.append(outputs) + + return np.concatenate(cv_values) + + +def get_dataset_cv_gradients( + model, + dataset, + component: int = 0, + batch_size: int = None, + show_progress: bool = True, + progress_prefix: str = 'Calculating CV gradients' +) -> np.ndarray: + """Get gradients of a GNN-based CV w.r.t. node positions in a given dataset. + The calculation will run on the device where the model is on. + + Parameters + ---------- + model: mlcolvar.cvs.BaseCV + Collective variable model based on GNN + dataset: mlcovar.data.DictDataset + Graph-based dataset on which to compute the sensitivity analysis + component: int + Component of the CV to analyse + batch_size: + Batch size used for evaluating the CV + show_progress: bool + If to show the progress bar + """ + datamodule = DictModule( + dataset=dataset, + lengths=(1.0,), + batch_size=batch_size, + random_split=False, + shuffle=False + ) + datamodule.setup() + + cv_value_gradients = [] + device = next(model.parameters()).device + + if show_progress: + items = pbar( + datamodule.train_dataloader(), + frequency=0.001, + prefix=progress_prefix + ) + else: + items = datamodule.train_dataloader() + + for batchs in items: + batch_dict = batchs['data_list'].to(device) + batch_dict['positions'].requires_grad_(True) + cv_values = model(batch_dict) + cv_values = cv_values[:, component] + grad_outputs = [torch.ones_like(cv_values, device=device)] + gradients = torch.autograd.grad( + outputs=[cv_values], + inputs=[batch_dict['positions']], + grad_outputs=grad_outputs, + retain_graph=False, + create_graph=False, + ) + graph_sizes = batch_dict['ptr'][1:] - batch_dict['ptr'][:-1] + + # if we used the removed isolated atoms this will give an inhomogenous tensor! + gradients = torch.split( + gradients[0].detach(), graph_sizes.cpu().numpy().tolist() + ) + + # here we ensure that all the gradients have the correct shape + # and that each entry is at the correct index accordingly + max_used_atoms = len(dataset.metadata['used_idx']) + for i,g in enumerate(gradients): + aux = torch.zeros((max_used_atoms, 3)) + # this populates the right entries according to the orignal indexing + aux[batch_dict[i]['names_idx'], :] = g + cv_value_gradients.extend(aux.unsqueeze(0).cpu().numpy()) + + return np.array(cv_value_gradients) + + +def test_get_cv_values_graph(): + import lightning + from mlcolvar.cvs import DeepTDA + from mlcolvar.core.nn.graph import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + + # create data, we need the dataset for sensitivity analysis later + dataset = create_test_graph_input(output_type='dataset', n_samples=50, n_states=2, n_atoms=3) + datamodule = DictModule(dataset=dataset, lengths=[0.8, 0.2], shuffle=[1, 0]) + + # create model + gnn_model = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[8, 1]) + model = DeepTDA( + n_states=2, + n_cvs=1, + target_centers=[-5, 5], + target_sigmas=[0.2, 0.2], + model=gnn_model + ) + + # train model + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=2, logger=False, enable_checkpointing=False, enable_model_summary=False + ) + trainer.fit(model, datamodule) + + # do analysis + cv_values = get_dataset_cv_values(model=model, dataset=dataset, batch_size=0) + + # print results + print(cv_values) + + assert (torch.allclose(model(dataset.get_graph_inputs()), torch.Tensor(cv_values))) + + + +def test_graph_sensitivity(): + import lightning + from mlcolvar.cvs import DeepTDA + from mlcolvar.core.nn.graph import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + + # create data, we need the dataset for sensitivity analysis later + dataset = create_test_graph_input(output_type='dataset', n_samples=100, n_states=2, n_atoms=3) + datamodule = DictModule(dataset=dataset, lengths=[0.8, 0.2], shuffle=[1, 0]) + + # create model + gnn_model = SchNetModel(n_out=1, cutoff=0.1, atomic_numbers=[8, 1]) + model = DeepTDA( + n_states=2, + n_cvs=1, + target_centers=[-5, 5], + target_sigmas=[0.2, 0.2], + model=gnn_model + ) + + # train model + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=2, logger=False, enable_checkpointing=False, enable_model_summary=False + ) + trainer.fit(model, datamodule) + + # do analysis + test_sensitivity = graph_node_sensitivity(model=model, + dataset=dataset, + batch_size=0) + + # print results + print(test_sensitivity) + +if __name__ == '__main__': + test_graph_sensitivity() + test_get_cv_values_graph() \ No newline at end of file diff --git a/mlcolvar/explain/lasso.py b/mlcolvar/explain/lasso.py index 55d90248..ca44f837 100644 --- a/mlcolvar/explain/lasso.py +++ b/mlcolvar/explain/lasso.py @@ -3,7 +3,6 @@ import matplotlib import matplotlib.pyplot as plt -import mlcolvar.utils.plot try: import sklearn diff --git a/mlcolvar/explain/sensitivity.py b/mlcolvar/explain/sensitivity.py index a553a4ba..f6f37451 100644 --- a/mlcolvar/explain/sensitivity.py +++ b/mlcolvar/explain/sensitivity.py @@ -2,7 +2,6 @@ import torch from matplotlib import patches as mpatches import matplotlib.pyplot as plt -import mlcolvar.utils.plot __all__ = [ "sensitivity_analysis", "plot_sensitivity" ] diff --git a/mlcolvar/tests/data/Cu.xyz b/mlcolvar/tests/data/Cu.xyz new file mode 100644 index 00000000..111af6a3 --- /dev/null +++ b/mlcolvar/tests/data/Cu.xyz @@ -0,0 +1,54 @@ +16 +Lattice="7.15486134 0.0 0.0 0.0 3.57743067 0.0 0.0 0.0 7.15486134" Properties=species:S:1:pos:R:3 +Cu 0.0 0.0 0.0 +Cu 0.0 1.78871534 1.78871534 +Cu 1.78871534 0.0 1.78871534 +Cu 1.78871534 1.78871534 0.0 +Cu 0.0 0.0 3.57743067 +Cu 0.0 1.78871534 5.36614601 +Cu 1.78871534 0.0 5.36614601 +Cu 1.78871534 1.78871534 3.57743067 +Cu 3.57743067 0.0 0.0 +Cu 3.57743067 1.78871534 1.78871534 +Cu 5.36614601 0.0 1.78871534 +Cu 5.36614601 1.78871534 0.0 +Cu 3.57743067 0.0 3.57743067 +Cu 3.57743067 1.78871534 5.36614601 +Cu 5.36614601 0.0 5.36614601 +Cu 5.36614601 1.78871534 3.57743067 +16 +Lattice="7.15486134 0.0 0.0 0.0 3.57743067 0.0 0.0 0.0 7.15486134" Properties=species:S:1:pos:R:3 +Cu 0.0 0.0 0.0 +Cu 0.0 1.78871534 1.78871534 +Cu 1.78871534 0.0 1.78871534 +Cu 1.78871534 1.78871534 0.0 +Cu 0.0 0.0 3.57743067 +Cu 0.0 1.78871534 5.36614601 +Cu 1.78871534 0.0 5.36614601 +Cu 1.78871534 1.78871534 3.57743067 +Cu 3.57743067 0.0 0.0 +Cu 3.57743067 1.78871534 1.78871534 +Cu 5.36614601 0.0 1.78871534 +Cu 5.36614601 1.78871534 0.0 +Cu 3.57743067 0.0 3.57743067 +Cu 3.57743067 1.78871534 5.36614601 +Cu 5.36614601 0.0 5.36614601 +Cu 5.36614601 1.78871534 3.57743067 +16 +Lattice="7.15486134 0.0 0.0 0.0 3.57743067 0.0 0.0 0.0 7.15486134" Properties=species:S:1:pos:R:3 +Cu 0.0 0.0 0.0 +Cu 0.0 1.78871534 1.78871534 +Cu 1.78871534 0.0 1.78871534 +Cu 1.78871534 1.78871534 0.0 +Cu 0.0 0.0 3.57743067 +Cu 0.0 1.78871534 5.36614601 +Cu 1.78871534 0.0 5.36614601 +Cu 1.78871534 1.78871534 3.57743067 +Cu 3.57743067 0.0 0.0 +Cu 3.57743067 1.78871534 1.78871534 +Cu 5.36614601 0.0 1.78871534 +Cu 5.36614601 1.78871534 0.0 +Cu 3.57743067 0.0 3.57743067 +Cu 3.57743067 1.78871534 5.36614601 +Cu 5.36614601 0.0 5.36614601 +Cu 5.36614601 1.78871534 3.57743067 \ No newline at end of file diff --git a/mlcolvar/tests/data/Cu_top.pdb b/mlcolvar/tests/data/Cu_top.pdb new file mode 100644 index 00000000..9e7a14a1 --- /dev/null +++ b/mlcolvar/tests/data/Cu_top.pdb @@ -0,0 +1,19 @@ +CRYST1 7.155 3.577 7.155 90.00 90.00 90.00 P 1 +MODEL 1 +ATOM 1 Cu MOL 1 0.000 0.000 0.000 1.00 0.00 CU +ATOM 2 Cu MOL 1 0.000 1.789 1.789 1.00 0.00 CU +ATOM 3 Cu MOL 1 1.789 0.000 1.789 1.00 0.00 CU +ATOM 4 Cu MOL 1 1.789 1.789 0.000 1.00 0.00 CU +ATOM 5 Cu MOL 1 0.000 0.000 3.577 1.00 0.00 CU +ATOM 6 Cu MOL 1 0.000 1.789 5.366 1.00 0.00 CU +ATOM 7 Cu MOL 1 1.789 0.000 5.366 1.00 0.00 CU +ATOM 8 Cu MOL 1 1.789 1.789 3.577 1.00 0.00 CU +ATOM 9 Cu MOL 1 3.577 0.000 0.000 1.00 0.00 CU +ATOM 10 Cu MOL 1 3.577 1.789 1.789 1.00 0.00 CU +ATOM 11 Cu MOL 1 5.366 0.000 1.789 1.00 0.00 CU +ATOM 12 Cu MOL 1 5.366 1.789 0.000 1.00 0.00 CU +ATOM 13 Cu MOL 1 3.577 0.000 3.577 1.00 0.00 CU +ATOM 14 Cu MOL 1 3.577 1.789 5.366 1.00 0.00 CU +ATOM 15 Cu MOL 1 5.366 0.000 5.366 1.00 0.00 CU +ATOM 16 Cu MOL 1 5.366 1.789 3.577 1.00 0.00 CU +ENDMDL diff --git a/mlcolvar/tests/data/p.dcd b/mlcolvar/tests/data/p.dcd new file mode 100644 index 00000000..9ac637e3 Binary files /dev/null and b/mlcolvar/tests/data/p.dcd differ diff --git a/mlcolvar/tests/data/p.pdb b/mlcolvar/tests/data/p.pdb new file mode 100644 index 00000000..9928ce81 --- /dev/null +++ b/mlcolvar/tests/data/p.pdb @@ -0,0 +1,21 @@ +CRYST1 100.000 100.000 100.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C UNL X 1 -2.477 -2.092 -0.388 1.00 0.00 C +ATOM 2 C UNL X 1 -3.520 -1.057 -0.204 1.00 0.00 C +ATOM 3 C UNL X 1 -2.495 -2.603 1.038 1.00 0.00 C +ATOM 4 H UNL X 1 -2.456 -1.819 1.767 1.00 0.00 H +ATOM 5 H UNL X 1 -3.448 -3.157 1.109 1.00 0.00 H +ATOM 6 H UNL X 1 -1.674 -3.291 1.255 1.00 0.00 H +ATOM 7 C UNL X 1 -2.786 -3.294 -1.306 1.00 0.00 C +ATOM 8 H UNL X 1 -2.634 -2.975 -2.351 1.00 0.00 H +ATOM 9 H UNL X 1 -2.089 -4.135 -1.164 1.00 0.00 H +ATOM 10 H UNL X 1 -3.798 -3.583 -1.079 1.00 0.00 H +ATOM 11 C UNL X 1 -1.151 -1.437 -0.717 1.00 0.00 C +ATOM 12 H UNL X 1 -0.988 -0.621 0.002 1.00 0.00 H +ATOM 13 H UNL X 1 -0.343 -2.138 -0.507 1.00 0.00 H +ATOM 14 H UNL X 1 -1.088 -1.151 -1.779 1.00 0.00 H +ATOM 15 C UNL X 1 -4.964 -1.269 -0.340 1.00 0.00 C +ATOM 16 H UNL X 1 -5.253 -1.213 -1.409 1.00 0.00 H +ATOM 17 H UNL X 1 -5.103 -2.218 0.074 1.00 0.00 H +ATOM 18 H UNL X 1 -5.607 -0.523 0.107 1.00 0.00 H +ATOM 19 F UNL X 1 -3.038 0.027 0.136 1.00 0.00 F +END diff --git a/mlcolvar/tests/data/r.dcd b/mlcolvar/tests/data/r.dcd new file mode 100644 index 00000000..c302acd2 Binary files /dev/null and b/mlcolvar/tests/data/r.dcd differ diff --git a/mlcolvar/tests/data/r.pdb b/mlcolvar/tests/data/r.pdb new file mode 100644 index 00000000..75656832 --- /dev/null +++ b/mlcolvar/tests/data/r.pdb @@ -0,0 +1,21 @@ +CRYST1 100.000 100.000 100.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C UNL X 1 -2.394 -1.013 0.390 1.00 0.00 C +ATOM 2 C UNL X 1 -2.588 -1.774 -0.881 1.00 0.00 C +ATOM 3 C UNL X 1 -2.555 -1.607 1.679 1.00 0.00 C +ATOM 4 H UNL X 1 -2.178 -2.584 1.842 1.00 0.00 H +ATOM 5 H UNL X 1 -2.075 -0.998 2.471 1.00 0.00 H +ATOM 6 H UNL X 1 -3.633 -1.736 1.862 1.00 0.00 H +ATOM 7 C UNL X 1 -1.780 0.373 0.296 1.00 0.00 C +ATOM 8 H UNL X 1 -0.759 0.225 0.823 1.00 0.00 H +ATOM 9 H UNL X 1 -1.819 0.834 -0.686 1.00 0.00 H +ATOM 10 H UNL X 1 -2.358 0.965 1.025 1.00 0.00 H +ATOM 11 C UNL X 1 -1.348 -1.777 -1.672 1.00 0.00 C +ATOM 12 H UNL X 1 -1.110 -0.748 -2.010 1.00 0.00 H +ATOM 13 H UNL X 1 -0.424 -2.178 -1.122 1.00 0.00 H +ATOM 14 H UNL X 1 -1.446 -2.390 -2.576 1.00 0.00 H +ATOM 15 C UNL X 1 -3.773 -1.199 -1.601 1.00 0.00 C +ATOM 16 H UNL X 1 -3.816 -1.599 -2.581 1.00 0.00 H +ATOM 17 H UNL X 1 -4.663 -1.441 -1.003 1.00 0.00 H +ATOM 18 H UNL X 1 -3.757 -0.102 -1.693 1.00 0.00 H +ATOM 19 F UNL X 1 -2.878 -3.051 -0.412 1.00 0.00 F +END diff --git a/mlcolvar/tests/test_core_nn_graph.py b/mlcolvar/tests/test_core_nn_graph.py new file mode 100644 index 00000000..821acc8d --- /dev/null +++ b/mlcolvar/tests/test_core_nn_graph.py @@ -0,0 +1,9 @@ +from mlcolvar.core.nn.graph.gnn import test_get_edge_vectors_and_lengths +from mlcolvar.core.nn.graph.radial import test_bessel_basis, test_gaussian_basis, test_polynomial_cutoff, test_radial_embedding_block + +if __name__ == "__main__": + test_get_edge_vectors_and_lengths() + test_bessel_basis() + test_gaussian_basis() + test_polynomial_cutoff() + test_radial_embedding_block() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_nn_graph_gvp.py b/mlcolvar/tests/test_core_nn_graph_gvp.py new file mode 100644 index 00000000..dcbc8176 --- /dev/null +++ b/mlcolvar/tests/test_core_nn_graph_gvp.py @@ -0,0 +1,4 @@ +from mlcolvar.core.nn.graph.gvp import test_gvp + +if __name__ == "__main__": + test_gvp() \ No newline at end of file diff --git a/mlcolvar/tests/test_core_nn_graph_schnet.py b/mlcolvar/tests/test_core_nn_graph_schnet.py new file mode 100644 index 00000000..2d83fdbf --- /dev/null +++ b/mlcolvar/tests/test_core_nn_graph_schnet.py @@ -0,0 +1,5 @@ +from mlcolvar.core.nn.graph.schnet import test_schnet_1, test_schnet_2 + +if __name__ == "__main__": + test_schnet_1() + test_schnet_2() \ No newline at end of file diff --git a/mlcolvar/tests/test_cvs.py b/mlcolvar/tests/test_cvs.py index 40d1425a..46f44d97 100644 --- a/mlcolvar/tests/test_cvs.py +++ b/mlcolvar/tests/test_cvs.py @@ -64,10 +64,10 @@ def dataset(): # ============================================================================= @pytest.mark.parametrize("cv_model", [ - mlcolvar.cvs.DeepLDA(layers=LAYERS, n_states=N_STATES), - mlcolvar.cvs.DeepTDA(n_states=N_STATES, n_cvs=1, target_centers=[-1., 1.], target_sigmas=[0.1, 0.1], layers=LAYERS), - mlcolvar.cvs.RegressionCV(layers=LAYERS), - mlcolvar.cvs.DeepTICA(layers=LAYERS, n_cvs=1), + mlcolvar.cvs.DeepLDA(model=LAYERS, n_states=N_STATES), + mlcolvar.cvs.DeepTDA(n_states=N_STATES, n_cvs=1, target_centers=[-1., 1.], target_sigmas=[0.1, 0.1], model=LAYERS), + mlcolvar.cvs.RegressionCV(model=LAYERS), + mlcolvar.cvs.DeepTICA(model=LAYERS, n_cvs=1), mlcolvar.cvs.AutoEncoderCV(encoder_layers=LAYERS), mlcolvar.cvs.VariationalAutoEncoderCV(n_cvs=1, encoder_layers=LAYERS[:-1]), ]) @@ -113,7 +113,7 @@ def test_lr_scheduler(): initial_lr = 1e-3 options = {'optimizer' : {'lr' : initial_lr}, 'lr_scheduler' : { 'scheduler' : lr_scheduler, 'gamma' : 0.9999}} - model = mlcolvar.cvs.RegressionCV(layers=[2,5,1], options=options) + model = mlcolvar.cvs.RegressionCV(model=[2,5,1], options=options) # check training and lr scheduling trainer = lightning.Trainer(max_epochs=10, diff --git a/mlcolvar/tests/test_cvs_committor.py b/mlcolvar/tests/test_cvs_committor.py index ac941394..aa658601 100644 --- a/mlcolvar/tests/test_cvs_committor.py +++ b/mlcolvar/tests/test_cvs_committor.py @@ -1,5 +1,10 @@ -from mlcolvar.cvs.committor.committor import test_committor, test_committor_with_derivatives +from mlcolvar.cvs.committor.committor import test_committor_1, test_committor_2 , test_committor_with_derivatives +from mlcolvar.cvs.committor.utils import test_compute_committor_weights, test_Kolmogorov_bias + if __name__ == "__main__": - test_committor() - test_committor_with_derivatives() \ No newline at end of file + test_committor_1() + test_committor_2() + test_committor_with_derivatives() + test_Kolmogorov_bias() + test_compute_committor_weights() \ No newline at end of file diff --git a/mlcolvar/tests/test_cvs_multitask_multitask.py b/mlcolvar/tests/test_cvs_multitask_multitask.py index 67a23087..75106305 100644 --- a/mlcolvar/tests/test_cvs_multitask_multitask.py +++ b/mlcolvar/tests/test_cvs_multitask_multitask.py @@ -21,6 +21,7 @@ import lightning import torch +from mlcolvar.core.nn import FeedForward from mlcolvar.core.loss import TDALoss, FisherDiscriminantLoss, AutocorrelationLoss from mlcolvar.cvs.cv import BaseCV from mlcolvar.cvs.multitask.multitask import MultiTaskCV @@ -62,11 +63,13 @@ def forward(self, data, data_lag=None, **kwargs): class MockCV(BaseCV, lightning.LightningModule): """Mock CV for mock testing.""" - BLOCKS = [] + DEFAULT_BLOCKS = [] + MODEL_BLOCKS = [] def __init__(self, in_features=N_DESCRIPTORS, out_features=N_CVS): """Constructor.""" - super().__init__(in_features=in_features, out_features=out_features) + model = FeedForward(layers=[in_features, in_features]) + super().__init__(model=model) self.loss_fn = MockAuxLoss(in_features, out_features) def training_step(self, train_batch, batch_idx): @@ -129,7 +132,7 @@ def create_cv(cv_name, n_descriptors=N_DESCRIPTORS, n_cvs=N_CVS): n_cvs=n_cvs, encoder_layers=[n_descriptors, 10] ) elif cv_name == "deeptica": - returned = "time-lagged", DeepTICA(layers=[n_descriptors, 10, n_cvs]) + returned = "time-lagged", DeepTICA(model=[n_descriptors, 10, n_cvs]) else: raise ValueError("Unrecognized cv_name.") diff --git a/mlcolvar/tests/test_cvs_supervised_tda.py b/mlcolvar/tests/test_cvs_supervised_tda.py index c1ec6bac..b3300b7f 100644 --- a/mlcolvar/tests/test_cvs_supervised_tda.py +++ b/mlcolvar/tests/test_cvs_supervised_tda.py @@ -1,4 +1,6 @@ from mlcolvar.cvs.supervised.deeptda import test_deeptda_cv +from mlcolvar.core.loss.tda_loss import test_tda_loss if __name__ == "__main__": test_deeptda_cv() + test_tda_loss() diff --git a/mlcolvar/tests/test_data_graph.py b/mlcolvar/tests/test_data_graph.py new file mode 100644 index 00000000..9c18f8c7 --- /dev/null +++ b/mlcolvar/tests/test_data_graph.py @@ -0,0 +1,6 @@ +from mlcolvar.data.graph.atomic import test_atomic_number_table +from mlcolvar.data.graph.neighborhood import test_get_neighborhood + +if __name__ == '__main__': + test_atomic_number_table() + test_get_neighborhood() \ No newline at end of file diff --git a/mlcolvar/tests/test_data_graph_utils.py b/mlcolvar/tests/test_data_graph_utils.py new file mode 100644 index 00000000..f765c6a9 --- /dev/null +++ b/mlcolvar/tests/test_data_graph_utils.py @@ -0,0 +1,6 @@ +from mlcolvar.data.graph.utils import test_from_configuration, test_from_configurations, test_to_one_hot + +if __name__ == "main": + test_to_one_hot() + test_from_configuration() + test_from_configurations() \ No newline at end of file diff --git a/mlcolvar/tests/test_data_utils.py b/mlcolvar/tests/test_data_utils.py new file mode 100644 index 00000000..6ef6eb43 --- /dev/null +++ b/mlcolvar/tests/test_data_utils.py @@ -0,0 +1,4 @@ +from mlcolvar.data.utils import test_save_dataset + +if __name__=="main": + test_save_dataset() \ No newline at end of file diff --git a/mlcolvar/tests/test_explain_sensitivity.py b/mlcolvar/tests/test_explain_sensitivity.py index 0b80078d..69d4edef 100644 --- a/mlcolvar/tests/test_explain_sensitivity.py +++ b/mlcolvar/tests/test_explain_sensitivity.py @@ -1,6 +1,9 @@ import pytest from mlcolvar.explain.sensitivity import test_sensitivity_analysis +from mlcolvar.explain.graph_sensitivity import test_graph_sensitivity, test_get_cv_values_graph if __name__ == "__main__": test_sensitivity_analysis() + test_graph_sensitivity() + test_get_cv_values_graph() diff --git a/mlcolvar/tests/test_utils_io.py b/mlcolvar/tests/test_utils_io.py index 629ff095..173f49b3 100644 --- a/mlcolvar/tests/test_utils_io.py +++ b/mlcolvar/tests/test_utils_io.py @@ -2,6 +2,9 @@ import urllib from mlcolvar.utils.io import load_dataframe from mlcolvar.utils.io import test_datasetFromFile +from mlcolvar.utils.io import test_datasesetFromTrajectories +from mlcolvar.utils.io import test_create_dataset_from_trajectories +from mlcolvar.utils.io import test_dataset_from_xyz example_files = { "str": "mlcolvar/tests/data/state_A.dat", @@ -23,6 +26,66 @@ def test_loadDataframe(file_type): df = load_dataframe(filename, start=0, stop=10, stride=1) +inputs = [""" +CRYST1 2.000 2.000 2.000 90.00 90.00 90.00 P 1 1 +ATOM 1 OH2 TIP3W 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 H1 TIP3W 1 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 3 H2 TIP3W 1 0.700 -0.700 0.000 1.00 0.00 WT1 H +ENDMODEL +ATOM 1 OH2 TIP3W 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 H1 TIP3W 1 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 3 H2 TIP3W 1 0.700 -0.700 0.000 1.00 0.00 WT1 H +END +""", +""" +CRYST1 2.000 2.000 2.000 90.00 90.00 90.00 P 1 1 +ATOM 1 OH2 TIP3W 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 H1 TIP3W 1 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 3 H2 TIP3W 1 0.700 -0.700 0.000 1.00 0.00 WT1 H +ATOM 4 OH2 XXXXW 2 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 5 H1 XXXXW 2 0.300 0.300 0.000 1.00 0.00 WT1 H +ATOM 6 H2 XXXXW 2 0.300 -0.300 0.000 1.00 0.00 WT1 H +ENDMODEL +ATOM 1 OH2 TIP3W 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 H1 TIP3W 1 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 3 H2 TIP3W 1 0.700 -0.700 0.000 1.00 0.00 WT1 H +ATOM 4 OH2 XXXXW 2 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 5 H1 XXXXW 2 0.300 0.300 0.000 1.00 0.00 WT1 H +ATOM 6 H2 XXXXW 2 0.300 -0.300 0.000 1.00 0.00 WT1 H +END +""", +""" +CRYST1 2.000 2.000 2.000 90.00 90.00 90.00 P 1 1 +ATOM 1 OH2 XXXXW 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 OH2 TIP3W 2 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 3 H1 XXXXW 1 0.300 0.300 0.000 1.00 0.00 WT1 H +ATOM 4 H1 TIP3W 2 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 5 H2 XXXXW 1 0.300 -0.300 0.000 1.00 0.00 WT1 H +ATOM 6 H2 TIP3W 2 0.700 -0.700 0.000 1.00 0.00 WT1 H +ENDMODEL +ATOM 1 OH2 XXXXW 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 OH2 TIP3W 2 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 3 H1 XXXXW 1 0.300 0.300 0.000 1.00 0.00 WT1 H +ATOM 4 H1 TIP3W 2 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 5 H2 XXXXW 1 0.300 -0.300 0.000 1.00 0.00 WT1 H +ATOM 6 H2 TIP3W 2 0.700 -0.700 0.000 1.00 0.00 WT1 H +END +""" +] + +@pytest.mark.parametrize("text,selection", + [(inputs[0], None), + (inputs[1], 'not resname XXXX'), + (inputs[2], 'not resname XXXX') + ] + ) +# @pytest.mark.parametrize("text", inputs) +def test_dataset_from_trajectories(text, selection): + print(selection) + test_create_dataset_from_trajectories(text, selection) + + if __name__ == "__main__": - # test_loadDataframe() + test_dataset_from_xyz() test_datasetFromFile() + test_datasesetFromTrajectories() \ No newline at end of file diff --git a/mlcolvar/utils/io.py b/mlcolvar/utils/io.py index 7bbc96db..608f29f5 100644 --- a/mlcolvar/utils/io.py +++ b/mlcolvar/utils/io.py @@ -9,10 +9,24 @@ import numpy as np import torch import os +import tempfile import urllib.request -from typing import Union +from typing import Union, List, Tuple +import mdtraj +from warnings import warn + +# Import ASE for xyz to pdb conversion. +try: + from ase.io import read, write + from ase import Atoms +except ImportError as e: + raise ImportError("ASE is required for xyz to pdb conversion.", e) + from mlcolvar.data import DictDataset +from mlcolvar.data.graph.atomic import AtomicNumberTable, Configuration, Configurations +from mlcolvar.data.graph.utils import create_dataset_from_configurations + __all__ = ["load_dataframe", "plumed_to_pandas", "create_dataset_from_files"] @@ -117,7 +131,11 @@ def load_dataframe( if "http" in filename: download = True url = filename - filename = "tmp_" + filename.split("/")[-1] + if delete_download: + temp = tempfile.NamedTemporaryFile() + filename = temp.name + else: + filename = "tmp_" + filename.split("/")[-1] urllib.request.urlretrieve(url, filename) # check if file is in PLUMED format @@ -137,7 +155,7 @@ def load_dataframe( # delete temporary data if necessary if download: if delete_download: - os.remove(filename) + temp.close() else: print(f"downloaded file ({url}) saved as ({filename}).") @@ -251,13 +269,371 @@ def create_dataset_from_files( dictionary = {"data": torch.Tensor(df_data.values)} if create_labels: dictionary["labels"] = torch.Tensor(df["labels"].values) - dataset = DictDataset(dictionary, feature_names=df_data.columns.values) + dataset = DictDataset(dictionary, feature_names=df_data.columns.values, data_type='descriptors') if return_dataframe: return dataset, df else: return dataset +def create_pdb_from_xyz(input_filename: str, output_filename: str) -> str: + """ + Convert the first frame of an XYZ file into a PDB file using ASE. + This pdb file can then serve as the topology for MDTraj. + + Parameters: + input_filename: Path to the input .xyz file. + output_filename: Path to the output .pdb file. + + Returns: + The path to the generated PDB file. + """ + atoms: Atoms = read(input_filename, index=0) + + if (atoms.cell == 0).all(): + warn("A topology file was generated from the xyz trajectory file but no cell information were provided!") + if not atoms.pbc.any(): + warn("A topology file was generated from the xyz trajectory file but no PBC information were provided!") + elif not atoms.pbc.all(): + warn( f"Partial PBC are not supported! The provided input has pbc {atoms.pbc}") + + write(output_filename, atoms, format='proteindatabank') + return output_filename + + + +def create_dataset_from_trajectories( + trajectories: Union[List[str], str], + top: Union[List[str], str, None], + cutoff: float, + buffer: float = 0.0, + z_table: AtomicNumberTable = None, + load_args: list = None, + folder: str = None, + labels: list = None, + system_selection: str = None, + environment_selection: str = None, + return_trajectories: bool = False, + remove_isolated_nodes: bool = True, + show_progress: bool = True, + save_names=True, + lengths_conversion : float = 10.0, +) -> Union[ + DictDataset, + Tuple[ + DictDataset, + Union[List[List[mdtraj.Trajectory]], List[mdtraj.Trajectory]] + ] +]: + """ + Create a dataset from a set of trajectory files. + + Parameters + ---------- + trajectories: Union[List[str], str] + Paths to trajectories files. + top: Union[List[str], str, None] + Path to topology files. Only for .xyz files it can be set to None or empty to generate automatically a topology file. + cutoff: float (units: Ang) + The graph cutoff radius. + buffer: float + Buffer size used in finding active environment atoms. + z_table: mlcolvar.graph.data.atomic.AtomicNumberTable + The atomic number table used to build the node attributes. If not + given, it will be created from the given trajectories. + load_args: list[dict], optional + List of dictionaries for loading options for each file (keys: start,stop,stride), by default None + folder: str + Common path for the files to be imported. If set, filenames become + `folder/file_name`. + labels: list + List of labels to be assigned to the given files. by default None. + If None, it simply enumerates the files. + system_selection: str + MDTraj style atom selections [1] of the system atoms. If given, only + selected atoms will be loaded from the trajectories. This option may + increase the speed of building graphs. + environment_selection: str + MDTraj style atom selections [1] of the environment atoms. If given, + only the system atoms and [the environment atoms within the cutoff + radius of the system atoms] will be kept in the graph. + return_trajectories: bool + If also return the loaded trajectory objects. + remove_isolated_nodes: bool + If remove isolated nodes from the dataset. + show_progress: bool + If show the progress bar. + save_names: bool + If to save names from topology file, by default True + lengths_conversion: float, + Conversion factor for length units, by default 10. + MDTraj uses nanometers, the default sends to Angstroms. + + Returns + ------- + dataset: mlcolvar.graph.data.GraphDataSet + The graph dataset. + trajectories: Union[List[List[mdtraj.Trajectory]], List[mdtraj.Trajectory]] + The loaded trajectory objects. + + Notes + ----- + The login behind this method is like the follows: + 1. If only `system_selection` is given, the method will only load atoms + selected by this selection, from the trajectories. + 2. If both `system_selection` and `environment_selection` are given, + the method will load the atoms select by both selections, but will + build graphs using [the system atoms] and [the environment atoms within + the cutoff radius of the system atoms]. + + References + ---------- + .. [1] https://www.mdtraj.org/1.9.8.dev0/atom_selection.html + """ + + # check if using truncated graph + if environment_selection is not None: + assert system_selection is not None, ( + 'the `environment_selection` argument requires the' + + '`system_selection` argument to be defined!' + ) + selection = '({:s}) or ({:s})'.format( + system_selection, environment_selection + ) + elif system_selection is not None: + selection = system_selection + else: + selection = None + + if environment_selection is None: + assert buffer == 0, ( + 'Not `environment_selection` given! Cannot define buffer size!' + ) + + # initiliaze simple labels if not provided + if labels is None: + labels = [i for i in range(len(trajectories))] + else: + assert len(labels) == len(trajectories), ( + "Number of labels and trajectories must be the same!" + ) + + # check topologies if given + if top is not None: + assert len(trajectories) == len(top) or len(top)==1 or isinstance(top, str), ( + 'Either a single topology file or as many as the trajectory files must be provided!' + ) + + # ensure trajectories is a list + if isinstance(trajectories, str): + trajectories = [trajectories] + + # --- Handle topologies input --- + # Allow top to be None or empty. In that case, create a list of empty strings. + if isinstance(top, str): + top = [top for _ in trajectories] + if top is None or (isinstance(top, list) and len(top) == 0): + top = ["" for _ in trajectories] + elif len(top) == 1 and len(trajectories) > 1: + top = [top for _ in trajectories] + + # For each trajectory file (and its associated topology), if the trajectory file + # has a ".xyz" extension and no topology is provided, convert it. + for i in range(len(trajectories)): + if folder is not None: + trajectories[i] = os.path.join(folder, trajectories[i]) + if top[i]: + top[i] = os.path.join(folder, top[i]) + assert isinstance(trajectories[i], str) + _, ext = os.path.splitext(trajectories[i]) + if (ext.lower() == ".xyz") and (not top[i]): + pdb_file = trajectories[i].replace('.xyz', '_top.pdb') + top[i] = create_pdb_from_xyz(trajectories[i], pdb_file) + + # check if per file args are given, otherwise set to {} + if load_args is not None: + if (not isinstance(load_args, list)) or (len(trajectories) != len(load_args)): + raise TypeError( + "load_args should be a list of dictionaries of arguments of same length as trajectories." + ) + + + # load topologies and trajectories + topologies = [] + trajectories_in_memory = [] + for i in range(len(trajectories)): + # load trajectory + traj = mdtraj.load(trajectories[i], top=top[i]) + traj.top = mdtraj.core.trajectory.load_topology(top[i]) + + # mdtraj does not load cell info from xyz, so we use ASE and add it + _, ext = os.path.splitext(trajectories[i]) + if (ext.lower() == ".xyz"): + ase_atoms = read(trajectories[i], index=':') + ase_cells = np.array([a.get_cell().array for a in ase_atoms], dtype=float) + # the pdb for the topology are in nm, ase work in A so we need to scale it + traj.unitcell_vectors = ase_cells/10 + + if selection is not None: + subset = traj.top.select(selection) + assert len(subset) > 0, ( + 'No atoms will be selected with selection string ' + + '"{:s}"!'.format(selection) + ) + traj = traj.atom_slice(subset) + trajectories_in_memory.append(traj) + topologies.append(traj.top) + + if z_table is None: + z_table = _z_table_from_top(topologies) + + if save_names: + atom_names = _names_from_top(topologies) + else: + atom_names = None + + # create configurations objects from trajectories + configurations = [] + for i in range(len(trajectories_in_memory)): + configuration = _configures_from_trajectory( + trajectory=trajectories_in_memory[i], + label=labels[i], + system_selection=system_selection, + environment_selection=environment_selection, + start=load_args[i]['start'] if load_args is not None else 0, + stop=load_args[i]['stop'] if load_args is not None else None, + stride=load_args[i]['stride'] if load_args is not None else 1, + lengths_conversion=lengths_conversion, + ) + configurations.extend(configuration) + + # convert configurations into DictDataset + dataset = create_dataset_from_configurations( + config=configurations, + z_table=z_table, + cutoff=cutoff, + buffer=buffer, + atom_names=atom_names, + remove_isolated_nodes=remove_isolated_nodes, + show_progress=show_progress + ) + + if return_trajectories: + return dataset, trajectories_in_memory + else: + return dataset + + +def _names_from_top(top: List[mdtraj.Topology] ): + it = iter(top) + atom_names = list(next(it).atoms) + if not all([atom_names == list(n.atoms) for n in it]): + raise ValueError( + "The atoms names or their order are different in the topology files. Check or deactivate save_names" + ) + + return atom_names + + +def _z_table_from_top( + top: List[mdtraj.Topology] +) -> AtomicNumberTable: + """ + Create an atomic number table from the topologies. + + Parameters + ---------- + top: List[mdtraj.Topology] + The topology objects. + """ + atomic_numbers = [] + for t in top: + atomic_numbers.extend([a.element.number for a in t.atoms]) + # atomic_numbers = np.array(atomic_numbers, dtype=int) + z_table = AtomicNumberTable.from_zs(atomic_numbers) + return z_table + + +def _configures_from_trajectory( + trajectory: mdtraj.Trajectory, + label: int = None, + system_selection: str = None, + environment_selection: str = None, + start: int = 0, + stop: int = None, + stride: int = 1, + lengths_conversion : float = 10.0) -> Configurations: + """ + Create configurations from one trajectory. + + Parameters + ---------- + trajectory: mdtraj.Trajectory + The MDTraj Trajectory object. + label: int + The graph label. + system_selection: str + MDTraj style atom selections of the system atoms. If given, only + selected atoms will be loaded from the trajectories. This option may + increase the speed of building graphs. + environment_selection: str + MDTraj style atom selections of the environment atoms. If given, + only the system atoms and [the environment atoms within the cutoff + radius of the system atoms] will be kept in the graph. + lengths_conversion: float, + Conversion factor for length units, by default 10. + MDTraj uses nanometers, the default sends to Angstroms. + """ + if label is not None: + label = np.array([[label]]) + + if system_selection is not None and environment_selection is not None: + system_atoms = trajectory.top.select(system_selection) + assert len(system_atoms) > 0, ( + 'No atoms will be selected with `system_selection`: ' + + '"{:s}"!'.format(system_selection) + ) + environment_atoms = trajectory.top.select(environment_selection) + assert len(environment_atoms) > 0, ( + 'No atoms will be selected with `environment_selection`: ' + + '"{:s}"!'.format(environment_selection) + ) + else: + system_atoms = None + environment_atoms = None + + atomic_numbers = [a.element.number for a in trajectory.top.atoms] + if trajectory.unitcell_vectors is not None: + pbc = [True] * 3 + cell = trajectory.unitcell_vectors + else: + pbc = [False] * 3 + cell = [None] * len(trajectory) + + if stop is None: + stop = len(trajectory) + + configurations = [] + + for i in range(start,stop,stride): + configuration = Configuration( + atomic_numbers=atomic_numbers, + positions=trajectory.xyz[i] * lengths_conversion, + cell=cell[i] * lengths_conversion, + pbc=pbc, + graph_labels=label, + node_labels=None, # TODO: Add supports for per-node labels. + system=system_atoms, + environment=environment_atoms + ) + configurations.append(configuration) + + return configurations + + +# ================================================================================================= +# ============================================= TESTS ============================================= +# ================================================================================================= def test_datasetFromFile(): # Test with unlabeled dataset @@ -316,6 +692,230 @@ def test_modifier(x): stride=1, ) +def test_datasesetFromTrajectories(): + create_dataset_from_trajectories( + trajectories=['r.dcd', + 'p.dcd'], + top=['r.pdb', + 'p.pdb'], + folder="mlcolvar/tests/data", + cutoff=8.0, # Ang + labels=None, + system_selection='all and not type H', + show_progress=False, + ) + + dataset = create_dataset_from_trajectories( + trajectories=['r.dcd', + 'p.dcd'], + top=['r.pdb', + 'p.pdb'], + folder="mlcolvar/tests/data", + cutoff=8.0, # Ang + labels=[0,1], + system_selection='all and not type H', + show_progress=False, + load_args=[{'start' : 0, 'stop' : 10, 'stride' : 1}, + {'start' : 6, 'stop' : 10, 'stride' : 2}] + ) + assert(len(dataset)==12) + + dataset = create_dataset_from_trajectories( + trajectories=['r.dcd', 'r.dcd', + 'p.dcd', 'p.dcd'], + top=['r.pdb', 'r.pdb', + 'p.pdb', 'p.pdb'], + folder="mlcolvar/tests/data", + cutoff=8.0, # Ang + labels=[0,1,2,3], + system_selection='all and not type H', + show_progress=False, + load_args=[{'start' : 0, 'stop' : 10, 'stride' : 1}, {'start' : 0, 'stop' : 10, 'stride' : 1}, + {'start' : 6, 'stop' : 10, 'stride' : 2}, {'start' : 6, 'stop' : 10, 'stride' : 2}] + ) + assert(len(dataset)==24) + + +def test_create_dataset_from_trajectories(text: str = """ +CRYST1 2.000 2.000 2.000 90.00 90.00 90.00 P 1 1 +ATOM 1 OH2 TIP3W 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 H1 TIP3W 1 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 3 H2 TIP3W 1 0.700 -0.700 0.000 1.00 0.00 WT1 H +ENDMODEL +ATOM 1 OH2 TIP3W 1 0.000 0.000 0.000 1.00 0.00 WT1 O +ATOM 2 H1 TIP3W 1 0.700 0.700 0.000 1.00 0.00 WT1 H +ATOM 3 H2 TIP3W 1 0.700 -0.700 0.000 1.00 0.00 WT1 H +END +""", +system_selection: str = None +) -> None: + import tempfile + with tempfile.TemporaryDirectory() as tmpdir: + test_dataset_path = "test_dataset.pdb" + test_dataset_path = os.path.join(tmpdir, test_dataset_path) + with open(test_dataset_path, 'w') as fp: + print(text, file=fp) + + dataset, trajectories = create_dataset_from_trajectories( + trajectories=[test_dataset_path, test_dataset_path, test_dataset_path], + top=[test_dataset_path, test_dataset_path, test_dataset_path], + cutoff=1.0, + system_selection=system_selection, + return_trajectories=True, + show_progress=False + ) + + assert len(dataset) == 6 + assert dataset.metadata["cutoff"] == 1.0 + assert dataset.metadata["z_table"] == [1, 8] + assert len(trajectories[0]) == 2 + assert len(trajectories[1]) == 2 + assert len(trajectories[2]) == 2 + + assert dataset[0]["data_list"]['graph_labels'] == torch.tensor([[0.0]]) + assert dataset[1]["data_list"]['graph_labels'] == torch.tensor([[0.0]]) + assert dataset[2]["data_list"]['graph_labels'] == torch.tensor([[1.0]]) + assert dataset[3]["data_list"]['graph_labels'] == torch.tensor([[1.0]]) + assert dataset[4]["data_list"]['graph_labels'] == torch.tensor([[2.0]]) + assert dataset[5]["data_list"]['graph_labels'] == torch.tensor([[2.0]]) + + dataset, trajectories = create_dataset_from_trajectories( + trajectories=[test_dataset_path, test_dataset_path, test_dataset_path], + top=test_dataset_path, + cutoff=1.0, + labels=None, + system_selection=system_selection, + return_trajectories=True, + show_progress=False + ) + + assert dataset[0]["data_list"]['graph_labels'] == torch.tensor([[0.0]]) + assert dataset[1]["data_list"]['graph_labels'] == torch.tensor([[0.0]]) + assert dataset[2]["data_list"]['graph_labels'] == torch.tensor([[1.0]]) + assert dataset[3]["data_list"]['graph_labels'] == torch.tensor([[1.0]]) + assert dataset[4]["data_list"]['graph_labels'] == torch.tensor([[2.0]]) + assert dataset[5]["data_list"]['graph_labels'] == torch.tensor([[2.0]]) + + def check_data_1(data) -> None: + assert(torch.allclose(data["data_list"]['edge_index'], torch.tensor([[0, 0, 1, 1, 2, 2], + [2, 1, 0, 2, 1, 0]]) + ) + ) + assert(torch.allclose(data["data_list"]['shifts'], torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, -2.0, 0.0], + [0.0, 0.0, 0.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['unit_shifts'], torch.tensor([[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 0.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['positions'], torch.tensor([[0.0, 0.0, 0.0], + [0.7, 0.7, 0.0], + [0.7, -0.7, 0.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['cell'], torch.tensor([[2.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 2.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['node_attrs'], torch.tensor([[0.0, 1.0], + [1.0, 0.0], + [1.0, 0.0]]) + ) + ) + + for i in range(6): + check_data_1(dataset[i]) + + if system_selection is not None: + + dataset = create_dataset_from_trajectories( + trajectories=[test_dataset_path, test_dataset_path, test_dataset_path], + top=[test_dataset_path, test_dataset_path, test_dataset_path], + cutoff=1.0, + system_selection='type O and {:s}'.format(system_selection), + environment_selection='type H and {:s}'.format(system_selection), + show_progress=False + ) + + for i in range(6): + check_data_1(dataset[i]) + + dataset = create_dataset_from_trajectories( + trajectories=[test_dataset_path, test_dataset_path, test_dataset_path], + top=[test_dataset_path, test_dataset_path, test_dataset_path], + cutoff=1.0, + system_selection='name H1 and {:s}'.format(system_selection), + environment_selection='name H2 and {:s}'.format(system_selection), + show_progress=False + ) -if __name__ == "__main__": - test_datasetFromFile() + def check_data_2(data) -> None: + assert(torch.allclose(data["data_list"]['edge_index'], torch.tensor([[0, 1], [1, 0]]))) + assert(torch.allclose(data["data_list"]['shifts'], torch.tensor([[0.0, 2.0, 0.0], + [0.0, -2.0, 0.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['unit_shifts'], torch.tensor([[0.0, 1.0, 0.0], + [0.0, -1.0, 0.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['positions'], torch.tensor([[0.7, 0.7, 0.0], + [0.7, -0.7, 0.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['cell'], torch.tensor([[2.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 2.0]]) + ) + ) + assert(torch.allclose(data["data_list"]['node_attrs'], torch.tensor([[1.0], + [1.0]]) + ) + ) + + for i in range(6): + check_data_2(dataset[i]) + + +def test_dataset_from_xyz(): + # load single file + load_args = [{'start' : 0, 'stop' : 2, 'stride' : 1}] + dataset = create_dataset_from_trajectories(trajectories="Cu.xyz", + folder="mlcolvar/tests/data", + top=None, + cutoff=3.5, # Ang + labels=None, + system_selection="index 0", + environment_selection="not index 0", + show_progress=False, + load_args=load_args, + buffer=1, + ) + + print(dataset) + + # load multiple files + load_args = [{'start' : 0, 'stop' : 2, 'stride' : 1}, + {'start' : 0, 'stop' : 4, 'stride' : 2}] + dataset = create_dataset_from_trajectories(trajectories=["Cu.xyz", "Cu.xyz"], + folder="mlcolvar/tests/data", + top=None, + cutoff=3.5, # Ang + labels=None, + system_selection="index 0 or index 1", + environment_selection="not index 0 and not index 1", + show_progress=False, + load_args=load_args, + buffer=1, + ) + print(dataset) \ No newline at end of file diff --git a/mlcolvar/utils/plot.py b/mlcolvar/utils/plot.py index 62f1a3ee..41b86d5f 100644 --- a/mlcolvar/utils/plot.py +++ b/mlcolvar/utils/plot.py @@ -328,6 +328,72 @@ def plot_features_distribution(dataset, features, titles=None, axs=None): ax.set_yticks([]) ax.legend([],[],title=feat,loc='upper center',frameon=False) +import sys +import time +import typing + +""" +A simple progress bar. +""" + +__all__ = ['pbar'] + + +def pbar( + item: typing.List[int], + prefix: str = '', + size: int = 25, + frequency: int = 0.05, + use_unicode: bool = True, + file: typing.TextIO = sys.stdout +): + """ + A simple progress bar. Taken from stackoverflow: + https://stackoverflow.com/questions/3160699 + Parameters + ---------- + it : List[int] + The looped item. + prefix : str + Prefix of the bar. + size : int + Size of the bar. + frequency : float + Flush frequency of the bar. + use_unicode : bool + If use unicode char to draw the bar. + file : TextIO + The output file. + """ + if (use_unicode): + c_1 = '' + c_2 = '█' + c_3 = '━' + c_4 = '' + else: + c_1 = '|' + c_2 = '|' + c_3 = '-' + c_4 = '|' + count = len(item) + start = time.time() + interval = max(int(count * frequency), 1) + + def show(j) -> None: + x = int(size * j / count) + remaining = ((time.time() - start) / j) * (count - j) + mins, sec = divmod(remaining, 60) + time_string = f'{int(mins):02}:{sec:02.1f}' + output = f' {prefix} {c_1}{c_2 * (x - 1) + c_4}{c_3 * (size - x)} ' + \ + f'{j}/{count} Est. {time_string}' + print('\x1b[1A\x1b[2K' + output, file=file, flush=True) + + for i, it in enumerate(item): + yield it + if ((i % interval) == 0 or i in [0, (count - 1)]): + show(i + 1) + print(flush=True, file=file) + def test_utils_plot(): import matplotlib @@ -344,3 +410,10 @@ def test_utils_plot(): cmap = matplotlib.colors.Colormap("fessa_r", 2) cmap = matplotlib.colors.Colormap("cortina80", 2) cmap = matplotlib.colors.Colormap("cortina80_r", 2) + + import time + for i in pbar(range(15), "Computing: ", 40): + time.sleep(0.1) + + for i in pbar(range(15), "Computing: ", 40, use_unicode=False): + time.sleep(0.1) \ No newline at end of file diff --git a/mlcolvar/utils/timelagged.py b/mlcolvar/utils/timelagged.py index 3d751f0a..989fd64e 100644 --- a/mlcolvar/utils/timelagged.py +++ b/mlcolvar/utils/timelagged.py @@ -3,6 +3,8 @@ from bisect import bisect_left from mlcolvar.data import DictDataset import warnings +from typing import Union +import copy # optional packages # pandas @@ -193,7 +195,7 @@ def progress(iter, progress_bar=progress_bar): def create_timelagged_dataset( - X: torch.Tensor, + X: Union[torch.Tensor, np.ndarray, DictDataset], t: torch.Tensor = None, lag_time: float = 1, reweight_mode: str = None, @@ -223,8 +225,8 @@ def create_timelagged_dataset( Parameters ---------- - X : array-like - input descriptors + X : torch.Tensor or np.ndarray or DictDataset + Input data, graph data can only be provided as DictDataset t : array-like, optional time series, by default np.arange(len(X)) reweight_mode: str, optional @@ -287,13 +289,23 @@ def create_timelagged_dataset( tprime = t # find pairs of configurations separated by lag_time - x_t, x_lag, w_t, w_lag = find_timelagged_configurations( - X, - tprime, - lag_time=lag_time, - logweights=logweights if reweight_mode == "weights_t" else None, - progress_bar=progress_bar, - ) + if isinstance(X, torch.Tensor) or isinstance(X, np.ndarray): + x_t, x_lag, w_t, w_lag = find_timelagged_configurations( + X, + tprime, + lag_time=lag_time, + logweights=logweights if reweight_mode == "weights_t" else None, + progress_bar=progress_bar, + ) + elif isinstance(X, DictDataset): + index = torch.arange(len(X), dtype=torch.long) + x_t, x_lag, w_t, w_lag = find_timelagged_configurations( + index, + tprime, + lag_time=lag_time, + logweights=logweights if reweight_mode == "weights_t" else None, + progress_bar=progress_bar, + ) # return only a slice of the data (N. Pedrani) if interval is not None: @@ -306,34 +318,92 @@ def create_timelagged_dataset( data[i] = data[i][interval[0] : interval[1]] x_t, x_lag, w_t, w_lag = data - dataset = DictDataset( - {"data": x_t, "data_lag": x_lag, "weights": w_t, "weights_lag": w_lag} - ) - - return dataset + if isinstance(X, torch.Tensor) or isinstance(X, np.ndarray): + dataset = DictDataset({"data": x_t, + "data_lag": x_lag, + "weights": w_t, + "weights_lag": w_lag}, + data_type='descriptors') + return dataset + + elif isinstance(X, DictDataset): + if X.metadata["data_type"] == "descriptors": + dataset = DictDataset({"data": X['data'][x_t], + "data_lag": X['data'][x_lag], + "weights": w_t, + "weights_lag": w_lag}, + data_type='descriptors') + + elif X.metadata["data_type"] == "graphs": + # we use deepcopy to avoid editing the original dataset + dataset = DictDataset(dictionary={"data_list" : copy.deepcopy(X[x_t.numpy().tolist()]["data_list"]), + "data_list_lag" : copy.deepcopy(X[x_lag.numpy().tolist()]["data_list"])}, + metadata={"z_table" : X.metadata["z_table"], + "cutoff" : X.metadata["cutoff"]}, + data_type="graphs") + # update weights + for i in range(len(dataset)): + dataset['data_list'][i]['weight'] = w_t[i] + dataset['data_list_lag'][i]['weight'] = w_lag[i] + + return dataset def test_create_timelagged_dataset(): in_features = 2 - n_points = 100 + n_points = 20 X = torch.rand(n_points, in_features) * 100 + dataset = DictDataset(data=X, data_type='descriptors') + # unbiased case t = np.arange(n_points) - dataset = create_timelagged_dataset(X, t, lag_time=10) - print(len(dataset)) + lagged_dataset_1 = create_timelagged_dataset(X, t, lag_time=10) + print(len(lagged_dataset_1)) + lagged_dataset_2 = create_timelagged_dataset(dataset, t, lag_time=10) + print(len(lagged_dataset_2)) + assert(torch.allclose(lagged_dataset_1['data'], lagged_dataset_2['data'])) + assert(torch.allclose(lagged_dataset_1['data_lag'], lagged_dataset_2['data_lag'])) + assert(torch.allclose(lagged_dataset_1['weights'], lagged_dataset_2['weights'])) + # reweight mode rescale_time (default) logweights = np.random.rand(n_points) - dataset = create_timelagged_dataset(X, t, logweights=logweights) - print(len(dataset)) + lagged_dataset_1 = create_timelagged_dataset(X, t, logweights=logweights) + print(len(lagged_dataset_1)) + lagged_dataset_2 = create_timelagged_dataset(dataset, t, logweights=logweights) + print(len(lagged_dataset_2)) + assert(torch.allclose(lagged_dataset_1['data'], lagged_dataset_2['data'])) + assert(torch.allclose(lagged_dataset_1['data_lag'], lagged_dataset_2['data_lag'])) + assert(torch.allclose(lagged_dataset_1['weights'], lagged_dataset_2['weights'])) + # reweight mode weights_t logweights = np.random.rand(n_points) - dataset = create_timelagged_dataset( + lagged_dataset_1 = create_timelagged_dataset( X, t, logweights=logweights, reweight_mode="weights_t" ) + print(len(lagged_dataset_1)) + lagged_dataset_2 = create_timelagged_dataset( + dataset, t, logweights=logweights, reweight_mode="weights_t" + ) + print(len(lagged_dataset_2)) + assert(torch.allclose(lagged_dataset_1['data'], lagged_dataset_2['data'])) + assert(torch.allclose(lagged_dataset_1['data_lag'], lagged_dataset_2['data_lag'])) + assert(torch.allclose(lagged_dataset_1['weights'], lagged_dataset_2['weights'])) + + + + # graph data + from mlcolvar.data.graph.utils import create_test_graph_input + dataset = create_test_graph_input('dataset') + print(dataset['data_list'][0]) + lagged_dataset = create_timelagged_dataset(dataset, logweights=torch.randn(len(dataset))) + print(lagged_dataset['data_list'][0]) + print(dataset['data_list'][0]) + print(len(dataset)) + if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index c5283ab0..ba6536de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,7 @@ torch numpy<2 pandas matplotlib -kdepy \ No newline at end of file +kdepy +torch_geometric +matscipy +mdtraj \ No newline at end of file