@@ -49,7 +49,14 @@ if (result != CUDA_SUCCESS) { \
49
49
throw OpenMMException (m.str ());\
50
50
}
51
51
52
+ CudaCalcTorchForceKernel::CudaCalcTorchForceKernel (string name, const Platform& platform, CudaContext& cu) :
53
+ CalcTorchForceKernel(name, platform), hasInitializedKernel(false ), cu(cu) {
54
+ // Explicitly activate the primary context
55
+ CHECK_RESULT (cuDevicePrimaryCtxRetain (&primaryContext, cu.getDevice ()), " Failed to retain the primary context" );
56
+ }
57
+
52
58
CudaCalcTorchForceKernel::~CudaCalcTorchForceKernel () {
59
+ cuDevicePrimaryCtxRelease (cu.getDevice ());
53
60
}
54
61
55
62
void CudaCalcTorchForceKernel::initialize (const System& system, const TorchForce& force, torch::jit::script::Module& module ) {
@@ -60,6 +67,11 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
60
67
globalNames.push_back (force.getGlobalParameterName (i));
61
68
int numParticles = system.getNumParticles ();
62
69
70
+ // Push the PyTorch context
71
+ // NOTE: Pytorch is always using the primary context.
72
+ // It makes the primary context current, if it is not a case.
73
+ CHECK_RESULT (cuCtxPushCurrent (primaryContext), " Failed to push the CUDA context" );
74
+
63
75
// Initialize CUDA objects for PyTorch
64
76
const torch::Device device (torch::kCUDA , cu.getDeviceIndex ()); // This implicitly initialize PyTorch
65
77
module .to (device);
@@ -69,8 +81,13 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
69
81
posTensor = torch::empty ({numParticles, 3 }, options.requires_grad (!outputsForces));
70
82
boxTensor = torch::empty ({3 , 3 }, options);
71
83
84
+ // Pop the PyToch context
85
+ CUcontext ctx;
86
+ CHECK_RESULT (cuCtxPopCurrent (&ctx), " Failed to pop the CUDA context" );
87
+ assert (primaryContext == ctx); // Check that PyTorch haven't messed up the context stack
88
+
72
89
// Initialize CUDA objects for OpenMM-Torch
73
- ContextSelector selector (cu);
90
+ ContextSelector selector (cu); // Switch to the OpenMM context
74
91
map<string, string> defines;
75
92
CUmodule program = cu.createModule (CudaTorchKernelSources::torchForce, defines);
76
93
copyInputsKernel = cu.getKernel (program, " copyInputs" );
@@ -80,6 +97,9 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
80
97
double CudaCalcTorchForceKernel::execute (ContextImpl& context, bool includeForces, bool includeEnergy) {
81
98
int numParticles = cu.getNumAtoms ();
82
99
100
+ // Push to the PyTorch context
101
+ CHECK_RESULT (cuCtxPushCurrent (primaryContext), " Failed to push the CUDA context" );
102
+
83
103
// Get pointers to the atomic positions and simulation box
84
104
void * posData;
85
105
void * boxData;
@@ -94,11 +114,11 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
94
114
95
115
// Copy the atomic positions and simulation box to PyTorch tensors
96
116
{
97
- ContextSelector selector (cu);
117
+ ContextSelector selector (cu); // Switch to the OpenMM context
98
118
void * inputArgs[] = {&posData, &boxData, &cu.getPosq ().getDevicePointer (), &cu.getAtomIndexArray ().getDevicePointer (),
99
119
&numParticles, cu.getPeriodicBoxVecXPointer (), cu.getPeriodicBoxVecYPointer (), cu.getPeriodicBoxVecZPointer ()};
100
120
cu.executeKernel (copyInputsKernel, inputArgs, numParticles);
101
- CHECK_RESULT (cuCtxSynchronize (), " Error synchronizing CUDA context" ); // Synchronize before switching to the PyTorch context
121
+ CHECK_RESULT (cuCtxSynchronize (), " Failed to synchronize the CUDA context" ); // Synchronize before switching to the PyTorch context
102
122
}
103
123
104
124
// Prepare the input of the PyTorch model
@@ -138,21 +158,30 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
138
158
forceTensor = forceTensor.to (torch::kFloat32 );
139
159
forceData = forceTensor.data_ptr <float >();
140
160
}
141
- CHECK_RESULT (cuCtxSynchronize (), " Error synchronizing CUDA context" ); // Synchronize before switching to the OpenMM context
161
+ CHECK_RESULT (cuCtxSynchronize (), " Failed to synchronize the CUDA context" ); // Synchronize before switching to the OpenMM context
142
162
143
163
// Add the computed forces to the total atomic forces
144
164
{
145
- ContextSelector selector (cu);
165
+ ContextSelector selector (cu); // Switch to the OpenMM context
146
166
int paddedNumAtoms = cu.getPaddedNumAtoms ();
147
167
int forceSign = (outputsForces ? 1 : -1 );
148
168
void * forceArgs[] = {&forceData, &cu.getForce ().getDevicePointer (), &cu.getAtomIndexArray ().getDevicePointer (), &numParticles, &paddedNumAtoms, &forceSign};
149
169
cu.executeKernel (addForcesKernel, forceArgs, numParticles);
150
- CHECK_RESULT (cuCtxSynchronize (), " Error synchronizing CUDA context" ); // Synchronize before switching to the PyTorch context
170
+ CHECK_RESULT (cuCtxSynchronize (), " Failed to synchronize the CUDA context" ); // Synchronize before switching to the PyTorch context
151
171
}
152
172
153
173
// Reset the forces
154
174
if (!outputsForces)
155
175
posTensor.grad ().zero_ ();
156
176
}
157
- return energyTensor.item <double >(); // This implicitly synchronize the PyTorch context
177
+
178
+ // Get energy
179
+ const double energy = energyTensor.item <double >(); // This implicitly synchronizes the PyTorch context
180
+
181
+ // Pop to the PyTorch context
182
+ CUcontext ctx;
183
+ CHECK_RESULT (cuCtxPopCurrent (&ctx), " Failed to pop the CUDA context" );
184
+ assert (primaryContext == ctx); // Check that the correct context was popped
185
+
186
+ return energy;
158
187
}
0 commit comments