Skip to content

Commit 248cedd

Browse files
committed
some more fixes for ONNX
1 parent e2a4f06 commit 248cedd

File tree

3 files changed

+111
-40
lines changed

3 files changed

+111
-40
lines changed

tests/test_onnx.py

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313

1414
@pytest.mark.skipif(not ONNXSCRIPT_AVAILABLE, reason="onnxscript not available")
1515
def test_onnx_export(tmp_path):
16-
from torchmdnet.models.model import create_model
16+
from torchmdnet.models.model import create_model, load_model
1717
from utils import load_example_args
1818
import torch as pt
19+
import numpy as np
1920

2021
device = "cuda" # "cuda" if pt.cuda.is_available() else "cpu"
2122

2223
ben = {
2324
"z": pt.tensor(
2425
[6, 6, 6, 6, 6, 6, 6, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1],
25-
dtype=pt.long,
26+
dtype=pt.int,
2627
device=device,
2728
),
2829
"pos": pt.tensor(
@@ -52,7 +53,7 @@ def test_onnx_export(tmp_path):
5253
),
5354
"batch": pt.tensor(
5455
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
55-
dtype=pt.long,
56+
dtype=pt.int,
5657
device=device,
5758
),
5859
"box": pt.tensor(
@@ -64,11 +65,11 @@ def test_onnx_export(tmp_path):
6465
dtype=pt.float32,
6566
device=device,
6667
),
67-
"q": pt.tensor([1], dtype=pt.long, device=device),
68+
"q": pt.tensor([1], dtype=pt.int, device=device),
6869
}
6970
# Water example
70-
example = {
71-
"z": pt.tensor([8, 1, 1], dtype=pt.long, device=device, requires_grad=False),
71+
water = {
72+
"z": pt.tensor([8, 1, 1], dtype=pt.int, device=device, requires_grad=False),
7273
"pos": pt.tensor(
7374
[
7475
[60.243, 56.013, 55.451],
@@ -81,7 +82,7 @@ def test_onnx_export(tmp_path):
8182
),
8283
"batch": pt.tensor(
8384
[0, 0, 0],
84-
dtype=pt.long,
85+
dtype=pt.int,
8586
device=device,
8687
requires_grad=False,
8788
),
@@ -95,7 +96,7 @@ def test_onnx_export(tmp_path):
9596
device=device,
9697
requires_grad=False,
9798
),
98-
"q": pt.tensor([0], dtype=pt.long, device=device, requires_grad=False),
99+
"q": pt.tensor([0], dtype=pt.int, device=device, requires_grad=False),
99100
}
100101

101102
model = create_model(
@@ -108,11 +109,53 @@ def test_onnx_export(tmp_path):
108109
onnx_export=True,
109110
)
110111
)
112+
model = load_model(
113+
os.path.join(curr_dir, "aceff-1.2-xtb.ckpt"),
114+
static_shapes=False,
115+
onnx_export=True,
116+
)
111117

118+
example = ben
112119
model.to(device)
113120
model.eval()
114-
out = model(**example)
115-
print(out)
121+
ref_energy, ref_forces = model(**example)
122+
ref_energy = ref_energy.detach().cpu().numpy()
123+
ref_forces = ref_forces.detach().cpu().numpy()
124+
print(ref_energy, "\n", ref_forces)
125+
126+
n_atoms = 573
127+
pt.onnx.export(
128+
model, # model to export
129+
(
130+
pt.ones(n_atoms, dtype=pt.int, device=device, requires_grad=False),
131+
pt.ones((n_atoms, 3), dtype=pt.float32, device=device, requires_grad=True),
132+
pt.zeros(n_atoms, dtype=pt.int, device=device, requires_grad=False),
133+
pt.ones((3, 3), dtype=pt.float32, device=device, requires_grad=False),
134+
pt.zeros(1, dtype=pt.int, device=device, requires_grad=False),
135+
), # inputs of the model,
136+
os.path.join(
137+
tmp_path, f"aceff-1.2-xtb-{n_atoms}atoms.onnx"
138+
), # filename of the ONNX model
139+
input_names=[
140+
"atomic_numbers",
141+
"positions",
142+
"batch",
143+
"box",
144+
"total_charge",
145+
], # Rename inputs for the ONNX model
146+
output_names=["energy", "forces"],
147+
dynamic_axes={
148+
"atomic_numbers": {0: "atoms"},
149+
"positions": {0: "atoms"},
150+
"batch": {0: "atoms"},
151+
"forces": {0: "atoms"},
152+
},
153+
dynamo=False,
154+
# report=True,
155+
# opset_version=20,
156+
do_constant_folding=True,
157+
export_params=True,
158+
)
116159

117160
pt.onnx.export(
118161
model, # model to export
@@ -123,39 +166,51 @@ def test_onnx_export(tmp_path):
123166
example["box"],
124167
example["q"],
125168
), # inputs of the model,
126-
os.path.join(tmp_path, "my_model.onnx"), # filename of the ONNX model
169+
os.path.join(
170+
tmp_path, f"aceff-1.2-xtb-18atoms.onnx"
171+
), # filename of the ONNX model
127172
input_names=[
128-
"z",
129-
"pos",
173+
"atomic_numbers",
174+
"positions",
130175
"batch",
131176
"box",
132-
"q",
177+
"total_charge",
133178
], # Rename inputs for the ONNX model
134179
output_names=["energy", "forces"],
135180
dynamic_axes={
136-
"z": {0: "atoms"},
137-
"pos": {0: "atoms"},
181+
"atomic_numbers": {0: "atoms"},
182+
"positions": {0: "atoms"},
138183
"batch": {0: "atoms"},
139-
# "energy": {0: "batch"},
140184
"forces": {0: "atoms"},
141185
},
142-
dynamo=False, # True or False to select the exporter to use
143-
report=True,
144-
opset_version=20,
186+
dynamo=False,
187+
# report=True,
188+
# opset_version=20,
189+
do_constant_folding=True,
190+
export_params=True,
145191
)
146192

147193
# Test the exported ONNX model
148-
import onnx
149194
import onnxruntime as ort
195+
import onnx
150196

197+
example = ben
151198
model_path = os.path.join(tmp_path, "my_model.onnx")
152-
session = ort.InferenceSession(model_path)
199+
onnx.checker.check_model(onnx.load(model_path))
200+
session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
153201
inputs = {
154-
"z": example["z"].cpu().numpy(),
155-
"pos": example["pos"].detach().cpu().numpy(),
202+
"atomic_numbers": example["z"].cpu().numpy(),
203+
"positions": example["pos"].detach().cpu().numpy(),
156204
"batch": example["batch"].cpu().numpy(),
157-
# "box": example["box"].cpu().numpy(),
158-
"q": example["q"].cpu().numpy(),
205+
"total_charge": example["q"].cpu().numpy(),
159206
}
160-
outputs = session.run(None, inputs)
161-
print(outputs)
207+
onnx_energy, onnx_forces = session.run(None, inputs)
208+
print(onnx_energy, "\n", onnx_forces)
209+
print("Forces diff", np.abs(ref_forces - onnx_forces).max())
210+
print("Energy diff", np.abs(ref_energy - onnx_energy).max())
211+
assert np.allclose(ref_forces, onnx_forces), "Forces are not close"
212+
assert np.allclose(ref_energy, onnx_energy), "Energy is not close"
213+
214+
215+
if __name__ == "__main__":
216+
test_onnx_export("/tmp/")

torchmdnet/models/model.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,9 @@ def energy_fn(
456456
for prior in self.prior_model:
457457
y = prior.post_reduce(y, z, pos, batch, box, extra_args)
458458

459+
if self.onnx_export:
460+
self.y = y
461+
459462
return y
460463

461464
def forward(
@@ -503,20 +506,23 @@ def forward(
503506
Returns:
504507
Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise.
505508
"""
506-
assert z.dim() == 1 and z.dtype == torch.long
509+
assert z.dim() == 1 and z.dtype == torch.int
507510
batch = torch.zeros_like(z) if batch is None else batch
508511

509512
if self.derivative:
510513
pos.requires_grad_(True)
511514

512-
y = self.energy_fn(z, pos, batch, box, q, s, extra_args)
513-
514-
def energy_wrapper(pos):
515-
return self.energy_fn(z, pos, batch, box, q, s, extra_args)
515+
if not self.onnx_export:
516+
y = self.energy_fn(z, pos, batch, box, q, s, extra_args)
516517

517518
if self.derivative:
518519
if self.onnx_export:
520+
521+
def energy_wrapper(pos):
522+
return self.energy_fn(z, pos, batch, box, q, s, extra_args)
523+
519524
dy = torch.autograd.functional.jacobian(energy_wrapper, pos)
525+
y = self.y
520526
else:
521527
grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
522528
dy = grad(

torchmdnet/models/tensornet.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ def decompose_tensor(tensor):
5656
return I, A, S
5757

5858

59+
def custom_tril_indices(rows, cols, offset=0, device=None, dtype=torch.long):
60+
row_indices = torch.arange(rows, device=device).view(-1, 1) # shape (rows, 1)
61+
col_indices = torch.arange(cols, device=device).view(1, -1) # shape (1, cols)
62+
63+
# Create a mask for the lower triangle
64+
mask = row_indices - col_indices >= -offset # broadcasting comparison
65+
66+
# Get indices where the mask is True
67+
result = mask.nonzero(as_tuple=False).t().to(dtype) # shape (2, N)
68+
69+
return result
70+
71+
5972
def tensor_norm(tensor):
6073
"""Computes Frobenius norm."""
6174
return (tensor**2).sum((-2, -1))
@@ -136,7 +149,7 @@ def __init__(
136149
check_errors=True,
137150
dtype=torch.float32,
138151
box_vecs=None,
139-
onnx_export=True,
152+
onnx_export=False,
140153
):
141154
super(TensorNet, self).__init__()
142155

@@ -212,11 +225,6 @@ def __init__(
212225
box=box_vecs,
213226
long_edge_index=True,
214227
)
215-
else:
216-
# TODO: Make this work with given size
217-
self.register_buffer(
218-
"edge_index", torch.tensor([[1, 2, 2], [0, 0, 1]], dtype=torch.long)
219-
)
220228

221229
self.reset_parameters()
222230

@@ -240,7 +248,9 @@ def forward(
240248
if not self.onnx_export:
241249
edge_index, edge_weight, edge_vec = self.distance(pos, batch, box)
242250
else:
243-
edge_index = self.edge_index
251+
edge_index = custom_tril_indices(
252+
pos.shape[0], pos.shape[0], device=pos.device, offset=pos.shape[0]
253+
)
244254
edge_vec = pos[edge_index[0]] - pos[edge_index[1]]
245255
edge_weight = torch.norm(edge_vec, dim=-1)
246256

0 commit comments

Comments
 (0)