Skip to content

Commit

Permalink
Merge pull request #143 from peastman/derivatives
Browse files Browse the repository at this point in the history
Can compute parameter derivatives
  • Loading branch information
RaulPPelaez authored May 22, 2024
2 parents 4fe23f6 + 1d83b86 commit e17b5d0
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 52 deletions.
24 changes: 23 additions & 1 deletion openmmapi/include/TorchForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand All @@ -36,6 +36,7 @@
#include "openmm/Force.h"
#include <map>
#include <string>
#include <vector>
#include <torch/torch.h>
#include "internal/windowsExportTorch.h"

Expand Down Expand Up @@ -106,6 +107,11 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* Get the number of global parameters that the interaction depends on.
*/
int getNumGlobalParameters() const;
/**
* Get the number of global parameters with respect to which the derivative of the energy
* should be computed.
*/
int getNumEnergyParameterDerivatives() const;
/**
* Add a new global parameter that the interaction may depend on. The default value provided to
* this method is the initial value of the parameter in newly created Contexts. You can change
Expand Down Expand Up @@ -144,6 +150,21 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
* @param defaultValue the default value of the parameter
*/
void setGlobalParameterDefaultValue(int index, double defaultValue);
/**
* Request that this Force compute the derivative of its energy with respect to a global parameter.
* The parameter must have already been added with addGlobalParameter().
*
* @param name the name of the parameter
*/
void addEnergyParameterDerivative(const std::string& name);
/**
* Get the name of a global parameter with respect to which this Force should compute the
* derivative of the energy.
*
* @param index the index of the parameter derivative, between 0 and getNumEnergyParameterDerivatives()
* @return the parameter name
*/
const std::string& getEnergyParameterDerivativeName(int index) const;
/**
* Set a value of a property.
*
Expand All @@ -163,6 +184,7 @@ class OPENMM_EXPORT_NN TorchForce : public OpenMM::Force {
std::string file;
bool usePeriodic, outputsForces;
std::vector<GlobalParameterInfo> globalParameters;
std::vector<int> energyParameterDerivatives;
torch::jit::Module module;
std::map<std::string, std::string> properties;
std::string emptyProperty;
Expand Down
20 changes: 19 additions & 1 deletion openmmapi/src/TorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand Down Expand Up @@ -112,6 +112,24 @@ void TorchForce::setGlobalParameterDefaultValue(int index, double defaultValue)
globalParameters[index].defaultValue = defaultValue;
}

int TorchForce::getNumEnergyParameterDerivatives() const {
return energyParameterDerivatives.size();
}

void TorchForce::addEnergyParameterDerivative(const string& name) {
for (int i = 0; i < globalParameters.size(); i++)
if (name == globalParameters[i].name) {
energyParameterDerivatives.push_back(i);
return;
}
throw OpenMMException(string("addEnergyParameterDerivative: Unknown global parameter '"+name+"'"));
}

const string& TorchForce::getEnergyParameterDerivativeName(int index) const {
ASSERT_VALID_INDEX(index, energyParameterDerivatives);
return globalParameters[energyParameterDerivatives[index]].name;
}

void TorchForce::setProperty(const std::string& name, const std::string& value) {
if (properties.find(name) == properties.end())
throw OpenMMException("TorchForce: Unknown property '" + name + "'");
Expand Down
73 changes: 50 additions & 23 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand Down Expand Up @@ -66,6 +66,10 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
outputsForces = force.getOutputsForces();
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
paramDerivs.insert(force.getEnergyParameterDerivativeName(i));
cu.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i));
}
int numParticles = system.getNumParticles();

// Push the PyTorch context
Expand All @@ -81,6 +85,8 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce
boxTensor = torch::empty({3, 3}, options);
energyTensor = torch::empty({0}, options);
forceTensor = torch::empty({0}, options);
for (const string& name : globalNames)
globalTensors[name] = torch::tensor({0}, options);
// Pop the PyToch context
CUcontext ctx;
CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context");
Expand Down Expand Up @@ -125,7 +131,7 @@ static void* getTensorPointer(OpenMM::CudaContext& cu, torch::Tensor& tensor) {
/**
* Prepare the inputs for the PyTorch model, copying positions from the OpenMM context.
*/
std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context) {
void CudaCalcTorchForceKernel::prepareTorchInputs(ContextImpl& context, vector<torch::jit::IValue>& inputs, map<string, torch::Tensor>& globalTensors) {
int numParticles = cu.getNumAtoms();
// Get pointers to the atomic positions and simulation box
void* posData = getTensorPointer(cu, posTensor);
Expand All @@ -145,12 +151,17 @@ std::vector<torch::jit::IValue> CudaCalcTorchForceKernel::prepareTorchInputs(Con
CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context
}
// Prepare the input of the PyTorch model
vector<torch::jit::IValue> inputs = {posTensor};
inputs = {posTensor};
if (usePeriodic)
inputs.push_back(boxTensor);
for (const string& name : globalNames)
inputs.push_back(torch::tensor(context.getParameter(name)));
return inputs;
for (const string& name : globalNames) {
// PyTorch requires us to set requires_grad to false before initializing a tensor.
globalTensors[name].set_requires_grad(false);
globalTensors[name][0] = context.getParameter(name);
if (paramDerivs.find(name) != paramDerivs.end())
globalTensors[name].set_requires_grad(true);
inputs.push_back(globalTensors[name]);
}
}

/**
Expand All @@ -173,26 +184,34 @@ void CudaCalcTorchForceKernel::addForces(torch::Tensor& forceTensor) {
}

/**
* This function launches the workload in a way compatible with CUDA
* graphs as far as OpenMM-Torch goes. Capturing this function when
* the model is not itself graph compatible (due to, for instance,
* This function launches the workload in a way compatible with CUDA
* graphs as far as OpenMM-Torch goes. Capturing this function when
* the model is not itself graph compatible (due to, for instance,
* implicit synchronizations) will result in a CUDA error.
*/
static void executeGraph(bool outputsForces, bool includeForces, torch::jit::script::Module& module, vector<torch::jit::IValue>& inputs, torch::Tensor& posTensor, torch::Tensor& energyTensor,
torch::Tensor& forceTensor) {
torch::Tensor& forceTensor, map<string, torch::Tensor>& globalTensors, set<string> paramDerivs) {
vector<torch::Tensor> gradInputs;
if (!outputsForces && includeForces)
gradInputs.push_back(posTensor);
for (auto& name : paramDerivs)
gradInputs.push_back(globalTensors[name]);
auto none = torch::Tensor();
if (outputsForces) {
auto outputs = module.forward(inputs).toTuple();
energyTensor = outputs->elements()[0].toTensor();
forceTensor = outputs->elements()[1].toTensor();
if (gradInputs.size() > 0)
energyTensor.backward(none, false, false, gradInputs);
} else {
energyTensor = module.forward(inputs).toTensor();
// Compute force by backpropagating the PyTorch model
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
// See https://github.com/openmm/openmm-torch/pull/120/
if (gradInputs.size() > 0)
energyTensor.backward(none, false, false, gradInputs);
if (includeForces) {
// CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
// See https://github.com/openmm/openmm-torch/pull/120/
auto none = torch::Tensor();
energyTensor.backward(none, false, false, posTensor);
// This is minus the forces, we change the sign later on
// This is minus the forces, we change the sign later on
forceTensor = posTensor.grad().clone();
// Zero the gradient to avoid accumulating it
posTensor.grad().zero_();
Expand All @@ -203,31 +222,32 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr
double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
// Push to the PyTorch context
CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context");
auto inputs = prepareTorchInputs(context);
vector<torch::jit::IValue> inputs;
prepareTorchInputs(context, inputs, globalTensors);
if (!useGraphs) {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
} else {
// Record graph if not already done
bool is_graph_captured = false;
if (graphs.find(includeForces) == graphs.end()) {
//CUDA graph capture must occur in a non-default stream
const auto stream = c10::cuda::getStreamFromPool(false, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
const c10::cuda::CUDAStreamGuard guard(stream);
// Warmup the graph workload before capturing. This first
// run before capture sets up allocations so that no
// allocations are needed after. Pytorch's allocator is
// stream capture-aware and, after warmup, will provide
// record static pointers and shapes during capture.
try {
for (int i = 0; i < this->warmupSteps; i++)
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
}
catch (std::exception& e) {
throw OpenMMException(string("TorchForce Failed to warmup the model before graph construction. Torch reported the following error:\n") + e.what());
}
graphs[includeForces].capture_begin();
try {
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor);
executeGraph(outputsForces, includeForces, module, inputs, posTensor, energyTensor, forceTensor, globalTensors, paramDerivs);
is_graph_captured = true;
graphs[includeForces].capture_end();
}
Expand All @@ -237,16 +257,23 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
}
throw OpenMMException(string("TorchForce Failed to capture the model into a CUDA graph. Torch reported the following error:\n") + e.what());
}
for (const string& name : paramDerivs)
globalTensors[name].grad().zero_();
}
// Use the same stream as the OpenMM context, even if it is the default stream
// Use the same stream as the OpenMM context, even if it is the default stream
const auto openmmStream = cu.getCurrentStream();
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
const auto stream = c10::cuda::getStreamFromExternal(openmmStream, cu.getDeviceIndex());
const c10::cuda::CUDAStreamGuard guard(stream);
graphs[includeForces].replay();
}
if (includeForces) {
addForces(forceTensor);
}
map<string, double>& energyParamDerivs = cu.getEnergyParamDerivWorkspace();
for (const string& name : paramDerivs) {
energyParamDerivs[name] += globalTensors[name].grad().item<double>();
globalTensors[name].grad().zero_();
}
// Get energy
const double energy = energyTensor.item<double>(); // This implicitly synchronizes the PyTorch context
// Pop to the PyTorch context
Expand Down
7 changes: 5 additions & 2 deletions platforms/cuda/src/CudaTorchKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: Raimondas Galvelis, Raul P. Pelaez *
* *
Expand Down Expand Up @@ -37,6 +37,7 @@
#include "openmm/cuda/CudaArray.h"
#include <torch/version.h>
#include <ATen/cuda/CUDAGraph.h>
#include <set>

namespace TorchPlugin {

Expand Down Expand Up @@ -71,12 +72,14 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel {
torch::jit::script::Module module;
torch::Tensor posTensor, boxTensor;
torch::Tensor energyTensor, forceTensor;
std::map<std::string, torch::Tensor> globalTensors;
std::vector<std::string> globalNames;
std::set<std::string> paramDerivs;
bool usePeriodic, outputsForces;
CUfunction copyInputsKernel, addForcesKernel;
CUcontext primaryContext;
std::map<bool, at::cuda::CUDAGraph> graphs;
std::vector<torch::jit::IValue> prepareTorchInputs(OpenMM::ContextImpl& context);
void prepareTorchInputs(OpenMM::ContextImpl& context, std::vector<torch::jit::IValue>& inputs, std::map<std::string, torch::Tensor>& derivInputs);
bool useGraphs;
void addForces(torch::Tensor& forceTensor);
int warmupSteps;
Expand Down
24 changes: 19 additions & 5 deletions platforms/cuda/tests/TestCudaTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2022 Stanford University and the Authors. *
* Portions copyright (c) 2018-2024 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand Down Expand Up @@ -129,7 +129,7 @@ void testPeriodicForce() {
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
}

void testGlobal() {
void testGlobal(bool useGraphs) {
// Create a random cloud of particles.

const int numParticles = 10;
Expand All @@ -143,6 +143,8 @@ void testGlobal() {
}
TorchForce* force = new TorchForce("tests/global.pt");
force->addGlobalParameter("k", 2.0);
force->addEnergyParameterDerivative("k");
force->setProperty("useCUDAGraphs", useGraphs ? "true" : "false");
system.addForce(force);

// Compute the forces and energy.
Expand All @@ -151,7 +153,7 @@ void testGlobal() {
Platform& platform = Platform::getPlatformByName("CUDA");
Context context(system, integ, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
State state = context.getState(State::Energy | State::Forces | State::ParameterDerivatives);

// See if the energy is correct. The network defines a potential of the form E(r) = k*|r|^2

Expand All @@ -164,15 +166,26 @@ void testGlobal() {
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);

// Check the gradient of the energy with respect to the parameter.

double expected = 0.0;
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
expected += pos.dot(pos);
}
double actual = state.getEnergyParameterDerivatives().at("k");
ASSERT_EQUAL_TOL(expected, actual, 1e-5);

// Change the global parameter and see if the forces are still correct.

context.setParameter("k", 3.0);
state = context.getState(State::Forces);
state = context.getState(State::Forces | State::ParameterDerivatives);
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
double r = sqrt(pos.dot(pos));
ASSERT_EQUAL_VEC(pos*(-6.0), state.getForces()[i], 1e-5);
}
ASSERT_EQUAL_TOL(expected, state.getEnergyParameterDerivatives().at("k"), 1e-5);
}

int main(int argc, char* argv[]) {
Expand All @@ -183,7 +196,8 @@ int main(int argc, char* argv[]) {
testForce(false);
testForce(true);
testPeriodicForce();
testGlobal();
testGlobal(false);
testGlobal(true);
}
catch(const std::exception& e) {
std::cout << "exception: " << e.what() << std::endl;
Expand Down
Loading

0 comments on commit e17b5d0

Please sign in to comment.