-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathonnx_export.py
51 lines (38 loc) · 1.33 KB
/
onnx_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import torch
from models.bisenet import BiSeNet
def torch2onnx_export(params):
num_classes = 19
model = BiSeNet(num_classes, backbone_name=params.model)
model.load_state_dict(torch.load(params.weight))
model.eval()
onnx_model_path = params.weight.replace(".pt", ".onnx")
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
# Export the model to ONNX
torch.onnx.export(
model,
dummy_input,
onnx_model_path,
export_params=True,
opset_version=20, # the ONNX version to export the model to
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}
}
)
def parse_args():
parser = argparse.ArgumentParser(description="Face parsing inference")
parser.add_argument("--model", type=str, default="resnet18", help="model name, i.e resnet18, resnet34")
parser.add_argument(
"--weight",
type=str,
default="./weights/resnet18.pt",
help="path to trained model, i.e resnet18/34"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
torch2onnx_export(params=args)