Skip to content

Commit 611cff0

Browse files
committed
added docstring to AbstractWrapper.forward
1 parent e0a9abd commit 611cff0

File tree

1 file changed

+10
-1
lines changed
  • equitrain/backends/torch_wrappers

1 file changed

+10
-1
lines changed

equitrain/backends/torch_wrappers/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,16 @@ def __init__(self, model):
1818

1919
@abstractmethod
2020
def forward(self, *args):
21-
"""Implement the model forward pass."""
21+
"""Implement the model forward pass.
22+
23+
Returns:
24+
Mapping with at least an ``'energy'`` entry (shape ``[batch, 1]`` or
25+
``[batch]``). Wrappers that produce forces or stresses should also
26+
return ``'forces'`` (``[num_atoms, 3]``) and ``'stress'``
27+
(``[batch, 3, 3]``). Additional observables (dipoles, virials, etc.)
28+
can be included as extra keys; they are forwarded to the loss /
29+
metrics stack unchanged.
30+
"""
2231
raise NotImplementedError
2332

2433
@property

0 commit comments

Comments
 (0)