Skip to content

Commit

Permalink
Wired QP8_QB4W Subgraph APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mcr229 committed Sep 4, 2024
1 parent eefcad9 commit 84c8708
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 16 deletions.
72 changes: 56 additions & 16 deletions src/subgraph/fully-connected.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,22 +199,40 @@ static enum xnn_status create_fully_connected_operator(
}
break;
case xnn_datatype_qbint4:
status = xnn_create_fully_connected_nc_qd8_f32_qb4w(
input_channels,
output_channels,
/*input_stride=*/input_channels,
/*output_stride=*/output_channels,
/*block_size=*/values[filter_id].quantization.block_size,
/*kernel_zero_point=*/values[filter_id].quantization.zero_point,
(const uint16_t*) values[filter_id].quantization.blockwise_scale,
kernel_data,
bias_data,
node->activation.output_min,
node->activation.output_max,
node->flags,
code_cache,
weights_cache,
&opdata->operator_objects[0]);
switch (input_datatype) {
case xnn_datatype_qdint8:
status = xnn_create_fully_connected_nc_qd8_f32_qb4w(
input_channels,
output_channels,
/*input_stride=*/input_channels,
/*output_stride=*/output_channels,
/*block_size=*/values[filter_id].quantization.block_size,
/*kernel_zero_point=*/values[filter_id].quantization.zero_point,
(const uint16_t*) values[filter_id].quantization.blockwise_scale,
kernel_data,
bias_data,
node->activation.output_min,
node->activation.output_max,
node->flags,
code_cache,
weights_cache,
&opdata->operator_objects[0]);
break;
case xnn_datatype_qpint8:
status = xnn_create_fully_connected_nc_qp8_f32_qb4w(
input_channels, output_channels,
/*input_stride=*/input_channels,
/*output_stride=*/output_channels,
/*block_size=*/values[filter_id].quantization.block_size,
/*kernel_zero_point=*/values[filter_id].quantization.zero_point,
values[filter_id].quantization.blockwise_scale, kernel_data,
bias_data, node->activation.output_min,
node->activation.output_max, node->flags, code_cache,
weights_cache, &opdata->operator_objects[0]);
break;
default:
XNN_UNREACHABLE;
}
break;
case xnn_datatype_qcint4:
switch (input_datatype) {
Expand Down Expand Up @@ -555,6 +573,12 @@ static enum xnn_status reshape_fully_connected_operator(
batch_size,
threadpool);
break;
case xnn_operator_type_fully_connected_nc_qp8_f32_qb4w:
status = xnn_reshape_fully_connected_nc_qp8_f32_qb4w(
opdata->operator_objects[0],
batch_size,
threadpool);
break;
case xnn_operator_type_fully_connected_nc_qs8:
status = xnn_reshape_fully_connected_nc_qs8(
opdata->operator_objects[0],
Expand Down Expand Up @@ -747,6 +771,15 @@ static enum xnn_status setup_fully_connected_operator(
input_data,
output_data);
}
case xnn_operator_type_fully_connected_nc_qp8_f32_qb4w:
{
assert(kernel_data == NULL);
assert(bias_data == NULL);
return xnn_setup_fully_connected_nc_qp8_f32_qb4w(
opdata->operator_objects[0],
input_data,
output_data);
}
case xnn_operator_type_fully_connected_nc_qs8:
assert(kernel_data == NULL);
assert(bias_data == NULL);
Expand Down Expand Up @@ -833,6 +866,11 @@ static inline enum xnn_compute_type validate_datatypes_with_bias(
output_datatype == xnn_datatype_fp16)
{
return xnn_compute_type_qd8_to_fp16;
} else if (input_datatype == xnn_datatype_qpint8 &&
bias_datatype == xnn_datatype_fp32 &&
output_datatype == xnn_datatype_fp32)
{
return xnn_compute_type_qp8_to_fp32;
}
break;
case xnn_datatype_qcint8:
Expand Down Expand Up @@ -919,6 +957,8 @@ static inline enum xnn_compute_type validate_datatypes_without_bias(
return xnn_compute_type_qd8_to_fp32;
} else if (input_datatype == xnn_datatype_qdint8 && output_datatype == xnn_datatype_fp16) {
return xnn_compute_type_qd8_to_fp16;
} else if (input_datatype == xnn_datatype_qpint8 && output_datatype == xnn_datatype_fp32) {
return xnn_compute_type_qp8_to_fp32;
}
break;
case xnn_datatype_qcint8:
Expand Down
180 changes: 180 additions & 0 deletions test/fully-connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,9 @@ class FullyConnectedTestF32QC4W : public FullyConnectedTestBase<float, uint8_t,

class FullyConnectedTestF32QC8W : public FullyConnectedTestBase<float, int8_t, float> {
};
class FullyConnectedTestQP8F32QB4W
: public FullyConnectedTestBase<int8_t, uint8_t, float, float, true> {};


using FullyConnectedTestQC8 = QuantizedFullyConnectedTestBase<int8_t>;
using FullyConnectedTestQS8 = QuantizedFullyConnectedTestBase<int8_t>;
Expand Down Expand Up @@ -4228,3 +4231,180 @@ TEST_F(FullyConnectedTestF32, reshape)
size_t num_output_elements = std::accumulate(new_input_dims.begin(), new_input_dims.end() - 1, size_t{1}, std::multiplies<size_t>()) * kernel_shape->dim[0];
ASSERT_EQ(runtime->values[node->outputs[0]].size, num_output_elements * sizeof(float));
}

TEST_F(FullyConnectedTestQP8F32QB4W, define)
{
size_t block_size = 32;
input_channels = round_up_po2(input_channels, block_size);

input_dims[input_dims.size() - 1] = input_channels;
kernel_dims[kernel_dims.size() - 1] = input_channels;

ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

xnn_subgraph_t subgraph = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph));
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);

uint32_t input_id = XNN_INVALID_VALUE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_dynamically_quantized_tensor_value(
subgraph, xnn_datatype_qpint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(),
/*external_id=*/0, /*flags=*/0, &input_id));
ASSERT_NE(input_id, XNN_INVALID_VALUE_ID);

// Adjust number of kernel elements for QB4W. input_channels should be padded to byte boundary, hence even.
const size_t rounded_input_channels = round_up_po2(input_channels, 2);
kernel = std::vector<uint8_t>(output_channels * rounded_input_channels);
const uint8_t kernel_zero_point = 8;
std::vector<uint16_t> kernel_scale(output_channels * block_size);
std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return math_cvt_bf16_fp32(scale_dist(rng)); });
uint32_t kernel_id = XNN_INVALID_VALUE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_blockwise_quantized_tensor_value(
subgraph, xnn_datatype_qbint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(),
/*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id));

uint32_t bias_id = XNN_INVALID_VALUE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(),
/*external_id=*/2, /*flags=*/0, &bias_id));

uint32_t output_id = XNN_INVALID_VALUE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
/*external_id=*/3, /*flags=*/0, &output_id));
ASSERT_NE(output_id, XNN_INVALID_VALUE_ID);

ASSERT_EQ(
xnn_status_success,
xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0));

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_fully_connected);
ASSERT_EQ(node->compute_type, xnn_compute_type_qp8_to_fp32);
ASSERT_EQ(node->activation.output_min, output_min);
ASSERT_EQ(node->activation.output_max, output_max);
ASSERT_EQ(node->num_inputs, 3);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->inputs[1], kernel_id);
ASSERT_EQ(node->inputs[2], bias_id);
ASSERT_EQ(node->num_outputs, 1);
ASSERT_EQ(node->outputs[0], output_id);
ASSERT_EQ(node->flags, 0);
}

TEST_F(FullyConnectedTestQP8F32QB4W, internally_allocated_dynamic_quantization_parameters)
{
size_t block_size = 32;
input_channels = round_up_po2(input_channels, block_size);

input_dims[input_dims.size() - 1] = input_channels;
kernel_dims[kernel_dims.size() - 1] = input_channels;

ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
xnn_subgraph_t subgraph = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph));
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
uint32_t input_id = XNN_INVALID_NODE_ID;
std::vector<float> convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float));
std::vector<int8_t> operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES);
std::vector<float> subgraph_output(batch_size * output_channels);
std::vector<float> operator_output(batch_size * output_channels);
std::fill(operator_output.begin(), operator_output.end(), nanf(""));
std::fill(subgraph_output.begin(), subgraph_output.end(), nanf(""));
std::vector<xnn_dynamic_quantization_params> quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS);

std::vector<uint16_t> kernel_scale(output_channels * block_size);
std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return math_cvt_bf16_fp32(scale_dist(rng)); });
std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); });
std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_dynamic_quantization_params{w8dist(rng), f32dist(rng)}; });

const size_t rounded_input_channels = round_up_po2(input_channels, 2);
kernel = std::vector<uint8_t>(output_channels * rounded_input_channels);

const float output_min = -std::numeric_limits<float>::infinity();
const float output_max = std::numeric_limits<float>::infinity();

const uint8_t kernel_zero_point = 8;

// Call operator API.
xnn_operator_t convert_op = nullptr;
xnn_operator_t fc_op = nullptr;
xnn_status status = xnn_create_convert_nc_f32_qp8(
/*flags=*/0, &convert_op);
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convert_op(convert_op, xnn_delete_operator);
if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}
ASSERT_EQ(xnn_status_success, status);
ASSERT_NE(nullptr, convert_op);
ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8(convert_op, batch_size, input_channels, input_channels, /*threadpool=*/nullptr));
ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qp8(convert_op, convert_input.data(),
operator_dq_data.data()));
ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr));

status = xnn_create_fully_connected_nc_qp8_f32_qb4w(
input_channels, output_channels, input_channels, output_channels, block_size, kernel_zero_point, kernel_scale.data(),
kernel.data(), bias.data(), output_min, output_max,
/*flags=*/0, nullptr, nullptr, &fc_op);
std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fc_op(fc_op, xnn_delete_operator);

if (status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
}

ASSERT_EQ(xnn_status_success, status);
ASSERT_NE(nullptr, fc_op);
ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qb4w(fc_op, batch_size, /*threadpool=*/nullptr));
ASSERT_EQ(xnn_status_success,
xnn_setup_fully_connected_nc_qp8_f32_qb4w(fc_op, operator_dq_data.data(), operator_output.data()));
ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr));

// Call subgraph API.
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0,
/*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t dq_quantized_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_dynamically_quantized_tensor_value(
subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(),
XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id));
ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID);
uint32_t kernel_id = XNN_INVALID_VALUE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_blockwise_quantized_tensor_value(
subgraph, xnn_datatype_qbint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(),
/*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id));

uint32_t bias_id = XNN_INVALID_VALUE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(),
/*external_id=*/2, /*flags=*/0, &bias_id));
uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ( xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr,
/*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

xnn_runtime_t runtime = nullptr;
ASSERT_EQ(xnn_status_success, xnn_define_convert(subgraph, input_id, dq_quantized_id, /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_GEMM));
ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id,
kernel_id, bias_id, output_id, /*flags=*/0));
ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
ASSERT_NE(nullptr, runtime);
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
std::array<xnn_external_value, 2> external = {
xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}};
ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));
}

0 comments on commit 84c8708

Please sign in to comment.