-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add binding to
addGather
to NetworkDefinition object
Implements the binding for the network definition as well as the `GatherLayer` object itself. Also added the tensorrt_rs_derive create to create a custom derive macro for deriving the `Layer` trait. We're going to be doing this a lot for the next handful of changes so it was just easier to create a custom derive macro.
- Loading branch information
Showing
14 changed files
with
214 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
[workspace] | ||
members = [ | ||
"tensorrt", | ||
"tensorrt-sys" | ||
"tensorrt-sys", | ||
"tensorrt_rs_derive", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// | ||
// Created by mason on 10/12/20. | ||
// | ||
|
||
#include "TRTGatherLayer.h" | ||
#include "TRTLayerInternal.hpp" | ||
|
||
int32_t gather_layer_get_gather_axis(Layer_t *layer) { | ||
auto concrete = dynamic_cast<nvinfer1::IGatherLayer*>(layer->internal_layer); | ||
return concrete->getGatherAxis(); | ||
} | ||
|
||
void gather_layer_set_gather_axis(Layer_t *layer, int32_t axis) { | ||
auto concrete = dynamic_cast<nvinfer1::IGatherLayer*>(layer->internal_layer); | ||
concrete->setGatherAxis(axis); | ||
} | ||
|
||
void gather_layer_destroy(Layer_t *layer) { | ||
delete layer; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// | ||
// Created by mason on 10/12/20. | ||
// | ||
|
||
#ifndef LIBTRT_TRTGATHERLAYER_H | ||
#define LIBTRT_TRTGATHERLAYER_H | ||
|
||
#include <stdint.h> | ||
#include "TRTLayer.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
int32_t gather_layer_get_gather_axis(Layer_t *layer); | ||
|
||
void gather_layer_set_gather_axis(Layer_t *layer, int32_t axis); | ||
|
||
void gather_layer_destroy(Layer_t *layer); | ||
|
||
#ifdef __cplusplus | ||
}; | ||
#endif | ||
|
||
#endif //LIBTRT_TRTGATHERLAYER_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
use super::*; | ||
use tensorrt_rs_derive::Layer; | ||
use tensorrt_sys::{ | ||
gather_layer_destroy, gather_layer_get_gather_axis, gather_layer_set_gather_axis, | ||
}; | ||
|
||
#[derive(Layer)] | ||
pub struct GatherLayer { | ||
pub(crate) internal_layer: *mut tensorrt_sys::Layer_t, | ||
} | ||
|
||
impl GatherLayer { | ||
pub fn get_gather_axis(&self) -> i32 { | ||
unsafe { gather_layer_get_gather_axis(self.internal_layer) } | ||
} | ||
|
||
pub fn set_gather_axis(&self, axis: i32) { | ||
unsafe { gather_layer_set_gather_axis(self.internal_layer, axis) } | ||
} | ||
} | ||
|
||
impl Drop for GatherLayer { | ||
fn drop(&mut self) { | ||
unsafe { gather_layer_destroy(self.internal_layer) } | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::builder::Builder; | ||
use crate::dims::DimsHW; | ||
use crate::network::Network; | ||
use crate::runtime::Logger; | ||
use lazy_static::lazy_static; | ||
use std::sync::Mutex; | ||
|
||
lazy_static! { | ||
static ref LOGGER: Mutex<Logger> = Mutex::new(Logger::new()); | ||
} | ||
|
||
fn create_network(logger: &Logger) -> Network { | ||
let builder = Builder::new(logger); | ||
builder.create_network() | ||
} | ||
|
||
#[test] | ||
fn get_gather_axis() { | ||
let logger = match LOGGER.lock() { | ||
Ok(guard) => guard, | ||
Err(poisoned) => poisoned.into_inner(), | ||
}; | ||
let network = create_network(&logger); | ||
|
||
let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); | ||
let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(10, 10)); | ||
let gather_layer = network.add_gather_layer(&input1, &input2, 1); | ||
|
||
assert_eq!(gather_layer.get_gather_axis(), 1); | ||
} | ||
|
||
#[test] | ||
fn set_gather_axis() { | ||
let logger = match LOGGER.lock() { | ||
Ok(guard) => guard, | ||
Err(poisoned) => poisoned.into_inner(), | ||
}; | ||
let network = create_network(&logger); | ||
|
||
let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10)); | ||
let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(10, 10)); | ||
let gather_layer = network.add_gather_layer(&input1, &input2, 1); | ||
|
||
gather_layer.set_gather_axis(0); | ||
assert_eq!(gather_layer.get_gather_axis(), 0); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,7 @@ | ||
use super::private::LayerPrivate; | ||
use super::*; | ||
use tensorrt_rs_derive::Layer; | ||
|
||
#[derive(Layer)] | ||
pub struct IdentityLayer { | ||
pub(crate) internal_layer: *mut tensorrt_sys::Layer_t, | ||
} | ||
|
||
impl Layer for IdentityLayer {} | ||
|
||
impl LayerPrivate for IdentityLayer { | ||
fn get_internal_layer(&self) -> *mut tensorrt_sys::Layer_t { | ||
self.internal_layer | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "tensorrt_rs_derive" | ||
version = "0.1.0" | ||
authors = ["Mason Stallmo <[email protected]>"] | ||
edition = "2018" | ||
|
||
[lib] | ||
proc-macro = true | ||
|
||
[dependencies] | ||
syn = "1.0.44" | ||
quote = "1.0.7" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
use proc_macro::TokenStream; | ||
use quote::quote; | ||
use syn; | ||
|
||
#[proc_macro_derive(Layer)] | ||
pub fn layer_derive(input: TokenStream) -> TokenStream { | ||
let ast = syn::parse(input).unwrap(); | ||
|
||
impl_layer_derive(&ast) | ||
} | ||
|
||
fn impl_layer_derive(ast: &syn::DeriveInput) -> TokenStream { | ||
let name = &ast.ident; | ||
let gen = quote! { | ||
impl private::LayerPrivate for #name { | ||
fn get_internal_layer(&self) -> *mut tensorrt_sys::Layer_t { | ||
self.internal_layer | ||
} | ||
} | ||
|
||
impl Layer for #name {} | ||
}; | ||
|
||
gen.into() | ||
} |