-
Notifications
You must be signed in to change notification settings - Fork 443
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add trilu.py and onnx model for operators * add node/trilu.rs under burn-import and modify related files * update trilu codegen tests * add op_configuration and to_burn rust files * update how to get diagonal value from nodes argument * add trilu tests in test_onnx.rs * update files to follow clippy and format rules * add tests for both of upper and lower to follow review comment * delete onnx model of trilu.onnx since upper and lower models are added
- Loading branch information
Showing
13 changed files
with
295 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,6 @@ target | |
|
||
# Generated IR and Burn Graph from ONNX | ||
out | ||
|
||
# Virtual Environment of Python | ||
.venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class TriluModel(nn.Module): | ||
def __init__(self, upper=True, diagonal=0): | ||
super(TriluModel, self).__init__() | ||
self.upper = upper # Determines upper or lower triangular | ||
self.diagonal = diagonal # Diagonal offset | ||
|
||
def forward(self, x): | ||
# torch.tril or torch.triu based on 'upper' attribute | ||
if self.upper: | ||
return torch.triu(x, diagonal=self.diagonal) | ||
else: | ||
return torch.tril(x, diagonal=self.diagonal) | ||
|
||
|
||
def main(): | ||
# Set seed for reproducibility | ||
torch.manual_seed(42) | ||
|
||
# Set print options for better precision output | ||
torch.set_printoptions(precision=8) | ||
|
||
# Export Trilu Upper Model | ||
upper = True # Change to False for lower triangular matrix | ||
diagonal = 1 # Change k to adjust the diagonal | ||
lower_model = TriluModel(upper=upper, diagonal=diagonal) | ||
lower_model.eval() | ||
device = torch.device("cpu") | ||
|
||
# Generate test input: a 2D matrix or batch of 2D matrices | ||
upper_file_name = "trilu_upper.onnx" | ||
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices | ||
torch.onnx.export(lower_model, test_input, upper_file_name, | ||
verbose=False, opset_version=16) | ||
|
||
print("Finished exporting model to {}".format(upper_file_name)) | ||
|
||
# Output some test data for use in the test | ||
print("Test input data: {}".format(test_input)) | ||
print("Test input data shape: {}".format(test_input.shape)) | ||
output = lower_model.forward(test_input) | ||
print("Test output data shape: {}".format(output.shape)) | ||
print("Test output: {}".format(output)) | ||
|
||
|
||
# Export Trilu Lower Model | ||
upper = False | ||
diagonal = 1 | ||
lower_model = TriluModel(upper=upper, diagonal=diagonal) | ||
lower_model.eval() | ||
# Generate test input: a 2D matrix or batch of 2D matrices | ||
upper_file_name = "trilu_lower.onnx" | ||
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices | ||
torch.onnx.export(lower_model, test_input, upper_file_name, | ||
verbose=False, opset_version=16) | ||
|
||
print("Finished exporting model to {}".format(upper_file_name)) | ||
|
||
print("Test input data: {}".format(test_input)) | ||
print("Test input data shape: {}".format(test_input.shape)) | ||
output = lower_model.forward(test_input) | ||
print("Test output data shape: {}".format(output.shape)) | ||
print("Test output: {}".format(output)) | ||
|
||
if __name__ == '__main__': | ||
main() |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
use super::{Node, NodeCodegen}; | ||
use crate::burn::{Scope, TensorType, ToTokens, Type}; | ||
use burn::config::Config; | ||
use burn::record::PrecisionSettings; | ||
use proc_macro2::TokenStream; | ||
use quote::quote; | ||
|
||
#[derive(Config, Debug)] | ||
pub struct TriluConfig { | ||
pub upper: bool, | ||
pub diagonal: i64, | ||
} | ||
|
||
#[derive(Debug, Clone, new)] | ||
pub struct TriluNode { | ||
pub input: TensorType, | ||
pub output: TensorType, | ||
pub config: TriluConfig, | ||
} | ||
|
||
impl<PS: PrecisionSettings> NodeCodegen<PS> for TriluNode { | ||
fn output_types(&self) -> Vec<Type> { | ||
vec![Type::Tensor(self.output.clone())] | ||
} | ||
fn input_types(&self) -> Vec<Type> { | ||
vec![Type::Tensor(self.input.clone())] | ||
} | ||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { | ||
let input = scope.tensor_use_owned(&self.input, node_position); | ||
let output = &self.output.name; | ||
let diagonal = self.config.diagonal.to_tokens(); | ||
if self.config.upper { | ||
quote! { | ||
let #output = #input.triu(#diagonal); | ||
} | ||
} else { | ||
quote! { | ||
let #output = #input.tril(#diagonal); | ||
} | ||
} | ||
} | ||
fn into_node(self) -> super::Node<PS> { | ||
Node::Trilu(self) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::burn::{ | ||
graph::BurnGraph, | ||
node::{test::assert_tokens, trilu::TriluConfig, trilu::TriluNode}, | ||
TensorType, | ||
}; | ||
use burn::record::FullPrecisionSettings; | ||
|
||
#[test] | ||
fn test_codegen_triu() { | ||
let mut graph = BurnGraph::<FullPrecisionSettings>::default(); | ||
let config = TriluConfig::new(true, 0); | ||
graph.register(TriluNode::new( | ||
TensorType::new_float("input", 2), | ||
TensorType::new_float("output", 2), | ||
config, | ||
)); | ||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); | ||
|
||
let expected = quote! { | ||
use burn::{ | ||
module::Module, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Model<B: Backend> { | ||
phantom: core::marker::PhantomData<B>, | ||
device: burn::module::Ignored<B::Device>, | ||
} | ||
|
||
impl<B: Backend> Model<B> { | ||
#[allow(unused_variables)] | ||
pub fn new(device: &B::Device) -> Self { | ||
Self { | ||
phantom: core::marker::PhantomData, | ||
device: burn::module::Ignored(device.clone()), | ||
} | ||
} | ||
#[allow(clippy::let_and_return, clippy::approx_constant)] | ||
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> { | ||
let output = input.triu(0); | ||
output | ||
} | ||
} | ||
}; | ||
|
||
assert_tokens(graph.codegen(), expected); | ||
} | ||
|
||
#[test] | ||
fn test_codegen_tril() { | ||
let mut graph = BurnGraph::<FullPrecisionSettings>::default(); | ||
let config = TriluConfig::new(false, 0); | ||
graph.register(TriluNode::new( | ||
TensorType::new_float("input", 2), | ||
TensorType::new_float("output", 2), | ||
config, | ||
)); | ||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); | ||
|
||
let expected = quote! { | ||
use burn::{ | ||
module::Module, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Model<B: Backend> { | ||
phantom: core::marker::PhantomData<B>, | ||
device: burn::module::Ignored<B::Device>, | ||
} | ||
|
||
impl<B: Backend> Model<B> { | ||
#[allow(unused_variables)] | ||
pub fn new(device: &B::Device) -> Self { | ||
Self { | ||
phantom: core::marker::PhantomData, | ||
device: burn::module::Ignored(device.clone()), | ||
} | ||
} | ||
#[allow(clippy::let_and_return, clippy::approx_constant)] | ||
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> { | ||
let output = input.tril(0); | ||
output | ||
} | ||
} | ||
}; | ||
assert_tokens(graph.codegen(), expected); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.