Skip to content

Commit de7e264

Browse files
authored
Restrict ONNX opset to 16 and up (#3051)
* Update ONNX-IR documentation with more comprehensive description * Fix build issues with data structure changes - Update struct TensorData field access patterns - Add support for non-optional data fields - Fix issues with tensor data handling * Fix build issues with TensorType structure changes The PR resolves several compilation errors caused by changes to the TensorType structure: 1. Modified code to adapt to the removal of the `shape` field from TensorType 2. Fixed pattern matching issues in rank_inference.rs to properly match against TensorData.data 3. Updated from_onnx.rs's remap_unsqueeze_to_reshape function to work with the new API 4. Fixed unused imports across multiple files 5. Fixed function calls that were using Option.len() incorrectly * Add static shape handling and rank inference for tensor operations Enhance tensor type system to support both static shapes and dynamic ranks across multiple ONNX operations including Expand, RandomNormal, Constant, and related nodes. Ensure proper shape validation and improve type safety throughout the conversion process. * Fix clippy warnings * Fix merge issues * Enable unsqueeze with runtime axes values * Fix clippy error * Remove default fall back * Removed dead code. * Removed rank from TensroData * Removed elem_type from TensorData * Simplify elem_type match expressions with pattern grouping * Add static_shape back * Add restriction for ONNX opset version >= 16 * Add onnx opset upgrade script * Update onnx-model.md * Removed onnx files for opsets < 16 * Skip opset upgrades if opset >= 16 * Bring back moved onnx file * Fix clippy * Updated opset script per PR feedback * Reimplement topk onnx and tests for opset16 * Update README.md * Include infer_shapes step in the upgrade script
1 parent 009ad59 commit de7e264

34 files changed

+443
-848
lines changed

burn-book/src/import/onnx-model.md

Lines changed: 90 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,89 @@
11
# Importing ONNX Models in Burn
22

3-
## Table of Contents
4-
5-
1. [Introduction](#introduction)
6-
2. [Why Import Models?](#why-import-models)
7-
3. [Understanding ONNX](#understanding-onnx)
8-
4. [Burn's ONNX Support](#burns-onnx-support)
9-
5. [Step-by-Step Guide](#step-by-step-guide)
10-
6. [Advanced Configuration](#advanced-configuration)
11-
7. [Loading and Using Models](#loading-and-using-models)
12-
8. [Troubleshooting](#troubleshooting)
13-
9. [Examples and Resources](#examples-and-resources)
14-
10. [Conclusion](#conclusion)
15-
163
## Introduction
174

18-
As the field of deep learning continues to evolve, the need for interoperability between different
19-
frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust,
20-
recognizes this need and provides robust support for importing models from other popular frameworks.
21-
This section focuses on importing
5+
As deep learning evolves, interoperability between frameworks becomes crucial. Burn, a modern deep
6+
learning framework in Rust, provides robust support for importing models from other popular
7+
frameworks. This section focuses on importing
228
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn,
23-
enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep
24-
learning projects.
9+
enabling you to leverage pre-trained models in your Rust-based deep learning projects.
2510

2611
## Why Import Models?
2712

2813
Importing pre-trained models offers several advantages:
2914

30-
1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and
31-
resource-intensive.
15+
1. **Time-saving**: Skip the resource-intensive process of training models from scratch.
3216
2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by
3317
researchers and industry leaders.
3418
3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from
3519
knowledge transfer.
36-
4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework
37-
to another.
20+
4. **Consistency across frameworks**: Maintain consistent performance when moving between
21+
frameworks.
3822

3923
## Understanding ONNX
4024

41-
ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models.
42-
Key features include:
25+
ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models
26+
with these key features:
4327

44-
- **Framework agnostic**: ONNX provides a common format that works across various deep learning
28+
- **Framework agnostic**: Provides a common format that works across various deep learning
4529
frameworks.
46-
- **Comprehensive representation**: It captures both the model architecture and trained weights.
47-
- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX
48-
export.
30+
- **Comprehensive representation**: Captures both the model architecture and trained weights.
31+
- **Wide support**: Compatible with popular frameworks like PyTorch, TensorFlow, and scikit-learn.
4932

50-
By using ONNX, you can easily move models between different frameworks and deployment environments.
33+
This standardization allows seamless movement of models between different frameworks and deployment
34+
environments.
5135

5236
## Burn's ONNX Support
5337

54-
Burn takes a unique approach to ONNX import, offering several advantages:
38+
Burn's approach to ONNX import offers unique advantages:
5539

56-
1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for
57-
deep integration with Burn's ecosystem.
58-
2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler,
40+
1. **Native Rust code generation**: Translates ONNX models into Rust source code for deep
41+
integration with Burn's ecosystem.
42+
2. **Compile-time optimization**: Leverages the Rust compiler to optimize the generated code,
5943
potentially improving performance.
60-
3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach
61-
eliminates this dependency.
62-
4. **Trainability**: Imported models can be further trained or fine-tuned using Burn.
63-
5. **Portability**: The generated Rust code can be compiled for various targets, including
64-
WebAssembly and embedded devices.
65-
6. **Any Burn Backend**: The imported models can be used with any of Burn's backends.
44+
3. **No runtime dependency**: Eliminates the need for an ONNX runtime, unlike many other solutions.
45+
4. **Trainability**: Allows imported models to be further trained or fine-tuned using Burn.
46+
5. **Portability**: Enables compilation for various targets, including WebAssembly and embedded
47+
devices.
48+
6. **Backend flexibility**: Works with any of Burn's supported backends.
49+
50+
## ONNX Compatibility
51+
52+
Burn requires ONNX models to use **opset version 16 or higher**. If your model uses an older
53+
version, you'll need to upgrade it using the ONNX version converter.
54+
55+
### Upgrading ONNX Models
56+
57+
There are two simple ways to upgrade your ONNX models to the required opset version:
58+
59+
Option 1: Use the provided utility script:
60+
61+
```
62+
uv run --script https://raw.githubusercontent.com/tracel-ai/burn/refs/heads/main/crates/burn-import/onnx_opset_upgrade.py
63+
```
64+
65+
Option 2: Use a custom Python script:
66+
67+
```python
68+
import onnx
69+
from onnx import version_converter, shape_inference
70+
71+
# Load your ONNX model
72+
model = onnx.load('path/to/your/model.onnx')
73+
74+
# Convert the model to opset version 16
75+
upgraded_model = version_converter.convert_version(model, 16)
76+
77+
# Apply shape inference to the upgraded model
78+
inferred_model = shape_inference.infer_shapes(upgraded_model)
79+
80+
# Save the converted model
81+
onnx.save(inferred_model, 'upgraded_model.onnx')
82+
```
6683

6784
## Step-by-Step Guide
6885

69-
Let's walk through the process of importing an ONNX model into a Burn project:
86+
Follow these steps to import an ONNX model into your Burn project:
7087

7188
### Step 1: Update `build.rs`
7289

@@ -90,7 +107,7 @@ fn main() {
90107
}
91108
```
92109

93-
This script uses `ModelGen` to generate Rust code from your ONNX model during the build process.
110+
This generates Rust code from your ONNX model during the build process.
94111

95112
### Step 2: Modify `mod.rs`
96113

@@ -102,11 +119,9 @@ pub mod my_model {
102119
}
103120
```
104121

105-
This makes the generated model code available in your project.
106-
107122
### Step 3: Use the Imported Model
108123

109-
Now you can use the imported model in your Rust code:
124+
Now you can use the imported model in your code:
110125

111126
```rust
112127
use burn::tensor;
@@ -116,8 +131,7 @@ use model::my_model::Model;
116131
fn main() {
117132
let device = NdArrayDevice::default();
118133

119-
// Create model instance and load weights from target dir default device.
120-
// (see more load options below in "Loading and Using Models" section)
134+
// Create model instance and load weights from target dir default device
121135
let model: Model<NdArray<f32>> = Model::default();
122136

123137
// Create input tensor (replace with your actual input)
@@ -132,7 +146,7 @@ fn main() {
132146

133147
## Advanced Configuration
134148

135-
The `ModelGen` struct offers several configuration options:
149+
The `ModelGen` struct provides several configuration options:
136150

137151
```rust
138152
ModelGen::new()
@@ -144,72 +158,69 @@ ModelGen::new()
144158
.run_from_script();
145159
```
146160

147-
- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or
161+
- `record_type`: Defines the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or
148162
PrettyJson).
149-
- `half_precision`: Use half-precision (f16) for weights to reduce model size.
150-
- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires
151-
record type `Bincode`.
163+
- `half_precision`: Reduces model size by using half-precision (f16) for weights.
164+
- `embed_states`: Embeds model weights directly in the generated Rust code (requires record type
165+
`Bincode`).
152166

153167
## Loading and Using Models
154168

155-
Depending on your configuration, you can load models in different ways:
169+
Depending on your configuration, you can load models in several ways:
156170

157171
```rust
158-
// Create a new model instance with device. Initializes weights randomly and lazily.
159-
// You can load weights via `load_record` afterwards.
172+
// Create a new model instance with device
173+
// (initializes weights randomly and lazily; load weights via `load_record` afterward)
160174
let model = Model::<Backend>::new(&device);
161175

162-
// Load from a file (must specify weights file in the target output directory or copy it from there).
163-
// File type should match the record type specified in `ModelGen`.
176+
// Load from a file
177+
// (file type should match the record type specified in `ModelGen`)
164178
let model = Model::<Backend>::from_file("path/to/weights", &device);
165179

166180
// Load from embedded weights (if embed_states was true)
167181
let model = Model::<Backend>::from_embedded(&device);
168182

169-
// Load from the out director location and load to default device (useful for testing)
183+
// Load from the output directory with default device (useful for testing)
170184
let model = Model::<Backend>::default();
171185
```
172186

173187
## Troubleshooting
174188

175-
Here are some common issues and their solutions:
189+
Common issues and solutions:
176190

177-
1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the
191+
1. **Unsupported ONNX operator**: Check the
178192
[list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
179-
You may need to simplify your model or wait for support to be added.
193+
You may need to simplify your model or wait for support.
180194

181-
2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check
182-
that the ONNX file path in `build.rs` is correct.
195+
2. **Build errors**: Ensure your `burn-import` version matches your Burn version and verify the ONNX
196+
file path in `build.rs`.
183197

184-
3. **Runtime errors**: If you get errors when running your model, double-check that your input
185-
tensors match the expected shape and data type of your model.
198+
3. **Runtime errors**: Confirm that your input tensors match the expected shape and data type of
199+
your model.
186200

187-
4. **Performance issues**: If your imported model is slower than expected, try using the
188-
`half_precision` option to reduce memory usage, or experiment with different `record_type`
189-
options.
201+
4. **Performance issues**: Try using the `half_precision` option to reduce memory usage or
202+
experiment with different `record_type` options.
190203

191-
5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR`
192-
directory specified in `build.rs` (usually `target/debug/build/<project>/out`).
204+
5. **Viewing generated files**: Find the generated Rust code and weights in the `OUT_DIR` directory
205+
(usually `target/debug/build/<project>/out`).
193206

194207
## Examples and Resources
195208

196-
For more detailed examples, check out:
209+
For practical examples, check out:
197210

198211
1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
199212
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
200213

201-
These examples demonstrate real-world usage of ONNX import in Burn projects.
214+
These demonstrate real-world usage of ONNX import in Burn projects.
202215

203216
## Conclusion
204217

205-
Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage
206-
pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's
207-
safety features. By following this guide, you should be able to seamlessly integrate ONNX models
208-
into your Burn projects, whether for inference, fine-tuning, or as a starting point for further
209-
development.
218+
Importing ONNX models into Burn combines the vast ecosystem of pre-trained models with Burn's
219+
performance and Rust's safety features. Following this guide, you can seamlessly integrate ONNX
220+
models into your Burn projects for inference, fine-tuning, or further development.
210221

211-
Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX
212-
operators and improve performance. Stay tuned to the Burn repository for updates and new features!
222+
The `burn-import` crate is actively developed, with ongoing work to support more ONNX operators and
223+
improve performance. Stay tuned to the Burn repository for updates!
213224

214225
---
215226

crates/burn-import/onnx-tests/build.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ fn main() {
1313
.input("tests/avg_pool2d/avg_pool2d.onnx")
1414
.input("tests/batch_norm/batch_norm.onnx")
1515
.input("tests/cast/cast.onnx")
16-
.input("tests/clip/clip_opset16.onnx")
17-
.input("tests/clip/clip_opset7.onnx")
16+
.input("tests/clip/clip.onnx")
1817
.input("tests/concat/concat.onnx")
1918
.input("tests/constant/constant_f32.onnx")
2019
.input("tests/constant/constant_f64.onnx")
@@ -31,8 +30,7 @@ fn main() {
3130
.input("tests/cos/cos.onnx")
3231
.input("tests/cosh/cosh.onnx")
3332
.input("tests/div/div.onnx")
34-
.input("tests/dropout/dropout_opset16.onnx")
35-
.input("tests/dropout/dropout_opset7.onnx")
33+
.input("tests/dropout/dropout.onnx")
3634
.input("tests/equal/equal.onnx")
3735
.input("tests/erf/erf.onnx")
3836
.input("tests/exp/exp.onnx")
@@ -97,8 +95,7 @@ fn main() {
9795
.input("tests/reduce_mean/reduce_mean.onnx")
9896
.input("tests/reduce_min/reduce_min.onnx")
9997
.input("tests/reduce_prod/reduce_prod.onnx")
100-
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
101-
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
98+
.input("tests/reduce_sum/reduce_sum.onnx")
10299
.input("tests/relu/relu.onnx")
103100
.input("tests/reshape/reshape.onnx")
104101
.input("tests/resize/resize_with_sizes.onnx")
@@ -116,22 +113,20 @@ fn main() {
116113
.input("tests/softmax/softmax.onnx")
117114
.input("tests/sqrt/sqrt.onnx")
118115
.input("tests/squeeze/squeeze_multiple.onnx")
119-
.input("tests/squeeze/squeeze_opset13.onnx")
120-
.input("tests/squeeze/squeeze_opset16.onnx")
116+
.input("tests/squeeze/squeeze.onnx")
121117
.input("tests/sub/sub.onnx")
122118
.input("tests/sub/sub_int.onnx")
123119
.input("tests/sum/sum.onnx")
124120
.input("tests/sum/sum_int.onnx")
125121
.input("tests/tan/tan.onnx")
126122
.input("tests/tanh/tanh.onnx")
127123
.input("tests/tile/tile.onnx")
128-
.input("tests/top_k/top_k_opset_1.onnx")
124+
.input("tests/topk/topk.onnx")
129125
.input("tests/trilu/trilu_upper.onnx")
130126
.input("tests/trilu/trilu_lower.onnx")
131127
.input("tests/transpose/transpose.onnx")
132-
.input("tests/unsqueeze/unsqueeze.onnx")
133-
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
134-
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
128+
.input("tests/unsqueeze/unsqueeze_runtime_axes.onnx")
129+
.input("tests/unsqueeze/unsqueeze_like.onnx")
135130
.input("tests/split/split.onnx")
136131
.out_dir("model/")
137132
.run_from_script();

crates/burn-import/onnx-tests/tests/clip/clip_opset16.onnx renamed to crates/burn-import/onnx-tests/tests/clip/clip.onnx

581 Bytes
Binary file not shown.

crates/burn-import/onnx-tests/tests/clip/clip_opset16.py renamed to crates/burn-import/onnx-tests/tests/clip/clip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def main():
2929
model.eval()
3030
device = torch.device("cpu")
3131

32-
file_name = "clip_opset16.onnx"
32+
file_name = "clip.onnx"
3333
test_input = torch.rand(6, device=device)
3434
torch.onnx.export(model, test_input, file_name,
3535
verbose=False, opset_version=16)
-275 Bytes
Binary file not shown.

crates/burn-import/onnx-tests/tests/clip/clip_opset7.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

crates/burn-import/onnx-tests/tests/dropout/dropout_opset16.onnx renamed to crates/burn-import/onnx-tests/tests/dropout/dropout.onnx

385 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)