Skip to content

Allow torch model to be used as Integrator #56

Open
@jchodera

Description

@jchodera

I wonder if we could also support torch models used as Integrators in OpenMM.

Perhaps something like this could work:

import torch

class IntegratorModule(torch.nn.Module):
    """A BAOAB Langevin integrator"""
    def forward(self, positions, velocities, timestep, temperature):
        """The forward method returns the updated positions and velocities given the current timestep and temperature.

        Parameters
        ----------
        positions : torch.Tensor with shape (nparticles,3)
           positions[i,k] is the position (in nanometers) of spatial dimension k of particle i
        velocities : torch.Tensor with shape (nparticles,3)
           velocities[i,k] is the velocity (in nanometers/picosecond) of spatial dimension k of particle i
        masses : torch.Tensor with shape (nparticles,3)
           masses[i,k] is the position (in amu) of particle i for all k
        timestep : torch.Tensor with shape (,)
            the integration timestep (in femtoseconds)
        temperature : torch.Tensor with shape (,)
            the temperature in kelvin)

        Returns
        -------
        positions : torch.Tensor with shape (nparticles,3)
           positions[i,k] is the position (in nanometers) of spatial dimension k of particle i
        velocities : torch.Tensor with shape (nparticles,3)
           velocities[i,k] is the velocity (in nanometers/picosecond) of spatial dimension k of particle i
        """
        forces = openmm_compute_forces(positions)
        velocities = velocities + (timestep/2) * (forces/masses)
        positions = positions + timestep * velocities
        forces = openmm_compute_forces(positions)
        velocities = velocities + (timestep/2) * (forces/masses)

        return positions, velocities

# Render the compute graph to a TorchScript module
module = torch.jit.script(IntegratorModule())

# Serialize the compute graph to a file
module.save('integrator.pt')

Here, we would have to extend TorchScript with custom Ops openmm_compute_forces, openmm_compute_potential, and openmm_compute_potential_and_forces, which would wrap the normal OpenMM energy/force computation. Optimally, these C++ functions would know when the force or potential did not need to be recomputed (because no particles moved) if called at the end of one step and at the beginning of the next step.

To use the integrator in a simulation, the user would create a TorchIntegrator object that would behave much like a normal Integrator:

from openmm import unit
timestep = 4.0*unit.femtoseconds
temperature = 300*unit.kelvin

# Create the TorchIntegrator from the serialized compute graph
from openmmtorch import TorchIntegrator
torch_integrator = TorchIntegrator('integrator.pt', temperature, timestep)

# Create a Context with the integrator
context = openmm.Context(system, torch_integrator)

# Run some dynamics
torch_integrator.step(100)

# Change the temperature and timestep
torch_integrator.setTemperature(100*unit.kelvin)
torch_integrator.setTimestep(2.0*unit.femtoseconds)

Edit: It would also be important to enable the integrator to modify global parameters, as well as define its own that can be accessed through the OpenMM API. I'm not quite sure how that would work, however.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions