Skip to content

Commit

Permalink
Add onnx op trilu (#2323)
Browse files Browse the repository at this point in the history
* add trilu.py and onnx model for operators

* add node/trilu.rs under burn-import and modify related files

* update trilu codegen tests

* add op_configuration and to_burn rust files

* update how to get diagonal value from nodes argument

* add trilu tests in test_onnx.rs

* update files to follow clippy and format rules

* add tests for both of upper and lower to follow review comment

* delete onnx model of trilu.onnx since upper and lower models are added
  • Loading branch information
tiruka authored Oct 8, 2024
1 parent dcd6e7f commit ebe2421
Show file tree
Hide file tree
Showing 13 changed files with 295 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ target

# Generated IR and Burn Graph from ONNX
out

# Virtual Environment of Python
.venv
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ represent the corresponding Burn Op.
| [Tile][185] |||
| [TopK][186] |||
| [Transpose][187] |||
| [Trilu][188] | ||
| [Trilu][188] | ||
| [Unique][189] |||
| [Upsample][190] |||
| [Where][191] |||
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ fn main() {
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/trilu/trilu_upper.onnx")
.input("tests/trilu/trilu_lower.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
Expand Down
40 changes: 40 additions & 0 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ include_models!(
sum_int,
tanh,
tile,
trilu_upper,
trilu_lower,
transpose,
unsqueeze,
unsqueeze_opset11,
Expand Down Expand Up @@ -1871,6 +1873,44 @@ mod tests {
output.assert_eq(&expected, true);
}

#[test]
fn trilu_upper() {
let device = Default::default();
let model: trilu_upper::Model<Backend> = trilu_upper::Model::new(&device);
let input = Tensor::<Backend, 3>::from_floats(
[[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]],
&device,
);
let expected = TensorData::from([[
[1.0_f32, 2.0_f32, 3.0_f32],
[0.0_f32, 5.0_f32, 6.0_f32],
[0.0_f32, 0.0_f32, 9.0_f32],
]]);

let output = model.forward(input).to_data();

output.assert_eq(&expected, true);
}

#[test]
fn trilu_lower() {
let device = Default::default();
let model: trilu_lower::Model<Backend> = trilu_lower::Model::new(&device);
let input = Tensor::<Backend, 3>::from_floats(
[[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]],
&device,
);
let expected = TensorData::from([[
[1.0_f32, 0.0_f32, 0.0_f32],
[4.0_f32, 5.0_f32, 0.0_f32],
[7.0_f32, 8.0_f32, 9.0_f32],
]]);

let output = model.forward(input).to_data();

output.assert_eq(&expected, true);
}

#[test]
fn unsqueeze() {
let device = Default::default();
Expand Down
71 changes: 71 additions & 0 deletions crates/burn-import/onnx-tests/tests/trilu/trilu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn


class TriluModel(nn.Module):
def __init__(self, upper=True, diagonal=0):
super(TriluModel, self).__init__()
self.upper = upper # Determines upper or lower triangular
self.diagonal = diagonal # Diagonal offset

def forward(self, x):
# torch.tril or torch.triu based on 'upper' attribute
if self.upper:
return torch.triu(x, diagonal=self.diagonal)
else:
return torch.tril(x, diagonal=self.diagonal)


def main():
# Set seed for reproducibility
torch.manual_seed(42)

# Set print options for better precision output
torch.set_printoptions(precision=8)

# Export Trilu Upper Model
upper = True # Change to False for lower triangular matrix
diagonal = 1 # Change k to adjust the diagonal
lower_model = TriluModel(upper=upper, diagonal=diagonal)
lower_model.eval()
device = torch.device("cpu")

# Generate test input: a 2D matrix or batch of 2D matrices
upper_file_name = "trilu_upper.onnx"
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices
torch.onnx.export(lower_model, test_input, upper_file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(upper_file_name))

# Output some test data for use in the test
print("Test input data: {}".format(test_input))
print("Test input data shape: {}".format(test_input.shape))
output = lower_model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))


# Export Trilu Lower Model
upper = False
diagonal = 1
lower_model = TriluModel(upper=upper, diagonal=diagonal)
lower_model.eval()
# Generate test input: a 2D matrix or batch of 2D matrices
upper_file_name = "trilu_lower.onnx"
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices
torch.onnx.export(lower_model, test_input, upper_file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(upper_file_name))

print("Test input data: {}".format(test_input))
print("Test input data shape: {}".format(test_input.shape))
output = lower_model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))

if __name__ == '__main__':
main()
Binary file not shown.
Binary file not shown.
6 changes: 5 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use super::{
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -114,6 +115,7 @@ pub enum Node<PS: PrecisionSettings> {
Squeeze(SqueezeNode),
Sum(SumNode),
Tile(TileNode),
Trilu(TriluNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
Expand Down Expand Up @@ -162,6 +164,7 @@ macro_rules! match_all {
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Tile(node) => $func(node),
Node::Trilu(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Expand Down Expand Up @@ -218,6 +221,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Tile(_) => "tile",
Node::Trilu(_) => "trilu",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod tile;
pub(crate) mod trilu;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
pub(crate) use base::*;
Expand Down
139 changes: 139 additions & 0 deletions crates/burn-import/src/burn/node/trilu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Config, Debug)]
pub struct TriluConfig {
pub upper: bool,
pub diagonal: i64,
}

#[derive(Debug, Clone, new)]
pub struct TriluNode {
pub input: TensorType,
pub output: TensorType,
pub config: TriluConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for TriluNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let diagonal = self.config.diagonal.to_tokens();
if self.config.upper {
quote! {
let #output = #input.triu(#diagonal);
}
} else {
quote! {
let #output = #input.tril(#diagonal);
}
}
}
fn into_node(self) -> super::Node<PS> {
Node::Trilu(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{test::assert_tokens, trilu::TriluConfig, trilu::TriluNode},
TensorType,
};
use burn::record::FullPrecisionSettings;

#[test]
fn test_codegen_triu() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = TriluConfig::new(true, 0);
graph.register(TriluNode::new(
TensorType::new_float("input", 2),
TensorType::new_float("output", 2),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.triu(0);
output
}
}
};

assert_tokens(graph.codegen(), expected);
}

#[test]
fn test_codegen_tril() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = TriluConfig::new(false, 0);
graph.register(TriluNode::new(
TensorType::new_float("input", 2),
TensorType::new_float("output", 2),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.tril(0);
output
}
}
};
assert_tokens(graph.codegen(), expected);
}
}
23 changes: 22 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};

use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig};
use crate::burn::node::{
expand::ExpandShape, pad::PadConfig, tile::TileConfig, trilu::TriluConfig,
};
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};

/// Create a Conv1dConfig from the attributes of the node
Expand Down Expand Up @@ -795,6 +797,25 @@ pub fn tile_config(node: &Node) -> TileConfig {
TileConfig::new(repeat)
}

/// Create a TriluConfig from the attributes of the node
pub fn trilu_config(node: &Node) -> TriluConfig {
let mut upper = true;
let mut diagonal = 0;
for (key, value) in node.attrs.iter() {
match key.as_str() {
"upper" => upper = value.clone().into_i64() != 0,
_ => {}
}
}
// The second input of the Trilu node is the diagonal value, coming from a constant node
if let Some(diagonal_arg) = node.inputs.get(1) {
if let Some(Data::Int64(diagonal_val)) = &diagonal_arg.value {
diagonal = *diagonal_val;
}
}
TriluConfig::new(upper, diagonal)
}

/// Create a PadConfig from the attributes of the node
pub fn pad_config(node: &Node) -> PadConfig {
fn get_pads_input(node: &Node) -> Vec<i64> {
Expand Down
Loading

0 comments on commit ebe2421

Please sign in to comment.