Skip to content

Commit e17b5d0

Browse files
authored
Merge pull request #143 from peastman/derivatives
Can compute parameter derivatives
2 parents 4fe23f6 + 1d83b86 commit e17b5d0

File tree

14 files changed

+219
-52
lines changed

14 files changed

+219
-52
lines changed

openmmapi/include/TorchForce.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* Biological Structures at Stanford, funded under the NIH Roadmap for *
1010
* Medical Research, grant U54 GM072970. See https://simtk.org. *
1111
* *
12-
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
12+
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
1313
* Authors: Peter Eastman *
1414
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
1515
* *
@@ -36,6 +36,7 @@
3636
#include "openmm/Force.h"
3737
#include <map>
3838
#include <string>
39+
#include <vector>
3940
#include <torch/torch.h>
4041
#include "internal/windowsExportTorch.h"
4142

@@ -106,6 +107,11 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
106107
* Get the number of global parameters that the interaction depends on.
107108
*/
108109
int getNumGlobalParameters() const;
110+
/**
111+
* Get the number of global parameters with respect to which the derivative of the energy
112+
* should be computed.
113+
*/
114+
int getNumEnergyParameterDerivatives() const;
109115
/**
110116
* Add a new global parameter that the interaction may depend on. The default value provided to
111117
* this method is the initial value of the parameter in newly created Contexts. You can change
@@ -144,6 +150,21 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
144150
* @param defaultValue the default value of the parameter
145151
*/
146152
void setGlobalParameterDefaultValue(int index, double defaultValue);
153+
/**
154+
* Request that this Force compute the derivative of its energy with respect to a global parameter.
155+
* The parameter must have already been added with addGlobalParameter().
156+
*
157+
* @param name the name of the parameter
158+
*/
159+
void addEnergyParameterDerivative(const std::string& name);
160+
/**
161+
* Get the name of a global parameter with respect to which this Force should compute the
162+
* derivative of the energy.
163+
*
164+
* @param index the index of the parameter derivative, between 0 and getNumEnergyParameterDerivatives()
165+
* @return the parameter name
166+
*/
167+
const std::string& getEnergyParameterDerivativeName(int index) const;
147168
/**
148169
* Set a value of a property.
149170
*
@@ -163,6 +184,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
163184
std::string file;
164185
bool usePeriodic, outputsForces;
165186
std::vector<GlobalParameterInfo> globalParameters;
187+
std::vector<int> energyParameterDerivatives;
166188
torch::jit::Module module;
167189
std::map<std::string, std::string> properties;
168190
std::string emptyProperty;

openmmapi/src/TorchForce.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* Biological Structures at Stanford, funded under the NIH Roadmap for *
77
* Medical Research, grant U54 GM072970. See https://simtk.org. *
88
* *
9-
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
9+
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
1010
* Authors: Peter Eastman *
1111
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
1212
* *
@@ -112,6 +112,24 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue)
112112
globalParameters[index].defaultValue = defaultValue;
113113
}
114114

115+
int TorchForce::getNumEnergyParameterDerivatives() const {
116+
return energyParameterDerivatives.size();
117+
}
118+
119+
void TorchForce::addEnergyParameterDerivative(const string& name) {
120+
for (int i = 0; i < globalParameters.size(); i++)
121+
if (name == globalParameters[i].name) {
122+
energyParameterDerivatives.push_back(i);
123+
return;
124+
}
125+
throw OpenMMException(string("addEnergyParameterDerivative: Unknown global parameter '"+name+"'"));
126+
}
127+
128+
const string& TorchForce::getEnergyParameterDerivativeName(int index) const {
129+
ASSERT_VALID_INDEX(index, energyParameterDerivatives);
130+
return globalParameters[energyParameterDerivatives[index]].name;
131+
}
132+
115133
void TorchForce::setProperty(const std::string& name, const std::string& value) {
116134
if (properties.find(name) == properties.end())
117135
throw OpenMMException("TorchForce: Unknown property '" + name + "'");

platforms/cuda/src/CudaTorchKernels.cpp

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* Biological Structures at Stanford, funded under the NIH Roadmap for *
77
* Medical Research, grant U54 GM072970. See https://simtk.org. *
88
* *
9-
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
9+
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
1010
* Authors: Peter Eastman *
1111
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
1212
* *
@@ -66,6 +66,10 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
6666
outputsForces = force.getOutputsForces();
6767
for (int i = 0; i < force.getNumGlobalParameters(); i++)
6868
globalNames.push_back(force.getGlobalParameterName(i));
69+
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
70+
paramDerivs.insert(force.getEnergyParameterDerivativeName(i));
71+
cu.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i));
72+
}
6973
int numParticles = system.getNumParticles();
7074

7175
// Push the PyTorch context
@@ -81,6 +85,8 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
8185
boxTensor = torch::empty({3, 3}, options);
8286
energyTensor = torch::empty({0}, options);
8387
forceTensor = torch::empty({0}, options);
88+
for (const string& name : globalNames)
89+
globalTensors[name] = torch::tensor({0}, options);
8490
// Pop the PyToch context
8591
CUcontext ctx;
8692
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
@@ -125,7 +131,7 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) {
125131
/**
126132
* Prepare the inputs for the PyTorch model, copying positions from the OpenMM context.
127133
*/
128-
std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context) {
134+
void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector<torch::jit::IValue>& inputs, map<string, torch::Tensor>& globalTensors) {
129135
int numParticles = cu.getNumAtoms();
130136
// Get pointers to the atomic positions and simulation box
131137
void* posData = getTensorPointer(cu, posTensor);
@@ -145,12 +151,17 @@ std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(Con
145151
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
146152
}
147153
// Prepare the input of the PyTorch model
148-
vector<torch::jit::IValue> inputs = {posTensor};
154+
inputs = {posTensor};
149155
if (usePeriodic)
150156
inputs.push_back(boxTensor);
151-
for (const string& name : globalNames)
152-
inputs.push_back(torch::tensor(context.getParameter(name)));
153-
return inputs;
157+
for (const string& name : globalNames) {
158+
// PyTorch requires us to set requires_grad to false before initializing a tensor.
159+
globalTensors[name].set_requires_grad(false);
160+
globalTensors[name][0] = context.getParameter(name);
161+
if (paramDerivs.find(name) != paramDerivs.end())
162+
globalTensors[name].set_requires_grad(true);
163+
inputs.push_back(globalTensors[name]);
164+
}
154165
}
155166

156167
/**
@@ -173,26 +184,34 @@ void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) {
173184
}
174185

175186
/**
176-
* This function launches the workload in a way compatible with CUDA
177-
* graphs as far as OpenMM-Torch goes. Capturing this function when
178-
* the model is not itself graph compatible (due to, for instance,
187+
* This function launches the workload in a way compatible with CUDA
188+
* graphs as far as OpenMM-Torch goes. Capturing this function when
189+
* the model is not itself graph compatible (due to, for instance,
179190
* implicit synchronizations) will result in a CUDA error.
180191
*/
181192
static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector<torch::jit::IValue>& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor,
182-
torch::Tensor& forceTensor) {
193+
torch::Tensor& forceTensor, map<string, torch::Tensor>& globalTensors, set<string> paramDerivs) {
194+
vector<torch::Tensor> gradInputs;
195+
if (!outputsForces && includeForces)
196+
gradInputs.push_back(posTensor);
197+
for (auto& name : paramDerivs)
198+
gradInputs.push_back(globalTensors[name]);
199+
auto none = torch::Tensor();
183200
if (outputsForces) {
184201
auto outputs = module.forward(inputs).toTuple();
185202
energyTensor = outputs->elements()[0].toTensor();
186203
forceTensor = outputs->elements()[1].toTensor();
204+
if (gradInputs.size() > 0)
205+
energyTensor.backward(none, false, false, gradInputs);
187206
} else {
188207
energyTensor = module.forward(inputs).toTensor();
189208
// Compute force by backpropagating the PyTorch model
209+
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
210+
// See https://github.com/openmm/openmm-torch/pull/120/
211+
if (gradInputs.size() > 0)
212+
energyTensor.backward(none, false, false, gradInputs);
190213
if (includeForces) {
191-
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
192-
// See https://github.com/openmm/openmm-torch/pull/120/
193-
auto none = torch::Tensor();
194-
energyTensor.backward(none, false, false, posTensor);
195-
// This is minus the forces, we change the sign later on
214+
// This is minus the forces, we change the sign later on
196215
forceTensor = posTensor.grad().clone();
197216
// Zero the gradient to avoid accumulating it
198217
posTensor.grad().zero_();
@@ -203,31 +222,32 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr
203222
double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
204223
// Push to the PyTorch context
205224
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");
206-
auto inputs = prepareTorchInputs(context);
225+
vector<torch::jit::IValue> inputs;
226+
prepareTorchInputs(context, inputs, globalTensors);
207227
if (!useGraphs) {
208-
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
228+
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
209229
} else {
210230
// Record graph if not already done
211231
bool is_graph_captured = false;
212232
if (graphs.find(includeForces) == graphs.end()) {
213233
//CUDA graph capture must occur in a non-default stream
214234
const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex());
215-
const c10::cuda::CUDAStreamGuard guard(stream);
235+
const c10::cuda::CUDAStreamGuard guard(stream);
216236
// Warmup the graph workload before capturing. This first
217237
// run before capture sets up allocations so that no
218238
// allocations are needed after. Pytorch's allocator is
219239
// stream capture-aware and, after warmup, will provide
220240
// record static pointers and shapes during capture.
221241
try {
222242
for (int i = 0; i < this->warmupSteps; i++)
223-
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
243+
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
224244
}
225245
catch (std::exception& e) {
226246
throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what());
227247
}
228248
graphs[includeForces].capture_begin();
229249
try {
230-
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
250+
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
231251
is_graph_captured = true;
232252
graphs[includeForces].capture_end();
233253
}
@@ -237,16 +257,23 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
237257
}
238258
throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what());
239259
}
260+
for (const string& name : paramDerivs)
261+
globalTensors[name].grad().zero_();
240262
}
241-
// Use the same stream as the OpenMM context, even if it is the default stream
263+
// Use the same stream as the OpenMM context, even if it is the default stream
242264
const auto openmmStream = cu.getCurrentStream();
243-
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
244-
const c10::cuda::CUDAStreamGuard guard(stream);
265+
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
266+
const c10::cuda::CUDAStreamGuard guard(stream);
245267
graphs[includeForces].replay();
246268
}
247269
if (includeForces) {
248270
addForces(forceTensor);
249271
}
272+
map<string, double>& energyParamDerivs = cu.getEnergyParamDerivWorkspace();
273+
for (const string& name : paramDerivs) {
274+
energyParamDerivs[name] += globalTensors[name].grad().item<double>();
275+
globalTensors[name].grad().zero_();
276+
}
250277
// Get energy
251278
const double energy = energyTensor.item<double>(); // This implicitly synchronizes the PyTorch context
252279
// Pop to the PyTorch context

platforms/cuda/src/CudaTorchKernels.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* Biological Structures at Stanford, funded under the NIH Roadmap for *
1010
* Medical Research, grant U54 GM072970. See https://simtk.org. *
1111
* *
12-
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
12+
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
1313
* Authors: Peter Eastman *
1414
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
1515
* *
@@ -37,6 +37,7 @@
3737
#include "openmm/cuda/CudaArray.h"
3838
#include <torch/version.h>
3939
#include <ATen/cuda/CUDAGraph.h>
40+
#include <set>
4041

4142
namespace TorchPlugin {
4243

@@ -71,12 +72,14 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
7172
torch::jit::script::Module module;
7273
torch::Tensor posTensor, boxTensor;
7374
torch::Tensor energyTensor, forceTensor;
75+
std::map<std::string, torch::Tensor> globalTensors;
7476
std::vector<std::string> globalNames;
77+
std::set<std::string> paramDerivs;
7578
bool usePeriodic, outputsForces;
7679
CUfunction copyInputsKernel, addForcesKernel;
7780
CUcontext primaryContext;
7881
std::map<bool, at::cuda::CUDAGraph> graphs;
79-
std::vector<torch::jit::IValue> prepareTorchInputs(OpenMM::ContextImpl& context);
82+
void prepareTorchInputs(OpenMM::ContextImpl& context, std::vector<torch::jit::IValue>& inputs, std::map<std::string, torch::Tensor>& derivInputs);
8083
bool useGraphs;
8184
void addForces(torch::Tensor& forceTensor);
8285
int warmupSteps;

platforms/cuda/tests/TestCudaTorchForce.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* Biological Structures at Stanford, funded under the NIH Roadmap for *
77
* Medical Research, grant U54 GM072970. See https://simtk.org. *
88
* *
9-
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
9+
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
1010
* Authors: Peter Eastman *
1111
* Contributors: *
1212
* *
@@ -129,7 +129,7 @@ void testPeriodicForce() {
129129
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
130130
}
131131

132-
void testGlobal() {
132+
void testGlobal(bool useGraphs) {
133133
// Create a random cloud of particles.
134134

135135
const int numParticles = 10;
@@ -143,6 +143,8 @@ void testGlobal() {
143143
}
144144
TorchForce* force = new TorchForce("tests/global.pt");
145145
force->addGlobalParameter("k", 2.0);
146+
force->addEnergyParameterDerivative("k");
147+
force->setProperty("useCUDAGraphs", useGraphs ? "true" : "false");
146148
system.addForce(force);
147149

148150
// Compute the forces and energy.
@@ -151,7 +153,7 @@ void testGlobal() {
151153
Platform& platform = Platform::getPlatformByName("CUDA");
152154
Context context(system, integ, platform);
153155
context.setPositions(positions);
154-
State state = context.getState(State::Energy | State::Forces);
156+
State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives);
155157

156158
// See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2
157159

@@ -164,15 +166,26 @@ void testGlobal() {
164166
}
165167
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
166168

169+
// Check the gradient of the energy with respect to the parameter.
170+
171+
double expected = 0.0;
172+
for (int i = 0; i < numParticles; i++) {
173+
Vec3 pos = positions[i];
174+
expected += pos.dot(pos);
175+
}
176+
double actual = state.getEnergyParameterDerivatives().at("k");
177+
ASSERT_EQUAL_TOL(expected, actual, 1e-5);
178+
167179
// Change the global parameter and see if the forces are still correct.
168180

169181
context.setParameter("k", 3.0);
170-
state = context.getState(State::Forces);
182+
state = context.getState(State::Forces | State::ParameterDerivatives);
171183
for (int i = 0; i < numParticles; i++) {
172184
Vec3 pos = positions[i];
173185
double r = sqrt(pos.dot(pos));
174186
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
175187
}
188+
ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5);
176189
}
177190

178191
int main(int argc, char* argv[]) {
@@ -183,7 +196,8 @@ int main(int argc, char* argv[]) {
183196
testForce(false);
184197
testForce(true);
185198
testPeriodicForce();
186-
testGlobal();
199+
testGlobal(false);
200+
testGlobal(true);
187201
}
188202
catch(const std::exception& e) {
189203
std::cout << "exception: " << e.what() << std::endl;

0 commit comments

Comments
 (0)