Skip to content

Commit c55722b

Browse files
committed
Enhancing Performance of INT4 Data Transformation
This pull request speeds up the transformation of int4 data in the osv32_isv2 layout to block_q4_0x4 layers by utilizing ARM NEON and OpenMP. The previous version took about 7 to 8 milliseconds to transform a 3072x8192 matrix, while the current patch takes only 2 to 4 milliseconds. **Self-evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghyeon Jeong <[email protected]>
1 parent b332f3f commit c55722b

File tree

3 files changed

+186
-0
lines changed

3 files changed

+186
-0
lines changed

nntrainer/tensor/cpu_backend/arm/arm_compute_backend.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,13 @@ void create_q4_0_weights(const uint8_t *int4_weight, uint8_t *q4_0_weight) {
528528
void transform_q4_0x_from_int4(size_t N, size_t K, const uint8_t *osv32_weights,
529529
const uint16_t *osv32_scales,
530530
size_t scale_group_size, void *dst_q4_0x) {
531+
#if defined(__aarch64__) || defined(_M_ARM64)
532+
neon::transform_int4_osv32_isv2_to_q4_0x4(N, K, osv32_weights, osv32_scales,
533+
scale_group_size, dst_q4_0x);
534+
#else
531535
Q4_0Utils::transformQ4_0x_FromInt4(N, K, osv32_weights, osv32_scales,
532536
scale_group_size, 4, dst_q4_0x);
537+
#endif
533538
}
534539

535540
} /* namespace nntrainer */

nntrainer/tensor/cpu_backend/arm/neon_impl.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#ifdef ARMV7
2828
#include <armv7_neon.h>
2929
#endif
30+
#include "nntr_ggml_impl_common.h"
3031
#include <fallback_internal.h>
3132
#include <util_func.h>
3233

@@ -1340,6 +1341,170 @@ void clamp(const float *input, float *output, size_t length, float lower_bound,
13401341
}
13411342
}
13421343

1344+
/**
1345+
* @brief Highly optimized version - processes 4 rows simultaneously
1346+
* and writes directly to output block, eliminating intermediate copies.
1347+
* Uses 128-bit NEON operations and prefetching for maximum throughput.
1348+
*/
1349+
inline static void neon_transform_4rows_to_q4_0x4(
1350+
const uint8_t *__restrict row0_ptr, const uint8_t *__restrict row1_ptr,
1351+
const uint8_t *__restrict row2_ptr, const uint8_t *__restrict row3_ptr,
1352+
uint16_t scale0, uint16_t scale1, uint16_t scale2, uint16_t scale3,
1353+
block_q4_0x4 *__restrict out) {
1354+
1355+
// Prefetch next cache lines
1356+
__builtin_prefetch(row0_ptr + 64, 0, 3);
1357+
__builtin_prefetch(row1_ptr + 64, 0, 3);
1358+
__builtin_prefetch(row2_ptr + 64, 0, 3);
1359+
__builtin_prefetch(row3_ptr + 64, 0, 3);
1360+
1361+
// Store scales directly
1362+
out->d[0] = scale0;
1363+
out->d[1] = scale1;
1364+
out->d[2] = scale2;
1365+
out->d[3] = scale3;
1366+
1367+
// Load 16 bytes from each row (strided by 32 bytes in source)
1368+
// For each row: load bytes at offsets 0, 32, 64, ..., 480 (16 values)
1369+
uint8_t r0[16], r1[16], r2[16], r3[16];
1370+
1371+
// Gather 16 bytes per row with stride 32
1372+
#pragma unroll
1373+
for (int j = 0; j < 16; j++) {
1374+
r0[j] = row0_ptr[j * 32];
1375+
r1[j] = row1_ptr[j * 32];
1376+
r2[j] = row2_ptr[j * 32];
1377+
r3[j] = row3_ptr[j * 32];
1378+
}
1379+
1380+
// Process all 4 rows with NEON
1381+
const uint8x8_t mask = vdup_n_u8(0x0F);
1382+
1383+
// Row 0
1384+
{
1385+
uint8x8_t lo = vld1_u8(r0);
1386+
uint8x8_t hi = vld1_u8(r0 + 8);
1387+
uint8x8_t v0 = vand_u8(lo, mask);
1388+
uint8x8_t v1 = vshr_n_u8(lo, 4);
1389+
uint8x8_t v2 = vand_u8(hi, mask);
1390+
uint8x8_t v3 = vshr_n_u8(hi, 4);
1391+
uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4));
1392+
uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4));
1393+
uint8x8x2_t zip = vzip_u8(even, odd);
1394+
// First half goes to qs[0..7], second half to qs[32..39]
1395+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[0]), zip.val[0]);
1396+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[32]), zip.val[1]);
1397+
}
1398+
1399+
// Row 1
1400+
{
1401+
uint8x8_t lo = vld1_u8(r1);
1402+
uint8x8_t hi = vld1_u8(r1 + 8);
1403+
uint8x8_t v0 = vand_u8(lo, mask);
1404+
uint8x8_t v1 = vshr_n_u8(lo, 4);
1405+
uint8x8_t v2 = vand_u8(hi, mask);
1406+
uint8x8_t v3 = vshr_n_u8(hi, 4);
1407+
uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4));
1408+
uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4));
1409+
uint8x8x2_t zip = vzip_u8(even, odd);
1410+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[8]), zip.val[0]);
1411+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[40]), zip.val[1]);
1412+
}
1413+
1414+
// Row 2
1415+
{
1416+
uint8x8_t lo = vld1_u8(r2);
1417+
uint8x8_t hi = vld1_u8(r2 + 8);
1418+
uint8x8_t v0 = vand_u8(lo, mask);
1419+
uint8x8_t v1 = vshr_n_u8(lo, 4);
1420+
uint8x8_t v2 = vand_u8(hi, mask);
1421+
uint8x8_t v3 = vshr_n_u8(hi, 4);
1422+
uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4));
1423+
uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4));
1424+
uint8x8x2_t zip = vzip_u8(even, odd);
1425+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[16]), zip.val[0]);
1426+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[48]), zip.val[1]);
1427+
}
1428+
1429+
// Row 3
1430+
{
1431+
uint8x8_t lo = vld1_u8(r3);
1432+
uint8x8_t hi = vld1_u8(r3 + 8);
1433+
uint8x8_t v0 = vand_u8(lo, mask);
1434+
uint8x8_t v1 = vshr_n_u8(lo, 4);
1435+
uint8x8_t v2 = vand_u8(hi, mask);
1436+
uint8x8_t v3 = vshr_n_u8(hi, 4);
1437+
uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4));
1438+
uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4));
1439+
uint8x8x2_t zip = vzip_u8(even, odd);
1440+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[24]), zip.val[0]);
1441+
vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[56]), zip.val[1]);
1442+
}
1443+
}
1444+
1445+
void transform_int4_osv32_isv2_to_q4_0x4(size_t N, size_t K,
1446+
const uint8_t *osv32_weights,
1447+
const uint16_t *osv32_scales,
1448+
size_t scale_group_size,
1449+
void *dst_q4_0x4) {
1450+
NNTR_THROW_IF((!(scale_group_size == 32 || scale_group_size == 64 ||
1451+
scale_group_size == 128)),
1452+
std::invalid_argument)
1453+
<< "Scale group size must be 32/64/128";
1454+
NNTR_THROW_IF(K % QK4_0 != 0, std::invalid_argument)
1455+
<< "K size must be divisable by QK4_0 (32)";
1456+
NNTR_THROW_IF(N % 4 != 0, std::invalid_argument)
1457+
<< "N size must be divisable by 4";
1458+
constexpr size_t ROW_BLOCK_SIZE = 32;
1459+
constexpr size_t Q4_0X_BLOCK_SIZE = 4;
1460+
1461+
const size_t rows_count_pad = align(N, ROW_BLOCK_SIZE);
1462+
const size_t columns_count_pad = align(K, ROW_BLOCK_SIZE);
1463+
const size_t column_blocks_count = columns_count_pad / 2;
1464+
const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE;
1465+
const size_t num_blocks_per_row = K / QK4_0;
1466+
1467+
block_q4_0x4 *dst_ptr = reinterpret_cast<block_q4_0x4 *>(dst_q4_0x4);
1468+
1469+
#pragma omp parallel for schedule(static)
1470+
for (size_t row_id = 0; row_id < N; row_id += Q4_0X_BLOCK_SIZE) {
1471+
const size_t row_block_id = row_id / ROW_BLOCK_SIZE;
1472+
const size_t i_in_block = row_id % ROW_BLOCK_SIZE;
1473+
const size_t row_base =
1474+
row_block_id * bytes_per_row_block_span + i_in_block;
1475+
1476+
// Output pointer for this row group
1477+
block_q4_0x4 *out =
1478+
dst_ptr + (row_id / Q4_0X_BLOCK_SIZE) * num_blocks_per_row;
1479+
1480+
// Precompute row pointers for fast inner loop
1481+
const uint8_t *row0_base = osv32_weights + row_base;
1482+
const uint8_t *row1_base = osv32_weights + row_base + 1;
1483+
const uint8_t *row2_base = osv32_weights + row_base + 2;
1484+
const uint8_t *row3_base = osv32_weights + row_base + 3;
1485+
1486+
for (size_t col_idx = 0; col_idx < K; col_idx += QK4_0) {
1487+
// Calculate weight offset: (col_idx / 2) * 32 = col_idx * 16
1488+
const size_t weight_offset = (col_idx / 2) * ROW_BLOCK_SIZE;
1489+
1490+
// Get scales for all 4 rows
1491+
const size_t scale_col = col_idx / scale_group_size;
1492+
const size_t scale_base = scale_col * rows_count_pad;
1493+
uint16_t s0 = osv32_scales[row_id + 0 + scale_base];
1494+
uint16_t s1 = osv32_scales[row_id + 1 + scale_base];
1495+
uint16_t s2 = osv32_scales[row_id + 2 + scale_base];
1496+
uint16_t s3 = osv32_scales[row_id + 3 + scale_base];
1497+
1498+
// Transform 4 rows directly to output
1499+
neon_transform_4rows_to_q4_0x4(
1500+
row0_base + weight_offset, row1_base + weight_offset,
1501+
row2_base + weight_offset, row3_base + weight_offset, s0, s1, s2, s3,
1502+
out);
1503+
out++;
1504+
}
1505+
}
1506+
}
1507+
13431508
#if defined(__aarch64__) || defined(_M_ARM64)
13441509
static inline void load_fp16_4_to_chunk(const uint16_t *src, float *dst,
13451510
int chunk_size) {

nntrainer/tensor/cpu_backend/arm/neon_impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,22 @@ void clamp(const T *input, T *output, size_t length,
773773
T lower_bound = std::numeric_limits<T>::lowest(),
774774
T upper_bound = std::numeric_limits<T>::max());
775775

776+
/**
777+
* @brief Transforms data from in-memory layout osv32_isv2 to block_q4_0x4
778+
* in-memory layout with ARM NEON optimization and OpenMP parallelization.
779+
* @param N number of rows
780+
* @param K number of columns
781+
* @param osv32_weights uint8_t* data of weights in osv32_isv2 layout
782+
* @param osv32_scales fp16* scales
783+
* @param scale_group_size group size (32 or 64 or 128)
784+
* @param dst_q4_0x4 void * output data in block_q4_0x4 layout
785+
*/
786+
void transform_int4_osv32_isv2_to_q4_0x4(size_t N, size_t K,
787+
const uint8_t *osv32_weights,
788+
const uint16_t *osv32_scales,
789+
size_t scale_group_size,
790+
void *dst_q4_0x4);
791+
776792
/// @note The structure should later be neon_impl_aarch64 and neon_impl_armv7l
777793
#if defined(__aarch64__) || defined(_M_ARM64)
778794
/**

0 commit comments

Comments
 (0)