Skip to content

Commit

Permalink
s32 popcount subgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
parth8mcw committed Aug 13, 2024
1 parent 5012886 commit 63fb0de
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 11 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ SUBGRAPH_SRCS = [
"src/subgraph/minimum2.c",
"src/subgraph/multiply2.c",
"src/subgraph/negate.c",
"src/subgraph/pop-count.c",
"src/subgraph/prelu.c",
"src/subgraph/reciprocal-square-root.c",
"src/subgraph/reshape-helpers.c",
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ SET(SUBGRAPH_SRCS
src/subgraph/minimum2.c
src/subgraph/multiply2.c
src/subgraph/negate.c
src/subgraph/pop-count.c
src/subgraph/prelu.c
src/subgraph/reciprocal-square-root.c
src/subgraph/reshape-helpers.c
Expand Down Expand Up @@ -1992,6 +1993,11 @@ IF(XNNPACK_BUILD_TESTS)
TARGET_LINK_LIBRARIES(reciprocal-square-root-test PRIVATE XNNPACK fp16 GTest::gtest GTest::gtest_main subgraph)
ADD_TEST(NAME reciprocal-square-root-test COMMAND reciprocal-square-root-test)

ADD_EXECUTABLE(pop-count-test test/pop-count.cc)
TARGET_INCLUDE_DIRECTORIES(pop-count-test PRIVATE src test)
TARGET_LINK_LIBRARIES(pop-count-test PRIVATE XNNPACK fp16 GTest::gtest GTest::gtest_main subgraph)
ADD_TEST(NAME pop-count-test COMMAND pop-count-test)

ADD_EXECUTABLE(reshape-helpers-test test/reshape-helpers.cc)
TARGET_INCLUDE_DIRECTORIES(reshape-helpers-test PRIVATE src test)
TARGET_LINK_LIBRARIES(reshape-helpers-test PRIVATE XNNPACK fp16 GTest::gtest GTest::gtest_main subgraph)
Expand Down
13 changes: 13 additions & 0 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,19 @@ enum xnn_status xnn_define_square_root(
uint32_t output_id,
uint32_t flags);

/// Define a Pop count Node and add it to a Subgraph.
///
/// @param subgraph - a Subgraph object that will own the created Node.
/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
/// shape must match the shape of the input tensor.
/// @param flags - binary features of the pop count Node. No supported flags are currently defined.
enum xnn_status xnn_define_pop_count(
xnn_subgraph_t subgraph,
uint32_t input_id,
uint32_t output_id,
uint32_t flags);

/// Define a Reciprocal Square Root Node and add it to a Subgraph.
///
/// @param subgraph - a Subgraph object that will own the created Node.
Expand Down
21 changes: 15 additions & 6 deletions src/configs/unary-elementwise-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ XNN_INIT_ONCE_GUARD(qu8_to_f32_cvt);
XNN_INIT_ONCE_GUARD(s8_clamp);
XNN_INIT_ONCE_GUARD(u8_clamp);
XNN_INIT_ONCE_GUARD(xx_copy);
XNN_INIT_ONCE_GUARD(s32_popcout);
XNN_INIT_ONCE_GUARD(s32_popcnt);

static void init_f16_abs_config(void) {
#if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR
Expand Down Expand Up @@ -1466,22 +1466,22 @@ static void init_s32_popcnt_config(void) {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) {
s32_popcnt_config.ukernel = (xnn_vbinary_ukernel_fn) xnn_s32_vpopcnt_ukernel__avx512f_u32;
s32_popcnt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_vpopcnt_ukernel__avx512f_u32;
s32_popcnt_config.element_tile = 32;
} else if (hardware_config->use_x86_avx2) {
s32_popcnt_config.ukernel = (xnn_vbinary_ukernel_fn) xnn_s32_vpopcnt_ukernel__avx2_u32;
s32_popcnt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_vpopcnt_ukernel__avx2_u32;
s32_popcnt_config.element_tile = 32;
} else{
s32_popcnt_config.ukernel = (xnn_vbinary_ukernel_fn) xnn_s32_vpopcnt_ukernel__sse41_u8;
s32_popcnt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_vpopcnt_ukernel__sse41_u8;
s32_popcnt_config.element_tile = 8;
}
#elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
s32_popcnt_config.ukernel = (xnn_vbinary_ukernel_fn) xnn_s32_vpopcnt_ukernel__wasmsimd_u16;
s32_popcnt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_vpopcnt_ukernel__wasmsimd_u16;
s32_popcnt_config.element_tile = 16;
#else
s32_popcnt_config.ukernel = (xnn_vbinary_ukernel_fn) xnn_s32_vpopcnt_ukernel__scalar_u2;
s32_popcnt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_vpopcnt_ukernel__scalar_u2;
s32_popcnt_config.element_tile = 2;
#endif
}
Expand Down Expand Up @@ -2760,6 +2760,15 @@ const struct xnn_unary_elementwise_config* xnn_init_f32_sqrt_config() {
return &f32_sqrt_config;
}

const struct xnn_unary_elementwise_config* xnn_init_s32_popcnt_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
return NULL;
}
XNN_INIT_ONCE(s32_popcnt);
return &s32_popcnt_config;
}

const struct xnn_unary_elementwise_config* xnn_init_f32_tanh_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
Expand Down
7 changes: 4 additions & 3 deletions src/enums/node-type.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
#include "xnnpack/node-type.h"

#if XNN_LOG_LEVEL > 0
static const uint16_t offset[62] = {
static const uint16_t offset[63] = {
0, 8, 12, 17, 35, 54, 71, 93, 101, 107, 120, 133, 146, 159, 167, 182, 187, 197, 214, 232, 257, 264, 268, 272, 284,
296, 308, 314, 330, 353, 358, 384, 410, 432, 454, 464, 468, 479, 494, 503, 512, 522, 529, 535, 558, 569, 574, 603,
611, 619, 637, 644, 656, 675, 695, 707, 722, 748, 761, 778, 787, 792
296, 308, 314, 330, 353, 358, 384, 410, 432, 454, 464, 468, 479, 494, 503, 512, 522, 529, 535, 545, 568, 579, 584,
613, 621, 629, 647, 654, 666, 685, 705, 717, 732, 758, 771, 788, 797, 802
};

static const char data[] =
Expand Down Expand Up @@ -63,6 +63,7 @@ static const char data[] =
"Multiply2\0"
"Negate\0"
"PReLU\0"
"Pop Count\0"
"Reciprocal Square Root\0"
"Reshape 2D\0"
"RoPE\0"
Expand Down
2 changes: 2 additions & 0 deletions src/enums/node-type.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@
string: "Negate"
- name: xnn_node_type_prelu
string: "PReLU"
- name: xnn_node_type_pop_count
string: "Pop Count"
- name: xnn_node_type_reciprocal_square_root
string: "Reciprocal Square Root"
- name: xnn_node_type_reshape_2d
Expand Down
4 changes: 2 additions & 2 deletions src/operators/unary-elementwise-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -3140,8 +3140,8 @@ enum xnn_status xnn_setup_square_root_nc_f32(

enum xnn_status xnn_setup_pop_count_nc_s32(
xnn_operator_t popcnt_op,
const float* input,
float* output)
const int32_t* input,
int32_t* output)
{
return setup_unary_elementwise_nc(
popcnt_op, xnn_operator_type_pop_count_nc_s32,
Expand Down
187 changes: 187 additions & 0 deletions src/subgraph/pop-count.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright 2024 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <inttypes.h>
#include <stddef.h>
#include <stdint.h>

#include "xnnpack.h"
#include "xnnpack/common.h"
#include "xnnpack/log.h"
#include "xnnpack/node-type.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
#include "xnnpack/reshape-helpers.h"
#include "xnnpack/subgraph-validation.h"
#include "xnnpack/subgraph.h"
#include "pthreadpool.h"

static enum xnn_status create_pop_count_operator(
const struct xnn_node* node,
const struct xnn_value* values,
size_t num_values,
struct xnn_operator_data* opdata,
struct xnn_code_cache* code_cache,
xnn_weights_cache_t weights_cache)
{
assert(node->num_inputs == 1);
assert(node->num_outputs == 1);

enum xnn_status status;
switch (node->compute_type) {
case xnn_compute_type_s32:
status = xnn_create_pop_count_nc_s32(
node->flags,
&opdata->operator_objects[0]);
break;
default:
XNN_UNREACHABLE;
}
return status;
}

static enum xnn_status reshape_pop_count_operator(
struct xnn_operator_data* opdata,
struct xnn_value* values,
size_t num_values,
pthreadpool_t threadpool)
{
const uint32_t input_id = opdata->inputs[0];
assert(input_id < num_values);
const size_t batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
const size_t num_input_dims = values[input_id].shape.num_dims;
const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
const size_t old_workspace_size = opdata->workspace_size;
enum xnn_status status = xnn_status_invalid_state;

switch (opdata->operator_objects[0]->type) {
case xnn_operator_type_pop_count_nc_s32:
status = xnn_reshape_pop_count_nc_s32(
opdata->operator_objects[0],
batch_size,
channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
threadpool);
break;
default:
XNN_UNREACHABLE;
}
if (status != xnn_status_success) {
return status;
}
return resize_unary_elementwise_output_tensor(opdata, values, num_values, old_workspace_size, threadpool);
}

static enum xnn_status setup_pop_count_operator(
const struct xnn_operator_data* opdata,
const struct xnn_value* values,
size_t num_values,
pthreadpool_t threadpool)
{
const uint32_t input_id = opdata->inputs[0];
assert(input_id != XNN_INVALID_VALUE_ID);
assert(input_id < num_values);

const uint32_t output_id = opdata->outputs[0];
assert(output_id != XNN_INVALID_VALUE_ID);
assert(output_id < num_values);

const struct xnn_value* input_value = values + input_id;
const void* input_data = input_value->data;
assert(input_data != NULL);

const struct xnn_value* output_value = values + output_id;
void* output_data = output_value->data;
assert(output_data != NULL);

switch (opdata->operator_objects[0]->type) {
case xnn_operator_type_pop_count_nc_s32:
return xnn_setup_pop_count_nc_s32(
opdata->operator_objects[0],
input_data,
output_data);
default:
XNN_UNREACHABLE;
}
}

enum xnn_status xnn_define_pop_count(
xnn_subgraph_t subgraph,
uint32_t input_id,
uint32_t output_id,
uint32_t flags)
{
enum xnn_status status;
if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_square_root)) != xnn_status_success) {
return status;
}

if (input_id >= subgraph->num_values) {
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
xnn_node_type_to_string(xnn_node_type_pop_count), input_id);
return xnn_status_invalid_parameter;
}

const struct xnn_value* input_value = &subgraph->values[input_id];
status = xnn_subgraph_check_input_type_dense(xnn_node_type_pop_count, input_id, input_value);
if (status != xnn_status_success) {
return status;
}

switch (input_value->datatype) {
case xnn_datatype_int32:
break;
default:
xnn_log_error(
"failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
xnn_node_type_to_string(xnn_node_type_pop_count), input_id,
xnn_datatype_to_string(input_value->datatype), input_value->datatype);
return xnn_status_invalid_parameter;
}

status = xnn_subgraph_check_output_node_id(xnn_node_type_pop_count, output_id, subgraph->num_values);
if (status != xnn_status_success) {
return status;
}

const struct xnn_value* output_value = &subgraph->values[output_id];
status = xnn_subgraph_check_output_type_dense(xnn_node_type_pop_count, output_id, output_value);
if (status != xnn_status_success) {
return status;
}

enum xnn_compute_type compute_type = xnn_compute_type_invalid;
switch (output_value->datatype) {
case xnn_datatype_int32:
compute_type = xnn_compute_type_s32;
break;
default:
xnn_log_error(
"failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
xnn_node_type_to_string(xnn_node_type_pop_count), output_id,
xnn_datatype_to_string(output_value->datatype), output_value->datatype);
return xnn_status_invalid_parameter;
}

struct xnn_node* node = xnn_subgraph_new_node(subgraph);
if (node == NULL) {
return xnn_status_out_of_memory;
}

node->type = xnn_node_type_pop_count;
node->compute_type = compute_type;
node->num_inputs = 1;
node->inputs[0] = input_id;
node->num_outputs = 1;
node->outputs[0] = output_id;
node->flags = flags;

node->create = create_pop_count_operator;
node->reshape = reshape_pop_count_operator;
node->setup = setup_pop_count_operator;

return xnn_status_success;
}
1 change: 1 addition & 0 deletions src/xnnpack/node-type.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ enum xnn_node_type {
xnn_node_type_multiply2,
xnn_node_type_negate,
xnn_node_type_prelu,
xnn_node_type_pop_count,
xnn_node_type_reciprocal_square_root,
xnn_node_type_reshape_2d,
xnn_node_type_rope,
Expand Down
1 change: 1 addition & 0 deletions test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2029,6 +2029,7 @@ xnnpack_cc_library(
"static_slice",
"static_transpose",
"tanh",
"pop_count",
]]

xnnpack_cc_library(
Expand Down
Loading

0 comments on commit 63fb0de

Please sign in to comment.