Skip to content

Commit

Permalink
Add binding to addGather to NetworkDefinition object
Browse files Browse the repository at this point in the history
Implements the binding for the network definition as well as the `GatherLayer` object
itself.

Also added the tensorrt_rs_derive create to create a custom derive macro for deriving the
`Layer` trait. We're going to be doing this a lot for the next handful of changes so it was
just easier to create a custom derive macro.
  • Loading branch information
mstallmo committed Oct 13, 2020
1 parent ea8ff5f commit 58be907
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 38 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[workspace]
members = [
"tensorrt",
"tensorrt-sys"
"tensorrt-sys",
"tensorrt_rs_derive",
]
20 changes: 20 additions & 0 deletions tensorrt-sys/trt-sys/TRTLayer/TRTGatherLayer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//
// Created by mason on 10/12/20.
//

#include "TRTGatherLayer.h"
#include "TRTLayerInternal.hpp"

int32_t gather_layer_get_gather_axis(Layer_t *layer) {
auto concrete = dynamic_cast<nvinfer1::IGatherLayer*>(layer->internal_layer);
return concrete->getGatherAxis();
}

void gather_layer_set_gather_axis(Layer_t *layer, int32_t axis) {
auto concrete = dynamic_cast<nvinfer1::IGatherLayer*>(layer->internal_layer);
concrete->setGatherAxis(axis);
}

void gather_layer_destroy(Layer_t *layer) {
delete layer;
}
25 changes: 25 additions & 0 deletions tensorrt-sys/trt-sys/TRTLayer/TRTGatherLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//
// Created by mason on 10/12/20.
//

#ifndef LIBTRT_TRTGATHERLAYER_H
#define LIBTRT_TRTGATHERLAYER_H

#include <stdint.h>
#include "TRTLayer.h"

#ifdef __cplusplus
extern "C" {
#endif

int32_t gather_layer_get_gather_axis(Layer_t *layer);

void gather_layer_set_gather_axis(Layer_t *layer, int32_t axis);

void gather_layer_destroy(Layer_t *layer);

#ifdef __cplusplus
};
#endif

#endif //LIBTRT_TRTGATHERLAYER_H
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,7 @@ Layer_t *network_add_element_wise(Network_t *network, Tensor_t *input1, Tensor_t
return new Layer(network->internal_network->addElementWise(*input1->internal_tensor, *input2->internal_tensor,
static_cast<nvinfer1::ElementWiseOperation>(op)));
}

Layer_t *network_add_gather(Network_t *network, Tensor_t *data, Tensor_t *indices, int32_t axis) {
return new Layer(network->internal_network->addGather(*data->internal_tensor, *indices->internal_tensor, axis));
}
10 changes: 6 additions & 4 deletions tensorrt-sys/trt-sys/TRTNetworkDefinition/TRTNetworkDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ typedef struct Network Network_t;

void destroy_network(Network_t *network);

Tensor_t* network_add_input(Network_t *network, const char* name, DataType_t dataType, Dims_t *dims);
Tensor_t *network_add_input(Network_t *network, const char *name, DataType_t dataType, Dims_t *dims);

int network_get_nb_layers(Network_t *network);

Layer_t* network_get_layer(Network_t *network, int index);
Layer_t *network_get_layer(Network_t *network, int index);

Layer_t* network_add_identity_layer(Network_t *network, Tensor_t* inputTensor);
Layer_t *network_add_identity_layer(Network_t *network, Tensor_t *inputTensor);

int network_get_nb_inputs(Network_t *network);

Expand All @@ -42,7 +42,9 @@ void network_mark_output(Network_t *network, Tensor_t *tensor);
void network_unmark_output(Network_t *network, Tensor_t *tensor);


Layer_t* network_add_element_wise(Network_t *network, Tensor_t *input1, Tensor_t *input2, ElementWiseOperation_t op);
Layer_t *network_add_element_wise(Network_t *network, Tensor_t *input1, Tensor_t *input2, ElementWiseOperation_t op);

Layer_t *network_add_gather(Network_t *network, Tensor_t *data, Tensor_t *indices, int32_t axis);

#ifdef __cplusplus
};
Expand Down
1 change: 1 addition & 0 deletions tensorrt-sys/trt-sys/tensorrt_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
#include "TRTTensor/TRTTensor.h"
#include "TRTLayer/TRTLayer.h"
#include "TRTLayer/TRTElementWiseLayer.h"
#include "TRTLayer/TRTGatherLayer.h"

#endif //TENSRORT_SYS_TENSORRT_API_H
3 changes: 2 additions & 1 deletion tensorrt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ trt-7 = ["tensorrt-sys/trt-7"]

[dependencies]
# Uncomment when working locally
tensorrt-sys = { path = "../tensorrt-sys"}
tensorrt-sys = { path = "../tensorrt-sys" }
# tensorrt-sys = "0.3"
tensorrt_rs_derive = { path = "../tensorrt_rs_derive" }
ndarray = "0.13"
ndarray-image = "0.2"
image = "0.23"
Expand Down
11 changes: 2 additions & 9 deletions tensorrt/src/network/layer/element_wise_layer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::private::LayerPrivate;
use super::*;
use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use tensorrt_rs_derive::Layer;
use tensorrt_sys::{elementwise_destroy, elementwise_get_operation, elementwise_set_operation};

#[repr(C)]
Expand All @@ -16,6 +16,7 @@ pub enum ElementWiseOperation {
Pow,
}

#[derive(Layer)]
pub struct ElementWiseLayer {
pub(crate) internal_layer: *mut tensorrt_sys::Layer_t,
}
Expand All @@ -31,14 +32,6 @@ impl ElementWiseLayer {
}
}

impl LayerPrivate for ElementWiseLayer {
fn get_internal_layer(&self) -> *mut tensorrt_sys::Layer_t {
self.internal_layer
}
}

impl Layer for ElementWiseLayer {}

impl Drop for ElementWiseLayer {
fn drop(&mut self) {
unsafe { elementwise_destroy(self.internal_layer) }
Expand Down
77 changes: 77 additions & 0 deletions tensorrt/src/network/layer/gather_layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use super::*;
use tensorrt_rs_derive::Layer;
use tensorrt_sys::{
gather_layer_destroy, gather_layer_get_gather_axis, gather_layer_set_gather_axis,
};

#[derive(Layer)]
pub struct GatherLayer {
pub(crate) internal_layer: *mut tensorrt_sys::Layer_t,
}

impl GatherLayer {
pub fn get_gather_axis(&self) -> i32 {
unsafe { gather_layer_get_gather_axis(self.internal_layer) }
}

pub fn set_gather_axis(&self, axis: i32) {
unsafe { gather_layer_set_gather_axis(self.internal_layer, axis) }
}
}

impl Drop for GatherLayer {
fn drop(&mut self) {
unsafe { gather_layer_destroy(self.internal_layer) }
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::builder::Builder;
use crate::dims::DimsHW;
use crate::network::Network;
use crate::runtime::Logger;
use lazy_static::lazy_static;
use std::sync::Mutex;

lazy_static! {
static ref LOGGER: Mutex<Logger> = Mutex::new(Logger::new());
}

fn create_network(logger: &Logger) -> Network {
let builder = Builder::new(logger);
builder.create_network()
}

#[test]
fn get_gather_axis() {
let logger = match LOGGER.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let network = create_network(&logger);

let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10));
let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(10, 10));
let gather_layer = network.add_gather_layer(&input1, &input2, 1);

assert_eq!(gather_layer.get_gather_axis(), 1);
}

#[test]
fn set_gather_axis() {
let logger = match LOGGER.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let network = create_network(&logger);

let input1 = network.add_input("new_input1", DataType::Float, DimsHW::new(10, 10));
let input2 = network.add_input("new_input2", DataType::Float, DimsHW::new(10, 10));
let gather_layer = network.add_gather_layer(&input1, &input2, 1);

gather_layer.set_gather_axis(0);
assert_eq!(gather_layer.get_gather_axis(), 0);
}
}
11 changes: 2 additions & 9 deletions tensorrt/src/network/layer/identity_layer.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
use super::private::LayerPrivate;
use super::*;
use tensorrt_rs_derive::Layer;

#[derive(Layer)]
pub struct IdentityLayer {
pub(crate) internal_layer: *mut tensorrt_sys::Layer_t,
}

impl Layer for IdentityLayer {}

impl LayerPrivate for IdentityLayer {
fn get_internal_layer(&self) -> *mut tensorrt_sys::Layer_t {
self.internal_layer
}
}
14 changes: 4 additions & 10 deletions tensorrt/src/network/layer/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
pub use element_wise_layer::{ElementWiseLayer, ElementWiseOperation};
pub use gather_layer::GatherLayer;
pub use identity_layer::IdentityLayer;

mod element_wise_layer;
mod gather_layer;
mod identity_layer;

use crate::engine::DataType;
use crate::network::Tensor;
use num_derive::FromPrimitive;
use num_traits::FromPrimitive;
use std::ffi::{CStr, CString};
use tensorrt_rs_derive::Layer;
use tensorrt_sys::{
layer_get_input, layer_get_name, layer_get_nb_inputs, layer_get_nb_outputs, layer_get_output,
layer_get_output_type, layer_get_precision, layer_get_type, layer_output_type_is_set,
Expand Down Expand Up @@ -133,20 +136,11 @@ mod private {
}
}

#[derive(Layer)]
pub struct BaseLayer {
pub(crate) internal_layer: *mut tensorrt_sys::Layer_t,
}

impl BaseLayer {}

impl Layer for BaseLayer {}

impl private::LayerPrivate for BaseLayer {
fn get_internal_layer(&self) -> *mut tensorrt_sys::Layer_t {
self.internal_layer
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
36 changes: 32 additions & 4 deletions tensorrt/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use crate::engine::DataType;
use layer::*;
use std::ffi::{CStr, CString};
use tensorrt_sys::{
destroy_network, network_add_element_wise, network_add_identity_layer, network_add_input,
network_get_input, network_get_layer, network_get_nb_inputs, network_get_nb_layers,
network_get_nb_outputs, network_get_output, network_mark_output, network_remove_tensor,
network_unmark_output, tensor_destroy, tensor_get_name, tensor_set_dimensions,
destroy_network, network_add_element_wise, network_add_gather, network_add_identity_layer,
network_add_input, network_get_input, network_get_layer, network_get_nb_inputs,
network_get_nb_layers, network_get_nb_outputs, network_get_output, network_mark_output,
network_remove_tensor, network_unmark_output, tensor_destroy, tensor_get_name,
tensor_set_dimensions,
};

pub struct Network {
Expand Down Expand Up @@ -94,6 +95,18 @@ impl Network {
};
ElementWiseLayer { internal_layer }
}

pub fn add_gather_layer(&self, data: &Tensor, indicies: &Tensor, axis: i32) -> GatherLayer {
let internal_layer = unsafe {
network_add_gather(
self.internal_network,
data.internal_tensor,
indicies.internal_tensor,
axis,
)
};
GatherLayer { internal_layer }
}
}

impl Drop for Network {
Expand Down Expand Up @@ -316,4 +329,19 @@ mod tests {
assert_eq!(network.get_nb_layers(), 1);
assert_eq!(network.get_layer(0).get_type(), LayerType::ElementWise);
}

#[test]
fn add_gather_layer() {
let logger = match LOGGER.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let network = create_network(&logger);
let input1 = network.add_input("new_input1", DataType::Float, DimsCHW::new(1, 28, 28));
let input2 = network.add_input("new_input2", DataType::Float, DimsCHW::new(1, 28, 28));
network.add_gather_layer(&input1, &input2, 1);

assert_eq!(network.get_nb_layers(), 1);
assert_eq!(network.get_layer(0).get_type(), LayerType::Gather);
}
}
12 changes: 12 additions & 0 deletions tensorrt_rs_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "tensorrt_rs_derive"
version = "0.1.0"
authors = ["Mason Stallmo <[email protected]>"]
edition = "2018"

[lib]
proc-macro = true

[dependencies]
syn = "1.0.44"
quote = "1.0.7"
25 changes: 25 additions & 0 deletions tensorrt_rs_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use proc_macro::TokenStream;
use quote::quote;
use syn;

#[proc_macro_derive(Layer)]
pub fn layer_derive(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();

impl_layer_derive(&ast)
}

fn impl_layer_derive(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let gen = quote! {
impl private::LayerPrivate for #name {
fn get_internal_layer(&self) -> *mut tensorrt_sys::Layer_t {
self.internal_layer
}
}

impl Layer for #name {}
};

gen.into()
}

0 comments on commit 58be907

Please sign in to comment.