From ebe24216a28e818f5e200a00688d632b18d9bd60 Mon Sep 17 00:00:00 2001 From: tiruka <33803972+tiruka@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:24:42 +0900 Subject: [PATCH] Add onnx op trilu (#2323) * add trilu.py and onnx model for operators * add node/trilu.rs under burn-import and modify related files * update trilu codegen tests * add op_configuration and to_burn rust files * update how to get diagonal value from nodes argument * add trilu tests in test_onnx.rs * update files to follow clippy and format rules * add tests for both of upper and lower to follow review comment * delete onnx model of trilu.onnx since upper and lower models are added --- .gitignore | 3 + crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 2 + .../burn-import/onnx-tests/tests/test_onnx.rs | 40 +++++ .../onnx-tests/tests/trilu/trilu.py | 71 +++++++++ .../onnx-tests/tests/trilu/trilu_lower.onnx | Bin 0 -> 237 bytes .../onnx-tests/tests/trilu/trilu_upper.onnx | Bin 0 -> 237 bytes crates/burn-import/src/burn/node/base.rs | 6 +- crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/burn/node/trilu.rs | 139 ++++++++++++++++++ .../burn-import/src/onnx/op_configuration.rs | 23 ++- crates/burn-import/src/onnx/to_burn.rs | 11 +- crates/onnx-ir/src/dim_inference.rs | 1 + 13 files changed, 295 insertions(+), 4 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/trilu/trilu.py create mode 100644 crates/burn-import/onnx-tests/tests/trilu/trilu_lower.onnx create mode 100644 crates/burn-import/onnx-tests/tests/trilu/trilu_upper.onnx create mode 100644 crates/burn-import/src/burn/node/trilu.rs diff --git a/.gitignore b/.gitignore index ffa113f25c..c47cb50116 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ target # Generated IR and Burn Graph from ONNX out + +# Virtual Environment of Python +.venv diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 6319fedbe7..b01aa9bdfb 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -194,7 +194,7 @@ represent the corresponding Burn Op. | [Tile][185] | ✅ | ✅ | | [TopK][186] | ❌ | ✅ | | [Transpose][187] | ✅ | ✅ | -| [Trilu][188] | ❌ | ✅ | +| [Trilu][188] | ✅ | ✅ | | [Unique][189] | ❌ | ❌ | | [Upsample][190] | ❌ | ❌ | | [Where][191] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index a7360e012d..7091c43056 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -107,6 +107,8 @@ fn main() { .input("tests/sum/sum_int.onnx") .input("tests/tanh/tanh.onnx") .input("tests/tile/tile.onnx") + .input("tests/trilu/trilu_upper.onnx") + .input("tests/trilu/trilu_lower.onnx") .input("tests/transpose/transpose.onnx") .input("tests/unsqueeze/unsqueeze.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 49380c6d89..b7dd160246 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -116,6 +116,8 @@ include_models!( sum_int, tanh, tile, + trilu_upper, + trilu_lower, transpose, unsqueeze, unsqueeze_opset11, @@ -1871,6 +1873,44 @@ mod tests { output.assert_eq(&expected, true); } + #[test] + fn trilu_upper() { + let device = Default::default(); + let model: trilu_upper::Model = trilu_upper::Model::new(&device); + let input = Tensor::::from_floats( + [[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]], + &device, + ); + let expected = TensorData::from([[ + [1.0_f32, 2.0_f32, 3.0_f32], + [0.0_f32, 5.0_f32, 6.0_f32], + [0.0_f32, 0.0_f32, 9.0_f32], + ]]); + + let output = model.forward(input).to_data(); + + output.assert_eq(&expected, true); + } + + #[test] + fn trilu_lower() { + let device = Default::default(); + let model: trilu_lower::Model = trilu_lower::Model::new(&device); + let input = Tensor::::from_floats( + [[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]], + &device, + ); + let expected = TensorData::from([[ + [1.0_f32, 0.0_f32, 0.0_f32], + [4.0_f32, 5.0_f32, 0.0_f32], + [7.0_f32, 8.0_f32, 9.0_f32], + ]]); + + let output = model.forward(input).to_data(); + + output.assert_eq(&expected, true); + } + #[test] fn unsqueeze() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/trilu/trilu.py b/crates/burn-import/onnx-tests/tests/trilu/trilu.py new file mode 100644 index 0000000000..036bf71f90 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/trilu/trilu.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn + + +class TriluModel(nn.Module): + def __init__(self, upper=True, diagonal=0): + super(TriluModel, self).__init__() + self.upper = upper # Determines upper or lower triangular + self.diagonal = diagonal # Diagonal offset + + def forward(self, x): + # torch.tril or torch.triu based on 'upper' attribute + if self.upper: + return torch.triu(x, diagonal=self.diagonal) + else: + return torch.tril(x, diagonal=self.diagonal) + + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + + # Set print options for better precision output + torch.set_printoptions(precision=8) + + # Export Trilu Upper Model + upper = True # Change to False for lower triangular matrix + diagonal = 1 # Change k to adjust the diagonal + lower_model = TriluModel(upper=upper, diagonal=diagonal) + lower_model.eval() + device = torch.device("cpu") + + # Generate test input: a 2D matrix or batch of 2D matrices + upper_file_name = "trilu_upper.onnx" + test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices + torch.onnx.export(lower_model, test_input, upper_file_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(upper_file_name)) + + # Output some test data for use in the test + print("Test input data: {}".format(test_input)) + print("Test input data shape: {}".format(test_input.shape)) + output = lower_model.forward(test_input) + print("Test output data shape: {}".format(output.shape)) + print("Test output: {}".format(output)) + + + # Export Trilu Lower Model + upper = False + diagonal = 1 + lower_model = TriluModel(upper=upper, diagonal=diagonal) + lower_model.eval() + # Generate test input: a 2D matrix or batch of 2D matrices + upper_file_name = "trilu_lower.onnx" + test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices + torch.onnx.export(lower_model, test_input, upper_file_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(upper_file_name)) + + print("Test input data: {}".format(test_input)) + print("Test input data shape: {}".format(test_input.shape)) + output = lower_model.forward(test_input) + print("Test output data shape: {}".format(output.shape)) + print("Test output: {}".format(output)) + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/trilu/trilu_lower.onnx b/crates/burn-import/onnx-tests/tests/trilu/trilu_lower.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d20eb150fa4d6ac70bfd3aac4e7739886de4ce6e GIT binary patch literal 237 zcmd4Oz0 zv4SZr9xm3>f`Zf{35ErXOhR0_iJ5uv=|zbJ8Bxk;1_}vr32-n9@o+J5Faa?O2q%ei NF&bf$aAFb=007&8I3551 literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/trilu/trilu_upper.onnx b/crates/burn-import/onnx-tests/tests/trilu/trilu_upper.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ddd92475683e6cd2e529c19a50e2395c28477831 GIT binary patch literal 237 zcmd4Oz0 zv4SZr9xm3>f`Zf{3C0DCOhR0_iJ5uv=|zbJ8Bxk;1_}vr32-n9@o+J5Faa?O2q%ei NF&bf$aAFb=007(5I3EB2 literal 0 HcmV?d00001 diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index a1c9103b41..1846f3e4c2 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -11,7 +11,8 @@ use super::{ max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, - squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode, + unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -114,6 +115,7 @@ pub enum Node { Squeeze(SqueezeNode), Sum(SumNode), Tile(TileNode), + Trilu(TriluNode), Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), @@ -162,6 +164,7 @@ macro_rules! match_all { Node::Squeeze(node) => $func(node), Node::Sum(node) => $func(node), Node::Tile(node) => $func(node), + Node::Trilu(node) => $func(node), Node::Unary(node) => $func(node), Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), @@ -218,6 +221,7 @@ impl Node { Node::Squeeze(_) => "squeeze", Node::Sum(_) => "add", Node::Tile(_) => "tile", + Node::Trilu(_) => "trilu", Node::Unary(unary) => unary.kind.as_str(), Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index ee294ddfd7..62411cb764 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -37,6 +37,7 @@ pub(crate) mod slice; pub(crate) mod squeeze; pub(crate) mod sum; pub(crate) mod tile; +pub(crate) mod trilu; pub(crate) mod unary; pub(crate) mod unsqueeze; pub(crate) use base::*; diff --git a/crates/burn-import/src/burn/node/trilu.rs b/crates/burn-import/src/burn/node/trilu.rs new file mode 100644 index 0000000000..0fa211d871 --- /dev/null +++ b/crates/burn-import/src/burn/node/trilu.rs @@ -0,0 +1,139 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::config::Config; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Config, Debug)] +pub struct TriluConfig { + pub upper: bool, + pub diagonal: i64, +} + +#[derive(Debug, Clone, new)] +pub struct TriluNode { + pub input: TensorType, + pub output: TensorType, + pub config: TriluConfig, +} + +impl NodeCodegen for TriluNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let diagonal = self.config.diagonal.to_tokens(); + if self.config.upper { + quote! { + let #output = #input.triu(#diagonal); + } + } else { + quote! { + let #output = #input.tril(#diagonal); + } + } + } + fn into_node(self) -> super::Node { + Node::Trilu(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{test::assert_tokens, trilu::TriluConfig, trilu::TriluNode}, + TensorType, + }; + use burn::record::FullPrecisionSettings; + + #[test] + fn test_codegen_triu() { + let mut graph = BurnGraph::::default(); + let config = TriluConfig::new(true, 0); + graph.register(TriluNode::new( + TensorType::new_float("input", 2), + TensorType::new_float("output", 2), + config, + )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.triu(0); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_tril() { + let mut graph = BurnGraph::::default(); + let config = TriluConfig::new(false, 0); + graph.register(TriluNode::new( + TensorType::new_float("input", 2), + TensorType::new_float("output", 2), + config, + )); + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = input.tril(0); + output + } + } + }; + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 4621b129d6..e8fcf27830 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,9 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig}; +use crate::burn::node::{ + expand::ExpandShape, pad::PadConfig, tile::TileConfig, trilu::TriluConfig, +}; use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -795,6 +797,25 @@ pub fn tile_config(node: &Node) -> TileConfig { TileConfig::new(repeat) } +/// Create a TriluConfig from the attributes of the node +pub fn trilu_config(node: &Node) -> TriluConfig { + let mut upper = true; + let mut diagonal = 0; + for (key, value) in node.attrs.iter() { + match key.as_str() { + "upper" => upper = value.clone().into_i64() != 0, + _ => {} + } + } + // The second input of the Trilu node is the diagonal value, coming from a constant node + if let Some(diagonal_arg) = node.inputs.get(1) { + if let Some(Data::Int64(diagonal_val)) = &diagonal_arg.value { + diagonal = *diagonal_val; + } + } + TriluConfig::new(upper, diagonal) +} + /// Create a PadConfig from the attributes of the node pub fn pad_config(node: &Node) -> PadConfig { fn get_pads_input(node: &Node) -> Vec { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 5c4f34078f..f302b044c1 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -51,6 +51,7 @@ use crate::{ squeeze::SqueezeNode, sum::SumNode, tile::TileNode, + trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }, @@ -68,7 +69,7 @@ use super::op_configuration::{ max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config, slice_config, softmax_config, squeeze_config, tile_config, transpose_config, - unsqueeze_config, + trilu_config, unsqueeze_config, }; use onnx_ir::{ convert_constant_value, @@ -338,6 +339,7 @@ impl ParsedOnnxGraph { NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), NodeType::RandomUniform => graph.register(Self::random_uniform_conversion(node)), NodeType::Tile => graph.register(Self::tile_conversion(node)), + NodeType::Trilu => graph.register(Self::trilu_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), NodeType::ConstantOfShape => { graph.register(Self::constant_of_shape_conversion(node)) @@ -1184,6 +1186,13 @@ impl ParsedOnnxGraph { TileNode::new(input, output, config) } + + fn trilu_conversion(node: Node) -> TriluNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let config = trilu_config(&node); + TriluNode::new(input, output, config) + } } /// Extract data from node states and convert it to `TensorData`. diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index ff580b37aa..7bd156a569 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -83,6 +83,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Sum => same_as_input_broadcast(node), NodeType::Tanh => same_as_input(node), NodeType::Transpose => same_as_input(node), + NodeType::Trilu => same_as_input(node), NodeType::Unsqueeze => unsqueeze_update_output(node), NodeType::Where => where_update_outputs(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.