Skip to content

Commit 5dc7279

Browse files
author
Raimondas Galvelis
authored
Fix interoperability with CustomCVForce (#80)
* Add a test with CustomCVForce * Test all the platforms * Add an iteroperability test for TorchANI and NNPOps * Add a missing dependencies * Skip for MacOS * Move imports * Fix import * Retain the primary context * Switch properly the contexts * Set the oldest CUDA to 11.0 * Fix nvcc version * Enable an extra check * Clean up a temporary file * Add more checks * Add comments * Remove a sync and clean up * Move the primary context activation
1 parent 661b004 commit 5dc7279

File tree

6 files changed

+132
-19
lines changed

6 files changed

+132
-19
lines changed

.github/workflows/CI.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ jobs:
2323
matrix:
2424
include:
2525
# Oldest supported versions
26-
- name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.11)
26+
# NOTE: renable CUDA 10.2 when it supported by NNPOps (https://github.com/conda-forge/nnpops-feedstock/pull/8)
27+
- name: Linux (CUDA 11.0, Python 3.7, PyTorch 1.11)
2728
os: ubuntu-18.04
28-
cuda-version: "10.2.89"
29+
cuda-version: "11.0.3"
2930
gcc-version: "8.5.*"
30-
nvcc-version: "10.2"
31+
nvcc-version: "11.0"
3132
python-version: "3.7"
3233
pytorch-version: "1.11.*"
3334

devtools/conda-envs/build-ubuntu-18.04.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dependencies:
66
- cudatoolkit @CUDATOOLKIT_VERSION@
77
- gxx_linux-64 @GCC_VERSION@
88
- make
9+
- nnpops
910
- nvcc_linux-64 @NVCC_VERSION@
1011
- ocl-icd
1112
- openmm >=7.7
@@ -15,4 +16,5 @@ dependencies:
1516
- python
1617
- pytorch-gpu @PYTORCH_VERSION@
1718
- swig
18-
- sysroot_linux-64 2.17
19+
- sysroot_linux-64 2.17
20+
- torchani

platforms/cuda/src/CudaTorchKernels.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ if (result != CUDA_SUCCESS) { \
4949
throw OpenMMException(m.str());\
5050
}
5151

52+
CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) :
53+
CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) {
54+
// Explicitly activate the primary context
55+
CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context");
56+
}
57+
5258
CudaCalcTorchForceKernel::~CudaCalcTorchForceKernel() {
59+
cuDevicePrimaryCtxRelease(cu.getDevice());
5360
}
5461

5562
void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce& force, torch::jit::script::Module& module) {
@@ -60,6 +67,11 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
6067
globalNames.push_back(force.getGlobalParameterName(i));
6168
int numParticles = system.getNumParticles();
6269

70+
// Push the PyTorch context
71+
// NOTE: Pytorch is always using the primary context.
72+
// It makes the primary context current, if it is not a case.
73+
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");
74+
6375
// Initialize CUDA objects for PyTorch
6476
const torch::Device device(torch::kCUDA, cu.getDeviceIndex()); // This implicitly initialize PyTorch
6577
module.to(device);
@@ -69,8 +81,13 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
6981
posTensor = torch::empty({numParticles, 3}, options.requires_grad(!outputsForces));
7082
boxTensor = torch::empty({3, 3}, options);
7183

84+
// Pop the PyToch context
85+
CUcontext ctx;
86+
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
87+
assert(primaryContext == ctx); // Check that PyTorch haven't messed up the context stack
88+
7289
// Initialize CUDA objects for OpenMM-Torch
73-
ContextSelector selector(cu);
90+
ContextSelector selector(cu); // Switch to the OpenMM context
7491
map<string, string> defines;
7592
CUmodule program = cu.createModule(CudaTorchKernelSources::torchForce, defines);
7693
copyInputsKernel = cu.getKernel(program, "copyInputs");
@@ -80,6 +97,9 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
8097
double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
8198
int numParticles = cu.getNumAtoms();
8299

100+
// Push to the PyTorch context
101+
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");
102+
83103
// Get pointers to the atomic positions and simulation box
84104
void* posData;
85105
void* boxData;
@@ -94,11 +114,11 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
94114

95115
// Copy the atomic positions and simulation box to PyTorch tensors
96116
{
97-
ContextSelector selector(cu);
117+
ContextSelector selector(cu); // Switch to the OpenMM context
98118
void* inputArgs[] = {&posData, &boxData, &cu.getPosq().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(),
99119
&numParticles, cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer()};
100120
cu.executeKernel(copyInputsKernel, inputArgs, numParticles);
101-
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the PyTorch context
121+
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
102122
}
103123

104124
// Prepare the input of the PyTorch model
@@ -138,21 +158,30 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
138158
forceTensor = forceTensor.to(torch::kFloat32);
139159
forceData = forceTensor.data_ptr<float>();
140160
}
141-
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the OpenMM context
161+
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the OpenMM context
142162

143163
// Add the computed forces to the total atomic forces
144164
{
145-
ContextSelector selector(cu);
165+
ContextSelector selector(cu); // Switch to the OpenMM context
146166
int paddedNumAtoms = cu.getPaddedNumAtoms();
147167
int forceSign = (outputsForces ? 1 : -1);
148168
void* forceArgs[] = {&forceData, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms, &forceSign};
149169
cu.executeKernel(addForcesKernel, forceArgs, numParticles);
150-
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the PyTorch context
170+
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
151171
}
152172

153173
// Reset the forces
154174
if (!outputsForces)
155175
posTensor.grad().zero_();
156176
}
157-
return energyTensor.item<double>(); // This implicitly synchronize the PyTorch context
177+
178+
// Get energy
179+
const double energy = energyTensor.item<double>(); // This implicitly synchronizes the PyTorch context
180+
181+
// Pop to the PyTorch context
182+
CUcontext ctx;
183+
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
184+
assert(primaryContext == ctx); // Check that the correct context was popped
185+
186+
return energy;
158187
}

platforms/cuda/src/CudaTorchKernels.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
#include "TorchKernels.h"
3636
#include "openmm/cuda/CudaContext.h"
37-
#include "openmm/cuda/CudaArray.h"
3837

3938
namespace TorchPlugin {
4039

@@ -43,9 +42,7 @@ namespace TorchPlugin {
4342
*/
4443
class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
4544
public:
46-
CudaCalcTorchForceKernel(std::string name, const OpenMM::Platform& platform, OpenMM::CudaContext& cu) :
47-
CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) {
48-
}
45+
CudaCalcTorchForceKernel(std::string name, const OpenMM::Platform& platform, OpenMM::CudaContext& cu);
4946
~CudaCalcTorchForceKernel();
5047
/**
5148
* Initialize the kernel.
@@ -72,6 +69,7 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
7269
std::vector<std::string> globalNames;
7370
bool usePeriodic, outputsForces;
7471
CUfunction copyInputsKernel, addForcesKernel;
72+
CUcontext primaryContext;
7573
};
7674

7775
} // namespace TorchPlugin

python/tests/TestInteroperability.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import openmm as mm
2+
import openmm.unit as unit
3+
import openmmtorch as ot
4+
import platform
5+
import pytest
6+
from tempfile import NamedTemporaryFile
7+
import torch as pt
8+
9+
10+
@pytest.mark.skipif(platform.system() == 'Darwin', reason='There is no NNPOps package for MacOS')
11+
@pytest.mark.parametrize('use_cv_force', [True, False])
12+
@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL'])
13+
def testTorchANI(use_cv_force, platform):
14+
15+
if pt.cuda.device_count() < 1 and platform == 'CUDA':
16+
pytest.skip('A CUDA device is not available')
17+
18+
import NNPOps # There is no NNPOps package for MacOS
19+
import torchani
20+
21+
class Model(pt.nn.Module):
22+
23+
def __init__(self):
24+
super().__init__()
25+
self.register_buffer('atomic_numbers', pt.tensor([[1, 1]]))
26+
self.model = torchani.models.ANI2x(periodic_table_index=True)
27+
self.model = NNPOps.OptimizedTorchANI(self.model, self.atomic_numbers)
28+
29+
def forward(self, positions):
30+
positions = positions.float().unsqueeze(0) * 10 # nm --> Ang
31+
return self.model((self.atomic_numbers, positions)).energies[0] * 2625.5 # Hartree --> kJ/mol
32+
33+
# Create a system
34+
system = mm.System()
35+
for _ in range(2):
36+
system.addParticle(1.0)
37+
positions = pt.tensor([[-5, 0.0, 0.0], [5, 0.0, 0.0]], requires_grad=True)
38+
39+
with NamedTemporaryFile() as model_file:
40+
41+
# Save the model
42+
pt.jit.script(Model()).save(model_file.name)
43+
44+
# Compute reference energy and forces
45+
model = pt.jit.load(model_file)
46+
ref_energy = model(positions)
47+
ref_energy.backward()
48+
ref_forces = positions.grad
49+
50+
# Create a force
51+
force = ot.TorchForce(model_file.name)
52+
if use_cv_force:
53+
# Wrap TorchForce into CustomCVForce
54+
cv_force = mm.CustomCVForce('force')
55+
cv_force.addCollectiveVariable('force', force)
56+
system.addForce(cv_force)
57+
else:
58+
system.addForce(force)
59+
60+
# Compute energy and forces
61+
integ = mm.VerletIntegrator(1.0)
62+
platform = mm.Platform.getPlatformByName(platform)
63+
context = mm.Context(system, integ, platform)
64+
context.setPositions(positions.detach().numpy())
65+
state = context.getState(getEnergy=True, getForces=True)
66+
energy = state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
67+
forces = state.getForces(asNumpy=True).value_in_unit(unit.kilojoules_per_mole/unit.nanometers)
68+
69+
# Check energy and forces
70+
assert pt.allclose(ref_energy, pt.tensor(energy, dtype=ref_energy.dtype))
71+
assert pt.allclose(ref_forces, pt.tensor(forces, dtype=ref_forces.dtype))

python/tests/TestTorchForce.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,37 @@
99
@pytest.mark.parametrize('model_file, output_forces,',
1010
[('../../tests/central.pt', False),
1111
('../../tests/forces.pt', True)])
12-
def testForce(model_file, output_forces):
12+
@pytest.mark.parametrize('use_cv_force', [True, False])
13+
@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL'])
14+
def testForce(model_file, output_forces, use_cv_force, platform):
15+
16+
if pt.cuda.device_count() < 1 and platform == 'CUDA':
17+
pytest.skip('A CUDA device is not available')
1318

1419
# Create a random cloud of particles.
1520
numParticles = 10
1621
system = mm.System()
1722
positions = np.random.rand(numParticles, 3)
18-
for i in range(numParticles):
23+
for _ in range(numParticles):
1924
system.addParticle(1.0)
2025

2126
# Create a force
2227
force = ot.TorchForce(model_file)
2328
assert not force.getOutputsForces() # Check the default
2429
force.setOutputsForces(output_forces)
2530
assert force.getOutputsForces() == output_forces
26-
system.addForce(force)
31+
if use_cv_force:
32+
# Wrap TorchForce into CustomCVForce
33+
cv_force = mm.CustomCVForce('force')
34+
cv_force.addCollectiveVariable('force', force)
35+
system.addForce(cv_force)
36+
else:
37+
system.addForce(force)
2738

2839
# Compute the forces and energy.
2940
integ = mm.VerletIntegrator(1.0)
30-
context = mm.Context(system, integ, mm.Platform.getPlatformByName('Reference'))
41+
platform = mm.Platform.getPlatformByName(platform)
42+
context = mm.Context(system, integ, platform)
3143
context.setPositions(positions)
3244
state = context.getState(getEnergy=True, getForces=True)
3345

0 commit comments

Comments
 (0)