6
6
* Biological Structures at Stanford, funded under the NIH Roadmap for *
7
7
* Medical Research, grant U54 GM072970. See https://simtk.org. *
8
8
* *
9
- * Portions copyright (c) 2018-2022 Stanford University and the Authors. *
9
+ * Portions copyright (c) 2018-2024 Stanford University and the Authors. *
10
10
* Authors: Peter Eastman *
11
11
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
12
12
* *
@@ -66,6 +66,10 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
66
66
outputsForces = force.getOutputsForces ();
67
67
for (int i = 0 ; i < force.getNumGlobalParameters (); i++)
68
68
globalNames.push_back (force.getGlobalParameterName (i));
69
+ for (int i = 0 ; i < force.getNumEnergyParameterDerivatives (); i++) {
70
+ paramDerivs.insert (force.getEnergyParameterDerivativeName (i));
71
+ cu.addEnergyParameterDerivative (force.getEnergyParameterDerivativeName (i));
72
+ }
69
73
int numParticles = system.getNumParticles ();
70
74
71
75
// Push the PyTorch context
@@ -81,6 +85,8 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
81
85
boxTensor = torch::empty ({3 , 3 }, options);
82
86
energyTensor = torch::empty ({0 }, options);
83
87
forceTensor = torch::empty ({0 }, options);
88
+ for (const string& name : globalNames)
89
+ globalTensors[name] = torch::tensor ({0 }, options);
84
90
// Pop the PyToch context
85
91
CUcontext ctx;
86
92
CHECK_RESULT (cuCtxPopCurrent (&ctx), " Failed to pop the CUDA context" );
@@ -125,7 +131,7 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) {
125
131
/* *
126
132
* Prepare the inputs for the PyTorch model, copying positions from the OpenMM context.
127
133
*/
128
- std:: vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs (ContextImpl& context ) {
134
+ void CudaCalcTorchForceKernel::prepareTorchInputs (ContextImpl& context, vector<torch::jit::IValue>& inputs, map<string, torch::Tensor>& globalTensors ) {
129
135
int numParticles = cu.getNumAtoms ();
130
136
// Get pointers to the atomic positions and simulation box
131
137
void * posData = getTensorPointer (cu, posTensor);
@@ -145,12 +151,17 @@ std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(Con
145
151
CHECK_RESULT (cuCtxSynchronize (), " Failed to synchronize the CUDA context" ); // Synchronize before switching to the PyTorch context
146
152
}
147
153
// Prepare the input of the PyTorch model
148
- vector<torch::jit::IValue> inputs = {posTensor};
154
+ inputs = {posTensor};
149
155
if (usePeriodic)
150
156
inputs.push_back (boxTensor);
151
- for (const string& name : globalNames)
152
- inputs.push_back (torch::tensor (context.getParameter (name)));
153
- return inputs;
157
+ for (const string& name : globalNames) {
158
+ // PyTorch requires us to set requires_grad to false before initializing a tensor.
159
+ globalTensors[name].set_requires_grad (false );
160
+ globalTensors[name][0 ] = context.getParameter (name);
161
+ if (paramDerivs.find (name) != paramDerivs.end ())
162
+ globalTensors[name].set_requires_grad (true );
163
+ inputs.push_back (globalTensors[name]);
164
+ }
154
165
}
155
166
156
167
/* *
@@ -173,26 +184,34 @@ void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) {
173
184
}
174
185
175
186
/* *
176
- * This function launches the workload in a way compatible with CUDA
177
- * graphs as far as OpenMM-Torch goes. Capturing this function when
178
- * the model is not itself graph compatible (due to, for instance,
187
+ * This function launches the workload in a way compatible with CUDA
188
+ * graphs as far as OpenMM-Torch goes. Capturing this function when
189
+ * the model is not itself graph compatible (due to, for instance,
179
190
* implicit synchronizations) will result in a CUDA error.
180
191
*/
181
192
static void executeGraph (bool outputsForces, bool includeForces, torch::jit::script::Module& module , vector<torch::jit::IValue>& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor,
182
- torch::Tensor& forceTensor) {
193
+ torch::Tensor& forceTensor, map<string, torch::Tensor>& globalTensors, set<string> paramDerivs) {
194
+ vector<torch::Tensor> gradInputs;
195
+ if (!outputsForces && includeForces)
196
+ gradInputs.push_back (posTensor);
197
+ for (auto & name : paramDerivs)
198
+ gradInputs.push_back (globalTensors[name]);
199
+ auto none = torch::Tensor ();
183
200
if (outputsForces) {
184
201
auto outputs = module .forward (inputs).toTuple ();
185
202
energyTensor = outputs->elements ()[0 ].toTensor ();
186
203
forceTensor = outputs->elements ()[1 ].toTensor ();
204
+ if (gradInputs.size () > 0 )
205
+ energyTensor.backward (none, false , false , gradInputs);
187
206
} else {
188
207
energyTensor = module .forward (inputs).toTensor ();
189
208
// Compute force by backpropagating the PyTorch model
209
+ // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
210
+ // See https://github.com/openmm/openmm-torch/pull/120/
211
+ if (gradInputs.size () > 0 )
212
+ energyTensor.backward (none, false , false , gradInputs);
190
213
if (includeForces) {
191
- // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
192
- // See https://github.com/openmm/openmm-torch/pull/120/
193
- auto none = torch::Tensor ();
194
- energyTensor.backward (none, false , false , posTensor);
195
- // This is minus the forces, we change the sign later on
214
+ // This is minus the forces, we change the sign later on
196
215
forceTensor = posTensor.grad ().clone ();
197
216
// Zero the gradient to avoid accumulating it
198
217
posTensor.grad ().zero_ ();
@@ -203,31 +222,32 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr
203
222
double CudaCalcTorchForceKernel::execute (ContextImpl& context, bool includeForces, bool includeEnergy) {
204
223
// Push to the PyTorch context
205
224
CHECK_RESULT (cuCtxPushCurrent (primaryContext), " Failed to push the CUDA context" );
206
- auto inputs = prepareTorchInputs (context);
225
+ vector<torch::jit::IValue> inputs;
226
+ prepareTorchInputs (context, inputs, globalTensors);
207
227
if (!useGraphs) {
208
- executeGraph (outputsForces, includeForces, module , inputs, posTensor, energyTensor, forceTensor);
228
+ executeGraph (outputsForces, includeForces, module , inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs );
209
229
} else {
210
230
// Record graph if not already done
211
231
bool is_graph_captured = false ;
212
232
if (graphs.find (includeForces) == graphs.end ()) {
213
233
// CUDA graph capture must occur in a non-default stream
214
234
const auto stream = c10::cuda::getStreamFromPool (false , cu.getDeviceIndex ());
215
- const c10::cuda::CUDAStreamGuard guard (stream);
235
+ const c10::cuda::CUDAStreamGuard guard (stream);
216
236
// Warmup the graph workload before capturing. This first
217
237
// run before capture sets up allocations so that no
218
238
// allocations are needed after. Pytorch's allocator is
219
239
// stream capture-aware and, after warmup, will provide
220
240
// record static pointers and shapes during capture.
221
241
try {
222
242
for (int i = 0 ; i < this ->warmupSteps ; i++)
223
- executeGraph (outputsForces, includeForces, module , inputs, posTensor, energyTensor, forceTensor);
243
+ executeGraph (outputsForces, includeForces, module , inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs );
224
244
}
225
245
catch (std::exception& e) {
226
246
throw OpenMMException (string (" TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n " ) + e.what ());
227
247
}
228
248
graphs[includeForces].capture_begin ();
229
249
try {
230
- executeGraph (outputsForces, includeForces, module , inputs, posTensor, energyTensor, forceTensor);
250
+ executeGraph (outputsForces, includeForces, module , inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs );
231
251
is_graph_captured = true ;
232
252
graphs[includeForces].capture_end ();
233
253
}
@@ -237,16 +257,23 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
237
257
}
238
258
throw OpenMMException (string (" TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n " ) + e.what ());
239
259
}
260
+ for (const string& name : paramDerivs)
261
+ globalTensors[name].grad ().zero_ ();
240
262
}
241
- // Use the same stream as the OpenMM context, even if it is the default stream
263
+ // Use the same stream as the OpenMM context, even if it is the default stream
242
264
const auto openmmStream = cu.getCurrentStream ();
243
- const auto stream = c10::cuda::getStreamFromExternal (openmmStream, cu.getDeviceIndex ());
244
- const c10::cuda::CUDAStreamGuard guard (stream);
265
+ const auto stream = c10::cuda::getStreamFromExternal (openmmStream, cu.getDeviceIndex ());
266
+ const c10::cuda::CUDAStreamGuard guard (stream);
245
267
graphs[includeForces].replay ();
246
268
}
247
269
if (includeForces) {
248
270
addForces (forceTensor);
249
271
}
272
+ map<string, double >& energyParamDerivs = cu.getEnergyParamDerivWorkspace ();
273
+ for (const string& name : paramDerivs) {
274
+ energyParamDerivs[name] += globalTensors[name].grad ().item <double >();
275
+ globalTensors[name].grad ().zero_ ();
276
+ }
250
277
// Get energy
251
278
const double energy = energyTensor.item <double >(); // This implicitly synchronizes the PyTorch context
252
279
// Pop to the PyTorch context
0 commit comments