Skip to content

Commit 49bc3e9

Browse files
authored
Merge pull request #49 from JDAI-CV/support_arbitrarily_channels
Support arbitrary channels
2 parents c08695a + ce6e843 commit 49bc3e9

37 files changed

+467
-8715
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@
1313
[submodule "third_party/protobuf"]
1414
path = third_party/protobuf
1515
url = https://github.com/protocolbuffers/protobuf
16+
[submodule "third_party/flatbuffers"]
17+
path = third_party/flatbuffers
18+
url = https://github.com/google/flatbuffers

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ include(cmake/system.cmake)
4242
include(cmake/glog.cmake)
4343
configure_glog()
4444

45+
include(cmake/flatbuffers.cmake)
46+
configure_flatbuffers()
47+
4548
add_compile_options("-DEIGEN_MPL2_ONLY")
4649
if (${BNN_NET_BENCHMARK})
4750
add_compile_options("-DBNN_BENCHMARK")

benchmark/benchmark.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
#include <dabnn/net.h>
1515

1616
static void BM_pack_mat_64_small(benchmark::State &state) {
17-
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, 0);
18-
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, 0);
17+
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, false);
18+
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, false);
1919
for (auto _ : state) {
2020
pack_mat_64(a, b);
2121
}
2222
}
2323

2424
#ifdef __aarch64__
2525
static void BM_pack_mat_128_small(benchmark::State &state) {
26-
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, 0);
27-
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, 0);
26+
const bnn::Mat a(1, 32, 32, 128, bnn::DataType::Float, false);
27+
bnn::Mat b(1, 32, 32, 128, bnn::DataType::Bit, false);
2828
for (auto _ : state) {
2929
pack_mat_128(a, b);
3030
}

cmake/flatbuffers.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
function(configure_flatbuffers)
2+
option(FLATBUFFERS_BUILD_TESTS "Enable the build of tests and samples." OFF)
3+
option(FLATBUFFERS_BUILD_FLATHASH "Enable the build of flathash" OFF)
4+
option(FLATBUFFERS_BUILD_FLATC "Enable the build of the flatbuffers compiler"
5+
OFF)
6+
option(FLATBUFFERS_BUILD_FLATLIB "Enable the build of the flatbuffers library"
7+
ON)
8+
add_subdirectory(third_party/flatbuffers)
9+
endfunction()
10+

common/baseline.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ inline void baseline_bconv(const Mat &input, const Mat &weight,
8080
const int stride_w, const int dilation_h,
8181
const int dilation_w, const int output_channels,
8282
Mat &output) {
83+
BNN_ASSERT(weight.total() % weight.n == 0, "");
84+
const auto HWC = weight.total() / weight.n;
8385
int input_y = 0;
8486
FORZ(th, output.h) {
8587
int input_x = 0;
@@ -91,7 +93,7 @@ inline void baseline_bconv(const Mat &input, const Mat &weight,
9193
FORZ(ww, kernel_w) {
9294
int x = input_x - pad_w + ww * dilation_w;
9395
FORZ(wc, input.c) {
94-
int idx = tc * kernel_h * kernel_w * input.c +
96+
int idx = tc * HWC +
9597
wh * kernel_w * input.c + ww * input.c +
9698
wc;
9799
const auto w_value =

common/common_bitpack.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,24 @@
99

1010
#include <common/helper.h>
1111

12-
inline void pack_64_bitset(const float *fptr, uint64_t *buf) {
12+
inline void pack_64_bitset(const float *fptr, uint64_t *buf,
13+
const size_t eff_bits = 64) {
14+
/**
15+
* The eff_bits is to support non-128-multiple channels.
16+
* In this case, we need pad the tensor to make the
17+
* channel aligned with 128.
18+
*/
19+
// BNN_ASSERT(eff_bits == 64, eff_bits);
1320
const size_t UNIT_LEN = 64;
21+
BNN_ASSERT(eff_bits <= UNIT_LEN, "The eff_bits ", eff_bits,
22+
" must be smaller than UNIT_LEN ", UNIT_LEN);
1423
std::bitset<UNIT_LEN> bits;
1524
for (size_t i = 0; i < UNIT_LEN; i++) {
16-
bits[i] = (*(fptr + i) > 0);
25+
if (i < eff_bits) {
26+
bits[i] = (*(fptr + i) > 0);
27+
} else {
28+
bits[i] = 0;
29+
}
1730
}
1831
static_assert(sizeof(decltype(bits.to_ullong())) * CHAR_BIT == 64,
1932
"bits.to_ullong() must return a 64-bit element");

common/dab.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ table Tensor {
1010
float32_data: [float32];
1111
shape: [uint32];
1212
name: string;
13+
align_hwc_to_128: bool;
1314
}
1415

1516
table Input {

0 commit comments

Comments
 (0)