-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Description
Hi,
I was trying to add a global parameter to my torchForce, but I could not figure out how to make it work along with the PBC. Could you help me have a look?
Here is the structure of my torchForce class:
class ForceModule(torch.nn.Module):
def __init__(self):
...
def forward(self, coordinates: torch.Tensor, cell: torch.Tensor, scale: float, RcNL: float = 10.0):
return energy
Then when I try to load the model with
force = TorchForce('model.pt')
force.setUsesPeriodicBoundaryConditions(True)
force.addGlobalParameter('scale', 0.9)
and try to run a MD with it, I get the following error:
Traceback (most recent call last):
File "/home/xie1/torchSANI_jit/scripts/openmm_test_3.py", line 128, in <module>
state = simulation.context.getState(getEnergy=True, groups={10})
File "/home/xie1/psi4conda/envs/nnpops_env/lib/python3.10/site-packages/openmm/openmm.py", line 12111, in getState
state = _openmm.Context_getState(self, types, enforcePeriodicBox, groups_mask)
openmm.OpenMMException: forward() Expected a value of type 'float' for argument 'scale' but instead found type 'Tensor'.
Position: 3
Declaration: forward(__torch__.___torch_mangle_0.ForceModule self, Tensor coordinates, Tensor cell, float scale, float RcNL=10.) -> Tensor
Thank you so much!
Metadata
Metadata
Assignees
Labels
No labels