From 25348cf181049d58656c5182222bf8e866d2d183 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Tue, 2 Jul 2024 15:17:44 -0500 Subject: [PATCH] Separating ONNX parsing from burn-import (#1921) * separating onnx parsing from burn-import * ran clippy and cargo-fmt * removed unused deps from onnx-ir * fixed clippy warnings that were causing run-checks to fail * removed dead code * removed unused dependencies from burn-import * updated contributor-book, updated publish.yml, added readme * update cargo lock * formatted md document with prettier, rephrased sentence * missed the errors with reduce_prod_conversion during merge * formatted onnx-to-burn-conversion-tool.md, forgot to save --- .github/workflows/publish.yml | 6 + Cargo.lock | 23 +- .../guides/adding-a-new-operation-to-burn.md | 28 +- .../guides/onnx-to-burn-conversion-tool.md | 28 +- crates/burn-import/Cargo.toml | 9 +- crates/burn-import/build.rs | 11 - crates/burn-import/src/onnx/mod.rs | 11 - .../burn-import/src/onnx/op_configuration.rs | 2 +- crates/burn-import/src/onnx/to_burn.rs | 383 +++++++++--------- crates/onnx-ir/Cargo.toml | 31 ++ crates/onnx-ir/README.md | 7 + crates/onnx-ir/build.rs | 9 + .../src/onnx => onnx-ir/src}/coalesce.rs | 2 +- .../src/onnx => onnx-ir/src}/dim_inference.rs | 4 +- .../src/onnx => onnx-ir/src}/from_onnx.rs | 65 ++- .../src/onnx => onnx-ir/src}/ir.rs | 6 +- crates/onnx-ir/src/lib.rs | 12 + .../src/onnx => onnx-ir/src}/node_remap.rs | 0 .../onnx => onnx-ir/src}/proto_conversion.rs | 2 +- .../src/onnx => onnx-ir/src}/protos/mod.rs | 0 .../onnx => onnx-ir/src}/protos/onnx.proto | 0 crates/onnx-ir/src/util.rs | 45 ++ 22 files changed, 405 insertions(+), 279 deletions(-) delete mode 100644 crates/burn-import/build.rs create mode 100644 crates/onnx-ir/Cargo.toml create mode 100644 crates/onnx-ir/README.md create mode 100644 crates/onnx-ir/build.rs rename crates/{burn-import/src/onnx => onnx-ir/src}/coalesce.rs (99%) rename crates/{burn-import/src/onnx => onnx-ir/src}/dim_inference.rs (99%) rename crates/{burn-import/src/onnx => onnx-ir/src}/from_onnx.rs (92%) rename crates/{burn-import/src/onnx => onnx-ir/src}/ir.rs (99%) create mode 100644 crates/onnx-ir/src/lib.rs rename crates/{burn-import/src/onnx => onnx-ir/src}/node_remap.rs (100%) rename crates/{burn-import/src/onnx => onnx-ir/src}/proto_conversion.rs (99%) rename crates/{burn-import/src/onnx => onnx-ir/src}/protos/mod.rs (100%) rename crates/{burn-import/src/onnx => onnx-ir/src}/protos/onnx.proto (100%) create mode 100644 crates/onnx-ir/src/util.rs diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 0d8b36ef81..ef8ea52d8b 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -161,3 +161,9 @@ jobs: with: crate: burn-import secrets: inherit + + publish-onnx-ir: + uses: tracel-ai/burn/.github/workflows/publish-template.yml@main + with: + crate: onnx-ir + secrets: inherit diff --git a/Cargo.lock b/Cargo.lock index 30e47d9004..e6dfedc830 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -628,23 +628,19 @@ name = "burn-import" version = "0.14.0" dependencies = [ "burn", - "bytemuck", "candle-core", "derive-new", "half", "log", + "onnx-ir", "pretty_assertions", "proc-macro2", - "protobuf", - "protobuf-codegen", "quote", "regex", "rstest", "rust-format", "serde", "serde_json", - "strum", - "strum_macros", "syn 2.0.68", "thiserror", "tracing-core", @@ -3689,6 +3685,23 @@ dependencies = [ "serde", ] +[[package]] +name = "onnx-ir" +version = "0.14.0" +dependencies = [ + "bytemuck", + "half", + "log", + "pretty_assertions", + "protobuf", + "protobuf-codegen", + "regex", + "rstest", + "serde", + "strum", + "strum_macros", +] + [[package]] name = "onnx-tests" version = "0.14.0" diff --git a/contributor-book/src/guides/adding-a-new-operation-to-burn.md b/contributor-book/src/guides/adding-a-new-operation-to-burn.md index 63cfa9c7dd..0af235d350 100644 --- a/contributor-book/src/guides/adding-a-new-operation-to-burn.md +++ b/contributor-book/src/guides/adding-a-new-operation-to-burn.md @@ -117,7 +117,8 @@ plug in your operator in terms of \\(x\\) and \\(y\\), and just swap out the var ### Testing autodiff -For testing the `autodiff` operations, please refer to [this section](../getting-started/testing.md). +For testing the `autodiff` operations, please refer to +[this section](../getting-started/testing.md). ## Adding the Op to other backends @@ -199,11 +200,15 @@ Generating the ONNX test files or tests is already covered about the specific changes you need to make when adding new operators after you have generated the tests. -The crate is divided into two sections `src/burn` and `src/onnx`. The code under the former -corresponds to the operation you've implemented earlier in this guide, and the latter to the -operations defined in the ONNX specification. So when you are loading a model, the operator is first -parsed to an intermediate representation defined by `src/onnx`, and then mapped to a Burn operation -defined under `src/burn/node`. +Changes will need to be made to both `onnx-ir` and `burn-import`. The code within `onnx-ir` defines +how to parse the nodes in an onnx file and produces the intermediate representation. The code within +`burn-import` is divided into two sections: `src/onnx` and `src/burn`. The code under the former +maps that intermediate representation to one used for code generation and the latter defines how to +generate code for the operator you've implemented earlier in this guide. + +So when you are loading a model, the operator is first parsed to an intermediate representation +defined by `burn-import` and then mapped to a Burn operation defined under `src/burn/node`; the +mapping from onnx to burn is aptly defined in `src/onnx/to_burn` Let's review the changes made for powf starting from `src/burn` and moving to `src/onnx`: @@ -218,17 +223,20 @@ Let's review the changes made for powf starting from `src/burn` and moving to `s [`{op}_conversion` function](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/to_burn.rs#L717) that maps the ONNX node to the binary type 3. Specify how dimensions for the output should be derived in - [crates/burn-import/src/onnx/dim_inference.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-import/src/onnx/dim_inference.rs#L55) + [crates/onnx-ir/src/dim_inference.rs](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/dim_inference.rs#L17) And you're done! Congrats, you just fully added a new operation to burn, and we are all one step closer to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being "Yes, and it's freaking fast!". Buy yourself a coffee. -[^supertrait]: for more on supertraits see +[^supertrait]: + for more on supertraits see [the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait) -[^autodiff]: wiki link for +[^autodiff]: + wiki link for [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) -[^absolute_units]: for more information on unit structs see +[^absolute_units]: + for more information on unit structs see [the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields) diff --git a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md index d9d31282cd..feced3fd43 100644 --- a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md +++ b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md @@ -16,7 +16,17 @@ For an introduction to ONNX import in Burn, see - [Design Goals](#design-goals) - [Design Decisions](#design-decisions) - [Adding New Operators](#adding-new-operators) - - [Implementing a New Operator](#implementing-a-new-operator) + - [Implementing a New Operator](#implementing-a-new-operator) + - [Step 1: Visibility](#step-1-visibility) + - [Step 2: Node Implementation](#step-2-node-implementation) + - [Within Onnx-IR](#within-onnx-ir) + - [Within burn-import](#within-burn-import) + - [Step 3: Registering New Operations](#step-3-registering-new-operations) + - [Step 4: Create a Config Function](#step-4-create-a-config-function) + - [Step 5: Dimension Inference](#step-5-dimension-inference) + - [Step 6: Integrate into the Graph Building Process](#step-6-integrate-into-the-graph-building-process) + - [Step 7: Add Newly Supported Op!](#step-7-add-newly-supported-op) + - [Misc:](#misc) - [Testing](#testing) - [Resources](#resources) @@ -91,6 +101,22 @@ located in the `src/burn/node/` directory. ### Step 2: Node Implementation +#### Within Onnx-IR + +If the node type does not exist within the +[`NodeType` enum](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/ir.rs#L246), +it will need to be added (support for custom operators is planned). If the node might be provided an +input which is a constant or the output of an identity node, it will need to be added to the list of +nodeTypes +[checked for constants](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L21). +The node will need to be added to `dim_inference`, and in most cases the work parsing side will be +done. If a node requires extra parsing (such as handling an edge case like potentially remapping an +unsqueeze to a reshape) the best place for that is after check constants and prior to dim_inference +in +[`OnnxGraphBuilder::Build`](https://github.com/tracel-ai/burn/blob/d4ae82b21ac3dd1def01bd380ab7ea4d3293eccb/crates/onnx-ir/src/from_onnx.rs#L221) + +#### Within burn-import + Create a new file named `.rs` in the `src/burn/node/` directory. This file will define the structure and functionality of your new operation. By convention, the necessary information for carrying out an operation is encapsulated within a struct named diff --git a/crates/burn-import/Cargo.toml b/crates/burn-import/Cargo.toml index b12296eaac..acfefeb2fc 100644 --- a/crates/burn-import/Cargo.toml +++ b/crates/burn-import/Cargo.toml @@ -20,30 +20,23 @@ pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"] [dependencies] burn = { path = "../burn", version = "0.14.0", features = ["ndarray"] } - -bytemuck = { workspace = true } +onnx-ir = { path = "../onnx-ir" } candle-core = { workspace = true } derive-new = { workspace = true } half = { workspace = true } log = { workspace = true } proc-macro2 = { workspace = true } -protobuf = { workspace = true, features = ["with-bytes"] } quote = { workspace = true } regex = { workspace = true } rust-format = { workspace = true, features = ["token_stream", "post_process"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } -strum = { workspace = true } -strum_macros = { workspace = true } syn = { workspace = true, features = ["parsing"] } thiserror = { workspace = true, optional = true } tracing-core = { workspace = true } tracing-subscriber = { workspace = true } zip = { workspace = true, optional = true } -[build-dependencies] -protobuf-codegen = { workspace = true } - [dev-dependencies] pretty_assertions = { workspace = true } rstest = { workspace = true } diff --git a/crates/burn-import/build.rs b/crates/burn-import/build.rs deleted file mode 100644 index f6cd681658..0000000000 --- a/crates/burn-import/build.rs +++ /dev/null @@ -1,11 +0,0 @@ -fn main() { - if cfg!(feature = "onnx") { - // Generate the onnx protobuf files - protobuf_codegen::Codegen::new() - .pure() - .includes(["src"]) - .input("src/onnx/protos/onnx.proto") - .cargo_out_dir("onnx-protos") - .run_from_script(); - } -} diff --git a/crates/burn-import/src/onnx/mod.rs b/crates/burn-import/src/onnx/mod.rs index b0d14549fb..b0b67d79c3 100644 --- a/crates/burn-import/src/onnx/mod.rs +++ b/crates/burn-import/src/onnx/mod.rs @@ -1,14 +1,3 @@ -mod coalesce; -mod dim_inference; -mod from_onnx; -mod ir; -mod node_remap; mod op_configuration; -mod proto_conversion; -mod protos; mod to_burn; - pub use to_burn::*; - -pub use from_onnx::parse_onnx; -pub use ir::OnnxGraph; diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 4cb8d2d019..f99cb960e6 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -5,8 +5,8 @@ use burn::nn::{ PaddingConfig2d, }; -use super::ir::{ArgType, AttributeValue, Data, Node}; use crate::burn::node::resize::ResizeMode; +use onnx_ir::ir::{ArgType, AttributeValue, Data, Node}; /// Create a Conv1dConfig from the attributes of the node pub fn conv1d_config(curr: &Node) -> Conv1dConfig { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index f4d287a602..09b416bbd5 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -53,20 +53,24 @@ use crate::{ }, format_tokens, logger::init_log, - onnx::{ - from_onnx::convert_constant_value, - ir::{Node, NodeType}, - op_configuration::*, - }, }; -use super::{ - from_onnx::parse_onnx, - ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph}, - op_configuration::{ - avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, - resize_config, softmax_config, +use super::op_configuration::{ + argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config, + concat_config, conv1d_config, conv2d_config, conv_transpose2d_config, dropout_config, + expand_config, flatten_config, gather_config, layer_norm_config, leaky_relu_config, + linear_config, log_softmax_config, max_pool1d_config, max_pool2d_config, reduce_max_config, + reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, + resize_config, shape_config, slice_config, softmax_config, squeeze_config, transpose_config, + unsqueeze_config, +}; +use onnx_ir::{ + convert_constant_value, + ir::{ + ArgType, Argument as OnnxArgument, Data, ElementType, Node, NodeType, OnnxGraph, + TensorType as OnnxTensorType, }, + parse_onnx, }; pub use crate::burn::graph::RecordType; @@ -197,6 +201,7 @@ impl ModelGen { log::debug!("Output file: {:?}", out_file); let graph = parse_onnx(input.as_ref()); + let graph = ParsedOnnxGraph(graph); if self.development { // export the graph @@ -231,15 +236,16 @@ impl ModelGen { log::info!("Model generated"); } } - -impl OnnxGraph { +#[derive(Debug)] +struct ParsedOnnxGraph(OnnxGraph); +impl ParsedOnnxGraph { /// Converts ONNX graph to Burn graph. pub fn into_burn(self) -> BurnGraph { let mut graph = BurnGraph::::default(); let mut unsupported_ops = vec![]; - for node in self.nodes { + for node in self.0.nodes { match node.node_type { NodeType::Add => graph.register(Self::add_conversion(node)), NodeType::ArgMax => graph.register(Self::argmax_conversion(node)), @@ -328,11 +334,13 @@ impl OnnxGraph { // Get input and output names let input_names = self + .0 .inputs .iter() .map(|input| input.name.clone()) .collect::>(); let output_names = self + .0 .outputs .iter() .map(|output| output.name.clone()) @@ -390,13 +398,13 @@ impl OnnxGraph { ArgType::Shape(_) => panic!("Shape is not supported as constant value."), }; - ConstantNode::new(node.name.clone(), const_value, output.to_type()) + ConstantNode::new(node.name.clone(), const_value, Type::from(output)) } fn random_uniform_conversion(node: Node) -> RandomUniformNode { let output = node.outputs.first().unwrap(); // cannot use output.to_tensor_type() here, since it drops the shape info... - let output_type = if let Type::Tensor(t) = output.to_type() { + let output_type = if let Type::Tensor(t) = Type::from(output) { t } else { panic!("RandomUniform output type is no Tensor."); @@ -423,7 +431,7 @@ impl OnnxGraph { fn random_normal_conversion(node: Node) -> RandomNormalNode { let output = node.outputs.first().unwrap(); // cannot use output.to_tensor_type() here, since it drops the shape info... - let output_type = if let Type::Tensor(t) = output.to_type() { + let output_type = if let Type::Tensor(t) = Type::from(output) { t } else { panic!("RandomNormal output type is no Tensor."); @@ -448,141 +456,141 @@ impl OnnxGraph { } fn add_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::add(lhs, rhs, output) } fn sub_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::sub(lhs, rhs, output) } fn mul_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::mul(lhs, rhs, output) } fn div_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::div(lhs, rhs, output) } fn matmul_conversion(node: Node) -> MatmulNode { - let lhs = node.inputs.first().unwrap().to_tensor_type(); - let rhs = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let lhs = TensorType::from(node.inputs.first().unwrap()); + let rhs = TensorType::from(node.inputs.get(1).unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); MatmulNode::new(lhs, rhs, output) } fn equal_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::equal(lhs, rhs, output) } fn max_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::max_pair(lhs, rhs, output) } fn erf_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::erf(input, output) } fn leaky_relu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let alpha = leaky_relu_config(&node); UnaryNode::leaky_relu(input, output, alpha) } fn relu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::relu(input, output) } fn gelu_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::gelu(input, output) } fn log_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::log(input, output) } fn flatten_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let (start_dim, end_dim) = flatten_config(&node); UnaryNode::flatten(input, output, start_dim, end_dim) } fn gather_conversion(node: Node) -> GatherNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let index = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let index = TensorType::from(node.inputs.get(1).unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let dim = gather_config(&node); GatherNode::new(input, index, output, dim) } fn gather_elements_conversion(node: Node) -> GatherElementsNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let index = node.inputs.get(1).unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let index = TensorType::from(node.inputs.get(1).unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let dim = gather_config(&node); GatherElementsNode::new(input, index, output, dim) } fn transpose_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let perm = transpose_config(&node); UnaryNode::transpose(input, output, perm) } fn cast_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::cast(input, output) } fn reshape_conversion(node: Node) -> ReshapeNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let shape = reshape_config(&node); ReshapeNode::new(input, output, shape) @@ -591,10 +599,10 @@ impl OnnxGraph { fn resize_conversion(node: Node) -> ResizeNode { let name = &node.name; - let input = node.inputs[0].to_tensor_type(); - let output_size = node.inputs[3].to_tensor_type(); + let input = TensorType::from(&node.inputs[0]); + let output_size = TensorType::from(&node.inputs[3]); - let output = node.outputs.first().unwrap().to_tensor_type(); + let output = TensorType::from(node.outputs.first().unwrap()); let mode = resize_config(&node); @@ -602,15 +610,15 @@ impl OnnxGraph { } fn min_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::min_pair(lhs, rhs, output) } fn range_conversion(node: Node) -> RangeNode { - fn convert_arg_to_scalar(arg: &Argument) -> ScalarType { + fn convert_arg_to_scalar(arg: &OnnxArgument) -> ScalarType { match &arg.ty { ArgType::Scalar(scalar) => { ScalarType::new(arg.name.clone(), ScalarKind::from(scalar)) @@ -624,7 +632,7 @@ impl OnnxGraph { _ => panic!("Range node requires scalar inputs"), } } - let output = node.outputs.first().unwrap().to_tensor_type(); + let output = TensorType::from(node.outputs.first().unwrap()); let start = convert_arg_to_scalar(node.inputs.first().unwrap()); let end = convert_arg_to_scalar(node.inputs.get(1).unwrap()); let step = convert_arg_to_scalar(node.inputs.get(2).unwrap()); @@ -633,164 +641,156 @@ impl OnnxGraph { } fn reduce_max_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let dim = reduce_max_config(&node); UnaryNode::reduce_max(input, output, dim) } fn reduce_min_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let dim = reduce_min_config(&node); UnaryNode::reduce_min(input, output, dim) } fn reduce_mean_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let dim = reduce_mean_config(&node); UnaryNode::reduce_mean(input, output, dim) } fn reduce_prod_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let dim = reduce_prod_config(&node); UnaryNode::reduce_prod(input, output, dim) } fn reduce_sum_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let dim = reduce_sum_config(&node); UnaryNode::reduce_sum(input, output, dim) } fn shape_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let (start_dim, end_dim) = shape_config(&node); UnaryNode::shape(input, output, start_dim, end_dim) } fn unsqueeze_conversion(node: Node) -> UnsqueezeNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let dims = unsqueeze_config(&node); UnsqueezeNode::new(input, output, dims) } fn where_conversion(node: Node) -> WhereNode { - let condition = node.inputs.first().unwrap().to_tensor_type(); - let x = node.inputs.get(1).unwrap().to_tensor_type(); - let y = node.inputs.get(2).unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let condition = TensorType::from(node.inputs.first().unwrap()); + let x = TensorType::from(node.inputs.get(1).unwrap()); + let y = TensorType::from(node.inputs.get(2).unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); WhereNode::new(condition, x, y, output) } fn clip_conversion(node: Node) -> ClipNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let (min, max) = clip_config(&node); ClipNode::new(input, output, min, max) } fn sigmoid_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::sigmoid(input, output) } fn sin_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::sin(input, output) } fn slice_conversion(node: Node) -> SliceNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let (starts, ends) = slice_config(&node); SliceNode::new(input, output, starts, ends) } fn sum_conversion(node: Node) -> SumNode { - let inputs = node - .inputs - .iter() - .map(|input| input.to_tensor_type()) - .collect(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let inputs = node.inputs.iter().map(TensorType::from).collect(); + let output = TensorType::from(node.outputs.first().unwrap()); SumNode::new(inputs, output) } fn reciprocal_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::reciprocal(input, output) } fn log_softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let dim = log_softmax_config(&node); UnaryNode::log_softmax(input, output, dim) } fn softmax_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); let dim = softmax_config(&node); UnaryNode::softmax(input, output, dim) } fn sqrt_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::sqrt(input, output) } fn tanh_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::tanh(input, output) } fn argmax_conversion(node: Node) -> ArgMaxNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let axis = argmax_config(&node); ArgMaxNode::new(input, output, axis) } fn concat_conversion(node: Node) -> ConcatNode { - let inputs = node - .inputs - .iter() - .map(|input| input.to_tensor_type()) - .collect(); + let inputs = node.inputs.iter().map(TensorType::from).collect(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let output = TensorType::from(node.outputs.first().unwrap()); let dim = concat_config(&node); ConcatNode::new(inputs, output, dim) @@ -798,8 +798,8 @@ impl OnnxGraph { fn linear_conversion(node: Node) -> LinearNode { let name = &node.name; - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = linear_config(&node); let weight = extract_data_serialize::(1, &node).expect("Weight is required"); @@ -811,8 +811,8 @@ impl OnnxGraph { fn dropout_conversion(node: Node) -> DropoutNode { let name = &node.name; - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = dropout_config(&node); DropoutNode::new(name, input, output, config) @@ -820,8 +820,8 @@ impl OnnxGraph { fn batch_norm_conversion(node: Node) -> BatchNormNode { let config = batch_norm_config(&node); - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let dim = input.dim - 2; let gamma = extract_data_serialize::(1, &node).expect("Gamma is required"); @@ -848,8 +848,8 @@ impl OnnxGraph { fn layer_norm_conversion(node: Node) -> LayerNormNode { let (config, full_precision) = layer_norm_config(&node); - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); // Scale tensor (aka gamma) let gamma = extract_data_serialize::(1, &node).expect("Gamma is required"); @@ -862,8 +862,8 @@ impl OnnxGraph { } fn conv1d_conversion(node: Node) -> Conv1dNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = conv1d_config(&node); let bias = node.inputs.len() == 3; @@ -878,8 +878,8 @@ impl OnnxGraph { } fn conv2d_conversion(node: Node) -> Conv2dNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = conv2d_config(&node); let bias = node.inputs.len() == 3; @@ -894,8 +894,8 @@ impl OnnxGraph { } fn max_pool1d_conversion(node: Node) -> MaxPool1dNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = max_pool1d_config(&node); let name = &node.name; @@ -903,8 +903,8 @@ impl OnnxGraph { } fn max_pool2d_conversion(node: Node) -> MaxPool2dNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = max_pool2d_config(&node); let name = &node.name; @@ -912,16 +912,16 @@ impl OnnxGraph { } fn prelu_conversion(node: Node) -> PReluNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let weight = extract_data_serialize::(1, &node).unwrap(); let config = PReluConfig::new(); let name = &node.name; PReluNode::new(name, input, output, weight, config) } fn conv_transpose2d_conversion(node: Node) -> ConvTranspose2dNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = conv_transpose2d_config(&node); let bias = node.inputs.len() == 3; @@ -935,8 +935,8 @@ impl OnnxGraph { ConvTranspose2dNode::new(name, input, output, weight, bias, config) } fn avg_pool_1d_conversion(node: Node) -> AvgPool1dNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = avg_pool1d_config(&node); let name = &node.name; @@ -944,8 +944,8 @@ impl OnnxGraph { } fn avg_pool_2d_conversion(node: Node) -> AvgPool2dNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let config = avg_pool2d_config(&node); let name = &node.name; @@ -953,8 +953,8 @@ impl OnnxGraph { } fn global_avg_pool_conversion(node: Node) -> GlobalAvgPoolNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let name = &node.name; @@ -962,71 +962,71 @@ impl OnnxGraph { } fn cos_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::cos(input, output) } fn exp_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::exp(input, output) } fn expand_conversion(node: Node) -> ExpandNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let shape = expand_config(&node); ExpandNode::new(input, output, shape) } fn neg_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::neg(input, output) } fn not_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::not(input, output) } fn greater_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::greater(lhs, rhs, output) } fn less_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::lower(lhs, rhs, output) } fn greater_or_equal_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::greater_equal(lhs, rhs, output) } fn less_or_equal_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); BinaryNode::lower_equal(lhs, rhs, output) } fn pow_conversion(node: Node) -> BinaryNode { - let lhs = node.inputs.first().unwrap().to_type(); - let rhs = node.inputs.get(1).unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let lhs = Type::from(node.inputs.first().unwrap()); + let rhs = Type::from(node.inputs.get(1).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); match &rhs { Type::Tensor(x) => match x.kind { TensorKind::Int => BinaryNode::powi(lhs, rhs, output), @@ -1043,14 +1043,14 @@ impl OnnxGraph { } fn sign_conversion(node: Node) -> UnaryNode { - let input = node.inputs.first().unwrap().to_type(); - let output = node.outputs.first().unwrap().to_type(); + let input = Type::from(node.inputs.first().unwrap()); + let output = Type::from(node.outputs.first().unwrap()); UnaryNode::sign(input, output) } fn squeeze_conversion(node: Node) -> SqueezeNode { - let input = node.inputs.first().unwrap().to_tensor_type(); - let output = node.outputs.first().unwrap().to_tensor_type(); + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); let axes = squeeze_config(&node); SqueezeNode::new(input, output, axes) @@ -1101,48 +1101,49 @@ fn serialize_data(data: Data, shape: Vec) -> TensorData { } } -impl Argument { - pub fn to_tensor_type(&self) -> TensorType { - match &self.ty { - ArgType::Tensor(ir::TensorType { +impl From<&OnnxArgument> for TensorType { + fn from(arg: &OnnxArgument) -> Self { + match &arg.ty { + ArgType::Tensor(OnnxTensorType { elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64, dim, .. - }) => TensorType::new_float(self.name.clone(), *dim), - ArgType::Tensor(ir::TensorType { + }) => TensorType::new_float(arg.name.clone(), *dim), + ArgType::Tensor(OnnxTensorType { elem_type: ElementType::Int32 | ElementType::Int64, dim, .. - }) => TensorType::new_int(self.name.clone(), *dim), - ArgType::Tensor(ir::TensorType { + }) => TensorType::new_int(arg.name.clone(), *dim), + ArgType::Tensor(OnnxTensorType { elem_type: ElementType::Bool, dim, .. - }) => TensorType::new_bool(self.name.clone(), *dim), - _ => panic!("Can't transform to tensor."), + }) => TensorType::new_bool(arg.name.clone(), *dim), + _ => panic!("Can't transform scalar to tensor."), } } - - pub fn to_type(&self) -> Type { - match &self.ty { +} +impl From<&OnnxArgument> for Type { + fn from(arg: &OnnxArgument) -> Self { + match &arg.ty { ArgType::Tensor(tensor) => { // Treat tensor with dim 0 as scalar if tensor.dim == 0 { Type::Scalar(ScalarType::new( - self.name.clone(), + arg.name.clone(), ScalarKind::from(&tensor.elem_type), )) } else { let kind: TensorKind = tensor.elem_type.clone().into(); let dim = tensor.dim; - let name = self.name.clone(); + let name = arg.name.clone(); let shape = tensor.shape.clone(); Type::Tensor(TensorType::new(name, dim, kind, shape)) } } ArgType::Scalar(elem_type) => { - Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into())) + Type::Scalar(ScalarType::new(arg.name.clone(), elem_type.into())) } ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."), } diff --git a/crates/onnx-ir/Cargo.toml b/crates/onnx-ir/Cargo.toml new file mode 100644 index 0000000000..e8cfee51c5 --- /dev/null +++ b/crates/onnx-ir/Cargo.toml @@ -0,0 +1,31 @@ +[package] +authors = [ + "Dilshod Tadjibaev (@antimora)", + "Nathaniel Simard (@nathanielsimard)", +] +description = "Library for parsing ONNX models" +edition.workspace = true +license.workspace = true +name = "onnx-ir" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/crates/onnx-ir" +version.workspace = true + + +[dependencies] +bytemuck = { workspace = true } +half = { workspace = true } +log = { workspace = true } +protobuf = { workspace = true, features = ["with-bytes"] } +regex = { workspace = true } +serde = { workspace = true, features = ["derive"] } +strum = { workspace = true } +strum_macros = { workspace = true } + + +[build-dependencies] +protobuf-codegen = { workspace = true } + +[dev-dependencies] +pretty_assertions = { workspace = true } +rstest = { workspace = true } diff --git a/crates/onnx-ir/README.md b/crates/onnx-ir/README.md new file mode 100644 index 0000000000..5b480be516 --- /dev/null +++ b/crates/onnx-ir/README.md @@ -0,0 +1,7 @@ +# ONNX-IR + +A pure rust Onnx Parser. Creates an intermediate representation useful for generating code in any ML/DL framework + +For a full list of currently supported operators, please check [here](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) + +To see how to use this for generating burn graphs, see [here](crates/burn-import/src/onnx/to_burn.rs). \ No newline at end of file diff --git a/crates/onnx-ir/build.rs b/crates/onnx-ir/build.rs new file mode 100644 index 0000000000..d35064abc5 --- /dev/null +++ b/crates/onnx-ir/build.rs @@ -0,0 +1,9 @@ +fn main() { + // Generate the onnx protobuf files + protobuf_codegen::Codegen::new() + .pure() + .includes(["src"]) + .input("src/protos/onnx.proto") + .cargo_out_dir("onnx-protos") + .run_from_script(); +} diff --git a/crates/burn-import/src/onnx/coalesce.rs b/crates/onnx-ir/src/coalesce.rs similarity index 99% rename from crates/burn-import/src/onnx/coalesce.rs rename to crates/onnx-ir/src/coalesce.rs index efccadeb21..c5e8ac550e 100644 --- a/crates/burn-import/src/onnx/coalesce.rs +++ b/crates/onnx-ir/src/coalesce.rs @@ -6,7 +6,7 @@ use super::{ proto_conversion::convert_node_proto, protos::NodeProto, }; -use crate::onnx::ir::{ArgType, Data, TensorType}; +use crate::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. pub fn coalesce( diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs similarity index 99% rename from crates/burn-import/src/onnx/dim_inference.rs rename to crates/onnx-ir/src/dim_inference.rs index 9d0d5b3087..e39ef73bdf 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -3,10 +3,10 @@ use core::panic; use protobuf::Enum; -use super::{ +use crate::{ ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, - op_configuration::flatten_config, protos::tensor_proto::DataType, + util::flatten_config, }; /// Infer the dimension of each output tensor and update them. diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/onnx-ir/src/from_onnx.rs similarity index 92% rename from crates/burn-import/src/onnx/from_onnx.rs rename to crates/onnx-ir/src/from_onnx.rs index e6fc091225..0dc50dc7fb 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/onnx-ir/src/from_onnx.rs @@ -4,7 +4,7 @@ use std::{ path::Path, }; -use crate::onnx::node_remap::remap_node_type; +use crate::node_remap::remap_node_type; use super::{ coalesce::coalesce, @@ -56,9 +56,9 @@ pub struct GraphData { impl GraphData { pub(crate) fn new( - inputs: &Vec, - outputs: &Vec, - initializers: &Vec, + inputs: &[ValueInfoProto], + outputs: &[ValueInfoProto], + initializers: &[TensorProto], ) -> Self { let mut input_name_map = HashMap::new(); let mut input_key_map = HashMap::new(); @@ -375,35 +375,32 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { /// properly deleted if nothing else uses it /// Remap the unsqueeze node to a reshape node pub(crate) fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { - match &out_arg.ty { - ArgType::Tensor(output_tensor) => { - let inner = output_tensor - .shape - .clone() - .unwrap() - .into_iter() - .map(|x| x as i64) - .collect::>(); - let shape_len = inner.len(); - let new_rhs_value = Some(Data::Int64s(inner)); - //moving the remap to here - let rhs_arg = Argument { - name: format!("{}_generated_const", &node.name), - ty: ArgType::Tensor(TensorType { - elem_type: super::ir::ElementType::Int64, - dim: 1, - shape: Some(vec![shape_len]), - }), - value: new_rhs_value, - passed: false, - }; - // ? should this replace the old input (reuse the old key) or should it be a new key - // going with new key for now - node.inputs[1] = rhs_arg; - node.outputs[0] = out_arg.clone(); - node.node_type = NodeType::Reshape; - } - _ => {} + if let ArgType::Tensor(output_tensor) = &out_arg.ty { + let inner = output_tensor + .shape + .clone() + .unwrap() + .into_iter() + .map(|x| x as i64) + .collect::>(); + let shape_len = inner.len(); + let new_rhs_value = Some(Data::Int64s(inner)); + //moving the remap to here + let rhs_arg = Argument { + name: format!("{}_generated_const", &node.name), + ty: ArgType::Tensor(TensorType { + elem_type: super::ir::ElementType::Int64, + dim: 1, + shape: Some(vec![shape_len]), + }), + value: new_rhs_value, + passed: false, + }; + // ? should this replace the old input (reuse the old key) or should it be a new key + // going with new key for now + node.inputs[1] = rhs_arg; + node.outputs[0] = out_arg.clone(); + node.node_type = NodeType::Reshape; } } // Define a trait for topological sorting @@ -444,7 +441,7 @@ impl TopologicalSortable for Vec { } /// Get the value of a constant node from its attributes -pub(crate) fn convert_constant_value(node: &Node) -> Argument { +pub fn convert_constant_value(node: &Node) -> Argument { // A value can be stored in any of these attributes let keys = [ "value", diff --git a/crates/burn-import/src/onnx/ir.rs b/crates/onnx-ir/src/ir.rs similarity index 99% rename from crates/burn-import/src/onnx/ir.rs rename to crates/onnx-ir/src/ir.rs index 620ffed4ec..23020bec22 100644 --- a/crates/burn-import/src/onnx/ir.rs +++ b/crates/onnx-ir/src/ir.rs @@ -3,7 +3,7 @@ use half::f16; use std::{collections::HashMap, fmt::Formatter}; use strum_macros::{Display, EnumString}; -use super::protos::TensorProto; +use crate::protos::TensorProto; pub type Dim = usize; pub type Shape = Vec; @@ -241,7 +241,7 @@ impl PartialEq for Argument { } /// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops) -/// Refer: https://github.com/onnx/onnx/blob/main/docs/Operators.md +/// Refer: #[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)] pub enum NodeType { Abs, @@ -444,7 +444,7 @@ pub enum NodeType { } /// Truncate the vector display for debug display -fn trunc(v: &Vec) -> String { +fn trunc(v: &[T]) -> String { const BEGIN_INDEX: usize = 0; const MAX_LEN: usize = 5; let mut s = String::new(); diff --git a/crates/onnx-ir/src/lib.rs b/crates/onnx-ir/src/lib.rs new file mode 100644 index 0000000000..8c16b23adf --- /dev/null +++ b/crates/onnx-ir/src/lib.rs @@ -0,0 +1,12 @@ +mod coalesce; +mod dim_inference; +mod from_onnx; +pub mod ir; +mod node_remap; +mod proto_conversion; +mod protos; +mod util; + +pub use from_onnx::convert_constant_value; +pub use from_onnx::parse_onnx; +pub use ir::OnnxGraph; diff --git a/crates/burn-import/src/onnx/node_remap.rs b/crates/onnx-ir/src/node_remap.rs similarity index 100% rename from crates/burn-import/src/onnx/node_remap.rs rename to crates/onnx-ir/src/node_remap.rs diff --git a/crates/burn-import/src/onnx/proto_conversion.rs b/crates/onnx-ir/src/proto_conversion.rs similarity index 99% rename from crates/burn-import/src/onnx/proto_conversion.rs rename to crates/onnx-ir/src/proto_conversion.rs index 740db218e4..43adb76e43 100644 --- a/crates/burn-import/src/onnx/proto_conversion.rs +++ b/crates/onnx-ir/src/proto_conversion.rs @@ -1,6 +1,6 @@ use std::str::{from_utf8, FromStr}; -use crate::onnx::ir::TensorType; +use crate::ir::TensorType; use super::from_onnx::GraphData; use super::ir::Dim; diff --git a/crates/burn-import/src/onnx/protos/mod.rs b/crates/onnx-ir/src/protos/mod.rs similarity index 100% rename from crates/burn-import/src/onnx/protos/mod.rs rename to crates/onnx-ir/src/protos/mod.rs diff --git a/crates/burn-import/src/onnx/protos/onnx.proto b/crates/onnx-ir/src/protos/onnx.proto similarity index 100% rename from crates/burn-import/src/onnx/protos/onnx.proto rename to crates/onnx-ir/src/protos/onnx.proto diff --git a/crates/onnx-ir/src/util.rs b/crates/onnx-ir/src/util.rs new file mode 100644 index 0000000000..98bf0871dd --- /dev/null +++ b/crates/onnx-ir/src/util.rs @@ -0,0 +1,45 @@ +use crate::ir::{ArgType, Node}; +/// Create a FlattenConfig from the attributes of the node +pub fn flatten_config(curr: &Node) -> (usize, usize) { + // the begin dimension is the first dimension (Default: 1 per ONNX spec) + let mut start_dim: i64 = 1; + + // check if the node has only one input + if curr.inputs.len() != 1 { + panic!( + "Flatten: multiple inputs are not supported (got {:?})", + curr.inputs.len() + ); + } + + // extract the shape of the input tensor + let tensor = match curr.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // check if the input tensor has at least 2 dimensions + if tensor.dim < 2 { + panic!( + "Flatten: input tensor must have at least 2 dimensions (got {:?})", + tensor.dim + ); + } + + // the end dimension is the last dimension + let end_dim = tensor.dim - 1; + + // extract the attributes + for (key, value) in curr.attrs.iter() { + if key.as_str() == "axis" { + start_dim = value.clone().into_i64(); + } + } + + // if beg_dim is negative, it is counted from the end + if start_dim < 0 { + start_dim += tensor.dim as i64; + } + + (start_dim as usize, end_dim) +}