13
13
14
14
@pytest .mark .skipif (not ONNXSCRIPT_AVAILABLE , reason = "onnxscript not available" )
15
15
def test_onnx_export (tmp_path ):
16
- from torchmdnet .models .model import create_model
16
+ from torchmdnet .models .model import create_model , load_model
17
17
from utils import load_example_args
18
18
import torch as pt
19
+ import numpy as np
19
20
20
21
device = "cuda" # "cuda" if pt.cuda.is_available() else "cpu"
21
22
22
23
ben = {
23
24
"z" : pt .tensor (
24
25
[6 , 6 , 6 , 6 , 6 , 6 , 6 , 7 , 7 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ],
25
- dtype = pt .long ,
26
+ dtype = pt .int ,
26
27
device = device ,
27
28
),
28
29
"pos" : pt .tensor (
@@ -52,7 +53,7 @@ def test_onnx_export(tmp_path):
52
53
),
53
54
"batch" : pt .tensor (
54
55
[0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
55
- dtype = pt .long ,
56
+ dtype = pt .int ,
56
57
device = device ,
57
58
),
58
59
"box" : pt .tensor (
@@ -64,11 +65,11 @@ def test_onnx_export(tmp_path):
64
65
dtype = pt .float32 ,
65
66
device = device ,
66
67
),
67
- "q" : pt .tensor ([1 ], dtype = pt .long , device = device ),
68
+ "q" : pt .tensor ([1 ], dtype = pt .int , device = device ),
68
69
}
69
70
# Water example
70
- example = {
71
- "z" : pt .tensor ([8 , 1 , 1 ], dtype = pt .long , device = device , requires_grad = False ),
71
+ water = {
72
+ "z" : pt .tensor ([8 , 1 , 1 ], dtype = pt .int , device = device , requires_grad = False ),
72
73
"pos" : pt .tensor (
73
74
[
74
75
[60.243 , 56.013 , 55.451 ],
@@ -81,7 +82,7 @@ def test_onnx_export(tmp_path):
81
82
),
82
83
"batch" : pt .tensor (
83
84
[0 , 0 , 0 ],
84
- dtype = pt .long ,
85
+ dtype = pt .int ,
85
86
device = device ,
86
87
requires_grad = False ,
87
88
),
@@ -95,7 +96,7 @@ def test_onnx_export(tmp_path):
95
96
device = device ,
96
97
requires_grad = False ,
97
98
),
98
- "q" : pt .tensor ([0 ], dtype = pt .long , device = device , requires_grad = False ),
99
+ "q" : pt .tensor ([0 ], dtype = pt .int , device = device , requires_grad = False ),
99
100
}
100
101
101
102
model = create_model (
@@ -108,11 +109,53 @@ def test_onnx_export(tmp_path):
108
109
onnx_export = True ,
109
110
)
110
111
)
112
+ model = load_model (
113
+ os .path .join (curr_dir , "aceff-1.2-xtb.ckpt" ),
114
+ static_shapes = False ,
115
+ onnx_export = True ,
116
+ )
111
117
118
+ example = ben
112
119
model .to (device )
113
120
model .eval ()
114
- out = model (** example )
115
- print (out )
121
+ ref_energy , ref_forces = model (** example )
122
+ ref_energy = ref_energy .detach ().cpu ().numpy ()
123
+ ref_forces = ref_forces .detach ().cpu ().numpy ()
124
+ print (ref_energy , "\n " , ref_forces )
125
+
126
+ n_atoms = 573
127
+ pt .onnx .export (
128
+ model , # model to export
129
+ (
130
+ pt .ones (n_atoms , dtype = pt .int , device = device , requires_grad = False ),
131
+ pt .ones ((n_atoms , 3 ), dtype = pt .float32 , device = device , requires_grad = True ),
132
+ pt .zeros (n_atoms , dtype = pt .int , device = device , requires_grad = False ),
133
+ pt .ones ((3 , 3 ), dtype = pt .float32 , device = device , requires_grad = False ),
134
+ pt .zeros (1 , dtype = pt .int , device = device , requires_grad = False ),
135
+ ), # inputs of the model,
136
+ os .path .join (
137
+ tmp_path , f"aceff-1.2-xtb-{ n_atoms } atoms.onnx"
138
+ ), # filename of the ONNX model
139
+ input_names = [
140
+ "atomic_numbers" ,
141
+ "positions" ,
142
+ "batch" ,
143
+ "box" ,
144
+ "total_charge" ,
145
+ ], # Rename inputs for the ONNX model
146
+ output_names = ["energy" , "forces" ],
147
+ dynamic_axes = {
148
+ "atomic_numbers" : {0 : "atoms" },
149
+ "positions" : {0 : "atoms" },
150
+ "batch" : {0 : "atoms" },
151
+ "forces" : {0 : "atoms" },
152
+ },
153
+ dynamo = False ,
154
+ # report=True,
155
+ # opset_version=20,
156
+ do_constant_folding = True ,
157
+ export_params = True ,
158
+ )
116
159
117
160
pt .onnx .export (
118
161
model , # model to export
@@ -123,39 +166,51 @@ def test_onnx_export(tmp_path):
123
166
example ["box" ],
124
167
example ["q" ],
125
168
), # inputs of the model,
126
- os .path .join (tmp_path , "my_model.onnx" ), # filename of the ONNX model
169
+ os .path .join (
170
+ tmp_path , f"aceff-1.2-xtb-18atoms.onnx"
171
+ ), # filename of the ONNX model
127
172
input_names = [
128
- "z " ,
129
- "pos " ,
173
+ "atomic_numbers " ,
174
+ "positions " ,
130
175
"batch" ,
131
176
"box" ,
132
- "q " ,
177
+ "total_charge " ,
133
178
], # Rename inputs for the ONNX model
134
179
output_names = ["energy" , "forces" ],
135
180
dynamic_axes = {
136
- "z " : {0 : "atoms" },
137
- "pos " : {0 : "atoms" },
181
+ "atomic_numbers " : {0 : "atoms" },
182
+ "positions " : {0 : "atoms" },
138
183
"batch" : {0 : "atoms" },
139
- # "energy": {0: "batch"},
140
184
"forces" : {0 : "atoms" },
141
185
},
142
- dynamo = False , # True or False to select the exporter to use
143
- report = True ,
144
- opset_version = 20 ,
186
+ dynamo = False ,
187
+ # report=True,
188
+ # opset_version=20,
189
+ do_constant_folding = True ,
190
+ export_params = True ,
145
191
)
146
192
147
193
# Test the exported ONNX model
148
- import onnx
149
194
import onnxruntime as ort
195
+ import onnx
150
196
197
+ example = ben
151
198
model_path = os .path .join (tmp_path , "my_model.onnx" )
152
- session = ort .InferenceSession (model_path )
199
+ onnx .checker .check_model (onnx .load (model_path ))
200
+ session = ort .InferenceSession (model_path , providers = ["CUDAExecutionProvider" ])
153
201
inputs = {
154
- "z " : example ["z" ].cpu ().numpy (),
155
- "pos " : example ["pos" ].detach ().cpu ().numpy (),
202
+ "atomic_numbers " : example ["z" ].cpu ().numpy (),
203
+ "positions " : example ["pos" ].detach ().cpu ().numpy (),
156
204
"batch" : example ["batch" ].cpu ().numpy (),
157
- # "box": example["box"].cpu().numpy(),
158
- "q" : example ["q" ].cpu ().numpy (),
205
+ "total_charge" : example ["q" ].cpu ().numpy (),
159
206
}
160
- outputs = session .run (None , inputs )
161
- print (outputs )
207
+ onnx_energy , onnx_forces = session .run (None , inputs )
208
+ print (onnx_energy , "\n " , onnx_forces )
209
+ print ("Forces diff" , np .abs (ref_forces - onnx_forces ).max ())
210
+ print ("Energy diff" , np .abs (ref_energy - onnx_energy ).max ())
211
+ assert np .allclose (ref_forces , onnx_forces ), "Forces are not close"
212
+ assert np .allclose (ref_energy , onnx_energy ), "Energy is not close"
213
+
214
+
215
+ if __name__ == "__main__" :
216
+ test_onnx_export ("/tmp/" )
0 commit comments