Skip to content

Commit ebe2421

Browse files
authored
Add onnx op trilu (#2323)
* add trilu.py and onnx model for operators * add node/trilu.rs under burn-import and modify related files * update trilu codegen tests * add op_configuration and to_burn rust files * update how to get diagonal value from nodes argument * add trilu tests in test_onnx.rs * update files to follow clippy and format rules * add tests for both of upper and lower to follow review comment * delete onnx model of trilu.onnx since upper and lower models are added
1 parent dcd6e7f commit ebe2421

File tree

13 files changed

+295
-4
lines changed

13 files changed

+295
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ target
1212

1313
# Generated IR and Burn Graph from ONNX
1414
out
15+
16+
# Virtual Environment of Python
17+
.venv

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ represent the corresponding Burn Op.
194194
| [Tile][185] |||
195195
| [TopK][186] |||
196196
| [Transpose][187] |||
197-
| [Trilu][188] | ||
197+
| [Trilu][188] | ||
198198
| [Unique][189] |||
199199
| [Upsample][190] |||
200200
| [Where][191] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ fn main() {
107107
.input("tests/sum/sum_int.onnx")
108108
.input("tests/tanh/tanh.onnx")
109109
.input("tests/tile/tile.onnx")
110+
.input("tests/trilu/trilu_upper.onnx")
111+
.input("tests/trilu/trilu_lower.onnx")
110112
.input("tests/transpose/transpose.onnx")
111113
.input("tests/unsqueeze/unsqueeze.onnx")
112114
.input("tests/unsqueeze/unsqueeze_opset11.onnx")

crates/burn-import/onnx-tests/tests/test_onnx.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ include_models!(
116116
sum_int,
117117
tanh,
118118
tile,
119+
trilu_upper,
120+
trilu_lower,
119121
transpose,
120122
unsqueeze,
121123
unsqueeze_opset11,
@@ -1871,6 +1873,44 @@ mod tests {
18711873
output.assert_eq(&expected, true);
18721874
}
18731875

1876+
#[test]
1877+
fn trilu_upper() {
1878+
let device = Default::default();
1879+
let model: trilu_upper::Model<Backend> = trilu_upper::Model::new(&device);
1880+
let input = Tensor::<Backend, 3>::from_floats(
1881+
[[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]],
1882+
&device,
1883+
);
1884+
let expected = TensorData::from([[
1885+
[1.0_f32, 2.0_f32, 3.0_f32],
1886+
[0.0_f32, 5.0_f32, 6.0_f32],
1887+
[0.0_f32, 0.0_f32, 9.0_f32],
1888+
]]);
1889+
1890+
let output = model.forward(input).to_data();
1891+
1892+
output.assert_eq(&expected, true);
1893+
}
1894+
1895+
#[test]
1896+
fn trilu_lower() {
1897+
let device = Default::default();
1898+
let model: trilu_lower::Model<Backend> = trilu_lower::Model::new(&device);
1899+
let input = Tensor::<Backend, 3>::from_floats(
1900+
[[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]],
1901+
&device,
1902+
);
1903+
let expected = TensorData::from([[
1904+
[1.0_f32, 0.0_f32, 0.0_f32],
1905+
[4.0_f32, 5.0_f32, 0.0_f32],
1906+
[7.0_f32, 8.0_f32, 9.0_f32],
1907+
]]);
1908+
1909+
let output = model.forward(input).to_data();
1910+
1911+
output.assert_eq(&expected, true);
1912+
}
1913+
18741914
#[test]
18751915
fn unsqueeze() {
18761916
let device = Default::default();
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python3
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class TriluModel(nn.Module):
8+
def __init__(self, upper=True, diagonal=0):
9+
super(TriluModel, self).__init__()
10+
self.upper = upper # Determines upper or lower triangular
11+
self.diagonal = diagonal # Diagonal offset
12+
13+
def forward(self, x):
14+
# torch.tril or torch.triu based on 'upper' attribute
15+
if self.upper:
16+
return torch.triu(x, diagonal=self.diagonal)
17+
else:
18+
return torch.tril(x, diagonal=self.diagonal)
19+
20+
21+
def main():
22+
# Set seed for reproducibility
23+
torch.manual_seed(42)
24+
25+
# Set print options for better precision output
26+
torch.set_printoptions(precision=8)
27+
28+
# Export Trilu Upper Model
29+
upper = True # Change to False for lower triangular matrix
30+
diagonal = 1 # Change k to adjust the diagonal
31+
lower_model = TriluModel(upper=upper, diagonal=diagonal)
32+
lower_model.eval()
33+
device = torch.device("cpu")
34+
35+
# Generate test input: a 2D matrix or batch of 2D matrices
36+
upper_file_name = "trilu_upper.onnx"
37+
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices
38+
torch.onnx.export(lower_model, test_input, upper_file_name,
39+
verbose=False, opset_version=16)
40+
41+
print("Finished exporting model to {}".format(upper_file_name))
42+
43+
# Output some test data for use in the test
44+
print("Test input data: {}".format(test_input))
45+
print("Test input data shape: {}".format(test_input.shape))
46+
output = lower_model.forward(test_input)
47+
print("Test output data shape: {}".format(output.shape))
48+
print("Test output: {}".format(output))
49+
50+
51+
# Export Trilu Lower Model
52+
upper = False
53+
diagonal = 1
54+
lower_model = TriluModel(upper=upper, diagonal=diagonal)
55+
lower_model.eval()
56+
# Generate test input: a 2D matrix or batch of 2D matrices
57+
upper_file_name = "trilu_lower.onnx"
58+
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices
59+
torch.onnx.export(lower_model, test_input, upper_file_name,
60+
verbose=False, opset_version=16)
61+
62+
print("Finished exporting model to {}".format(upper_file_name))
63+
64+
print("Test input data: {}".format(test_input))
65+
print("Test input data shape: {}".format(test_input.shape))
66+
output = lower_model.forward(test_input)
67+
print("Test output data shape: {}".format(output.shape))
68+
print("Test output: {}".format(output))
69+
70+
if __name__ == '__main__':
71+
main()
Binary file not shown.
Binary file not shown.

crates/burn-import/src/burn/node/base.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ use super::{
1111
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
1212
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
1313
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
14-
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
14+
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode,
15+
unsqueeze::UnsqueezeNode,
1516
};
1617
use crate::burn::{BurnImports, Scope, Type};
1718
use burn::backend::NdArray;
@@ -114,6 +115,7 @@ pub enum Node<PS: PrecisionSettings> {
114115
Squeeze(SqueezeNode),
115116
Sum(SumNode),
116117
Tile(TileNode),
118+
Trilu(TriluNode),
117119
Unary(UnaryNode),
118120
Unsqueeze(UnsqueezeNode),
119121
Where(WhereNode),
@@ -162,6 +164,7 @@ macro_rules! match_all {
162164
Node::Squeeze(node) => $func(node),
163165
Node::Sum(node) => $func(node),
164166
Node::Tile(node) => $func(node),
167+
Node::Trilu(node) => $func(node),
165168
Node::Unary(node) => $func(node),
166169
Node::Unsqueeze(node) => $func(node),
167170
Node::Where(node) => $func(node),
@@ -218,6 +221,7 @@ impl<PS: PrecisionSettings> Node<PS> {
218221
Node::Squeeze(_) => "squeeze",
219222
Node::Sum(_) => "add",
220223
Node::Tile(_) => "tile",
224+
Node::Trilu(_) => "trilu",
221225
Node::Unary(unary) => unary.kind.as_str(),
222226
Node::Unsqueeze(_) => "unsqueeze",
223227
Node::Where(_) => "where",

crates/burn-import/src/burn/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pub(crate) mod slice;
3737
pub(crate) mod squeeze;
3838
pub(crate) mod sum;
3939
pub(crate) mod tile;
40+
pub(crate) mod trilu;
4041
pub(crate) mod unary;
4142
pub(crate) mod unsqueeze;
4243
pub(crate) use base::*;
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use super::{Node, NodeCodegen};
2+
use crate::burn::{Scope, TensorType, ToTokens, Type};
3+
use burn::config::Config;
4+
use burn::record::PrecisionSettings;
5+
use proc_macro2::TokenStream;
6+
use quote::quote;
7+
8+
#[derive(Config, Debug)]
9+
pub struct TriluConfig {
10+
pub upper: bool,
11+
pub diagonal: i64,
12+
}
13+
14+
#[derive(Debug, Clone, new)]
15+
pub struct TriluNode {
16+
pub input: TensorType,
17+
pub output: TensorType,
18+
pub config: TriluConfig,
19+
}
20+
21+
impl<PS: PrecisionSettings> NodeCodegen<PS> for TriluNode {
22+
fn output_types(&self) -> Vec<Type> {
23+
vec![Type::Tensor(self.output.clone())]
24+
}
25+
fn input_types(&self) -> Vec<Type> {
26+
vec![Type::Tensor(self.input.clone())]
27+
}
28+
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
29+
let input = scope.tensor_use_owned(&self.input, node_position);
30+
let output = &self.output.name;
31+
let diagonal = self.config.diagonal.to_tokens();
32+
if self.config.upper {
33+
quote! {
34+
let #output = #input.triu(#diagonal);
35+
}
36+
} else {
37+
quote! {
38+
let #output = #input.tril(#diagonal);
39+
}
40+
}
41+
}
42+
fn into_node(self) -> super::Node<PS> {
43+
Node::Trilu(self)
44+
}
45+
}
46+
47+
#[cfg(test)]
48+
mod tests {
49+
use super::*;
50+
use crate::burn::{
51+
graph::BurnGraph,
52+
node::{test::assert_tokens, trilu::TriluConfig, trilu::TriluNode},
53+
TensorType,
54+
};
55+
use burn::record::FullPrecisionSettings;
56+
57+
#[test]
58+
fn test_codegen_triu() {
59+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
60+
let config = TriluConfig::new(true, 0);
61+
graph.register(TriluNode::new(
62+
TensorType::new_float("input", 2),
63+
TensorType::new_float("output", 2),
64+
config,
65+
));
66+
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
67+
68+
let expected = quote! {
69+
use burn::{
70+
module::Module,
71+
tensor::{backend::Backend, Tensor},
72+
};
73+
74+
#[derive(Module, Debug)]
75+
pub struct Model<B: Backend> {
76+
phantom: core::marker::PhantomData<B>,
77+
device: burn::module::Ignored<B::Device>,
78+
}
79+
80+
impl<B: Backend> Model<B> {
81+
#[allow(unused_variables)]
82+
pub fn new(device: &B::Device) -> Self {
83+
Self {
84+
phantom: core::marker::PhantomData,
85+
device: burn::module::Ignored(device.clone()),
86+
}
87+
}
88+
#[allow(clippy::let_and_return, clippy::approx_constant)]
89+
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
90+
let output = input.triu(0);
91+
output
92+
}
93+
}
94+
};
95+
96+
assert_tokens(graph.codegen(), expected);
97+
}
98+
99+
#[test]
100+
fn test_codegen_tril() {
101+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
102+
let config = TriluConfig::new(false, 0);
103+
graph.register(TriluNode::new(
104+
TensorType::new_float("input", 2),
105+
TensorType::new_float("output", 2),
106+
config,
107+
));
108+
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
109+
110+
let expected = quote! {
111+
use burn::{
112+
module::Module,
113+
tensor::{backend::Backend, Tensor},
114+
};
115+
116+
#[derive(Module, Debug)]
117+
pub struct Model<B: Backend> {
118+
phantom: core::marker::PhantomData<B>,
119+
device: burn::module::Ignored<B::Device>,
120+
}
121+
122+
impl<B: Backend> Model<B> {
123+
#[allow(unused_variables)]
124+
pub fn new(device: &B::Device) -> Self {
125+
Self {
126+
phantom: core::marker::PhantomData,
127+
device: burn::module::Ignored(device.clone()),
128+
}
129+
}
130+
#[allow(clippy::let_and_return, clippy::approx_constant)]
131+
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
132+
let output = input.tril(0);
133+
output
134+
}
135+
}
136+
};
137+
assert_tokens(graph.codegen(), expected);
138+
}
139+
}

crates/burn-import/src/onnx/op_configuration.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ use burn::nn::{
77
PaddingConfig2d, PaddingConfig3d,
88
};
99

10-
use crate::burn::node::{expand::ExpandShape, pad::PadConfig, tile::TileConfig};
10+
use crate::burn::node::{
11+
expand::ExpandShape, pad::PadConfig, tile::TileConfig, trilu::TriluConfig,
12+
};
1113
use onnx_ir::ir::{ArgType, AttributeValue, Data, ElementType, Node};
1214

1315
/// Create a Conv1dConfig from the attributes of the node
@@ -795,6 +797,25 @@ pub fn tile_config(node: &Node) -> TileConfig {
795797
TileConfig::new(repeat)
796798
}
797799

800+
/// Create a TriluConfig from the attributes of the node
801+
pub fn trilu_config(node: &Node) -> TriluConfig {
802+
let mut upper = true;
803+
let mut diagonal = 0;
804+
for (key, value) in node.attrs.iter() {
805+
match key.as_str() {
806+
"upper" => upper = value.clone().into_i64() != 0,
807+
_ => {}
808+
}
809+
}
810+
// The second input of the Trilu node is the diagonal value, coming from a constant node
811+
if let Some(diagonal_arg) = node.inputs.get(1) {
812+
if let Some(Data::Int64(diagonal_val)) = &diagonal_arg.value {
813+
diagonal = *diagonal_val;
814+
}
815+
}
816+
TriluConfig::new(upper, diagonal)
817+
}
818+
798819
/// Create a PadConfig from the attributes of the node
799820
pub fn pad_config(node: &Node) -> PadConfig {
800821
fn get_pads_input(node: &Node) -> Vec<i64> {

0 commit comments

Comments
 (0)