|
27 | 27 | #ifdef ARMV7 |
28 | 28 | #include <armv7_neon.h> |
29 | 29 | #endif |
| 30 | +#include "nntr_ggml_impl_common.h" |
30 | 31 | #include <fallback_internal.h> |
31 | 32 | #include <util_func.h> |
32 | 33 |
|
@@ -1340,6 +1341,170 @@ void clamp(const float *input, float *output, size_t length, float lower_bound, |
1340 | 1341 | } |
1341 | 1342 | } |
1342 | 1343 |
|
| 1344 | +/** |
| 1345 | + * @brief Highly optimized version - processes 4 rows simultaneously |
| 1346 | + * and writes directly to output block, eliminating intermediate copies. |
| 1347 | + * Uses 128-bit NEON operations and prefetching for maximum throughput. |
| 1348 | + */ |
| 1349 | +inline static void neon_transform_4rows_to_q4_0x4( |
| 1350 | + const uint8_t *__restrict row0_ptr, const uint8_t *__restrict row1_ptr, |
| 1351 | + const uint8_t *__restrict row2_ptr, const uint8_t *__restrict row3_ptr, |
| 1352 | + uint16_t scale0, uint16_t scale1, uint16_t scale2, uint16_t scale3, |
| 1353 | + block_q4_0x4 *__restrict out) { |
| 1354 | + |
| 1355 | + // Prefetch next cache lines |
| 1356 | + __builtin_prefetch(row0_ptr + 64, 0, 3); |
| 1357 | + __builtin_prefetch(row1_ptr + 64, 0, 3); |
| 1358 | + __builtin_prefetch(row2_ptr + 64, 0, 3); |
| 1359 | + __builtin_prefetch(row3_ptr + 64, 0, 3); |
| 1360 | + |
| 1361 | + // Store scales directly |
| 1362 | + out->d[0] = scale0; |
| 1363 | + out->d[1] = scale1; |
| 1364 | + out->d[2] = scale2; |
| 1365 | + out->d[3] = scale3; |
| 1366 | + |
| 1367 | + // Load 16 bytes from each row (strided by 32 bytes in source) |
| 1368 | + // For each row: load bytes at offsets 0, 32, 64, ..., 480 (16 values) |
| 1369 | + uint8_t r0[16], r1[16], r2[16], r3[16]; |
| 1370 | + |
| 1371 | +// Gather 16 bytes per row with stride 32 |
| 1372 | +#pragma unroll |
| 1373 | + for (int j = 0; j < 16; j++) { |
| 1374 | + r0[j] = row0_ptr[j * 32]; |
| 1375 | + r1[j] = row1_ptr[j * 32]; |
| 1376 | + r2[j] = row2_ptr[j * 32]; |
| 1377 | + r3[j] = row3_ptr[j * 32]; |
| 1378 | + } |
| 1379 | + |
| 1380 | + // Process all 4 rows with NEON |
| 1381 | + const uint8x8_t mask = vdup_n_u8(0x0F); |
| 1382 | + |
| 1383 | + // Row 0 |
| 1384 | + { |
| 1385 | + uint8x8_t lo = vld1_u8(r0); |
| 1386 | + uint8x8_t hi = vld1_u8(r0 + 8); |
| 1387 | + uint8x8_t v0 = vand_u8(lo, mask); |
| 1388 | + uint8x8_t v1 = vshr_n_u8(lo, 4); |
| 1389 | + uint8x8_t v2 = vand_u8(hi, mask); |
| 1390 | + uint8x8_t v3 = vshr_n_u8(hi, 4); |
| 1391 | + uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4)); |
| 1392 | + uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4)); |
| 1393 | + uint8x8x2_t zip = vzip_u8(even, odd); |
| 1394 | + // First half goes to qs[0..7], second half to qs[32..39] |
| 1395 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[0]), zip.val[0]); |
| 1396 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[32]), zip.val[1]); |
| 1397 | + } |
| 1398 | + |
| 1399 | + // Row 1 |
| 1400 | + { |
| 1401 | + uint8x8_t lo = vld1_u8(r1); |
| 1402 | + uint8x8_t hi = vld1_u8(r1 + 8); |
| 1403 | + uint8x8_t v0 = vand_u8(lo, mask); |
| 1404 | + uint8x8_t v1 = vshr_n_u8(lo, 4); |
| 1405 | + uint8x8_t v2 = vand_u8(hi, mask); |
| 1406 | + uint8x8_t v3 = vshr_n_u8(hi, 4); |
| 1407 | + uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4)); |
| 1408 | + uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4)); |
| 1409 | + uint8x8x2_t zip = vzip_u8(even, odd); |
| 1410 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[8]), zip.val[0]); |
| 1411 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[40]), zip.val[1]); |
| 1412 | + } |
| 1413 | + |
| 1414 | + // Row 2 |
| 1415 | + { |
| 1416 | + uint8x8_t lo = vld1_u8(r2); |
| 1417 | + uint8x8_t hi = vld1_u8(r2 + 8); |
| 1418 | + uint8x8_t v0 = vand_u8(lo, mask); |
| 1419 | + uint8x8_t v1 = vshr_n_u8(lo, 4); |
| 1420 | + uint8x8_t v2 = vand_u8(hi, mask); |
| 1421 | + uint8x8_t v3 = vshr_n_u8(hi, 4); |
| 1422 | + uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4)); |
| 1423 | + uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4)); |
| 1424 | + uint8x8x2_t zip = vzip_u8(even, odd); |
| 1425 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[16]), zip.val[0]); |
| 1426 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[48]), zip.val[1]); |
| 1427 | + } |
| 1428 | + |
| 1429 | + // Row 3 |
| 1430 | + { |
| 1431 | + uint8x8_t lo = vld1_u8(r3); |
| 1432 | + uint8x8_t hi = vld1_u8(r3 + 8); |
| 1433 | + uint8x8_t v0 = vand_u8(lo, mask); |
| 1434 | + uint8x8_t v1 = vshr_n_u8(lo, 4); |
| 1435 | + uint8x8_t v2 = vand_u8(hi, mask); |
| 1436 | + uint8x8_t v3 = vshr_n_u8(hi, 4); |
| 1437 | + uint8x8_t even = vorr_u8(v0, vshl_n_u8(v2, 4)); |
| 1438 | + uint8x8_t odd = vorr_u8(v1, vshl_n_u8(v3, 4)); |
| 1439 | + uint8x8x2_t zip = vzip_u8(even, odd); |
| 1440 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[24]), zip.val[0]); |
| 1441 | + vst1_u8(reinterpret_cast<uint8_t *>(&out->qs[56]), zip.val[1]); |
| 1442 | + } |
| 1443 | +} |
| 1444 | + |
| 1445 | +void transform_int4_osv32_isv2_to_q4_0x4(size_t N, size_t K, |
| 1446 | + const uint8_t *osv32_weights, |
| 1447 | + const uint16_t *osv32_scales, |
| 1448 | + size_t scale_group_size, |
| 1449 | + void *dst_q4_0x4) { |
| 1450 | + NNTR_THROW_IF((!(scale_group_size == 32 || scale_group_size == 64 || |
| 1451 | + scale_group_size == 128)), |
| 1452 | + std::invalid_argument) |
| 1453 | + << "Scale group size must be 32/64/128"; |
| 1454 | + NNTR_THROW_IF(K % QK4_0 != 0, std::invalid_argument) |
| 1455 | + << "K size must be divisable by QK4_0 (32)"; |
| 1456 | + NNTR_THROW_IF(N % 4 != 0, std::invalid_argument) |
| 1457 | + << "N size must be divisable by 4"; |
| 1458 | + constexpr size_t ROW_BLOCK_SIZE = 32; |
| 1459 | + constexpr size_t Q4_0X_BLOCK_SIZE = 4; |
| 1460 | + |
| 1461 | + const size_t rows_count_pad = align(N, ROW_BLOCK_SIZE); |
| 1462 | + const size_t columns_count_pad = align(K, ROW_BLOCK_SIZE); |
| 1463 | + const size_t column_blocks_count = columns_count_pad / 2; |
| 1464 | + const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE; |
| 1465 | + const size_t num_blocks_per_row = K / QK4_0; |
| 1466 | + |
| 1467 | + block_q4_0x4 *dst_ptr = reinterpret_cast<block_q4_0x4 *>(dst_q4_0x4); |
| 1468 | + |
| 1469 | +#pragma omp parallel for schedule(static) |
| 1470 | + for (size_t row_id = 0; row_id < N; row_id += Q4_0X_BLOCK_SIZE) { |
| 1471 | + const size_t row_block_id = row_id / ROW_BLOCK_SIZE; |
| 1472 | + const size_t i_in_block = row_id % ROW_BLOCK_SIZE; |
| 1473 | + const size_t row_base = |
| 1474 | + row_block_id * bytes_per_row_block_span + i_in_block; |
| 1475 | + |
| 1476 | + // Output pointer for this row group |
| 1477 | + block_q4_0x4 *out = |
| 1478 | + dst_ptr + (row_id / Q4_0X_BLOCK_SIZE) * num_blocks_per_row; |
| 1479 | + |
| 1480 | + // Precompute row pointers for fast inner loop |
| 1481 | + const uint8_t *row0_base = osv32_weights + row_base; |
| 1482 | + const uint8_t *row1_base = osv32_weights + row_base + 1; |
| 1483 | + const uint8_t *row2_base = osv32_weights + row_base + 2; |
| 1484 | + const uint8_t *row3_base = osv32_weights + row_base + 3; |
| 1485 | + |
| 1486 | + for (size_t col_idx = 0; col_idx < K; col_idx += QK4_0) { |
| 1487 | + // Calculate weight offset: (col_idx / 2) * 32 = col_idx * 16 |
| 1488 | + const size_t weight_offset = (col_idx / 2) * ROW_BLOCK_SIZE; |
| 1489 | + |
| 1490 | + // Get scales for all 4 rows |
| 1491 | + const size_t scale_col = col_idx / scale_group_size; |
| 1492 | + const size_t scale_base = scale_col * rows_count_pad; |
| 1493 | + uint16_t s0 = osv32_scales[row_id + 0 + scale_base]; |
| 1494 | + uint16_t s1 = osv32_scales[row_id + 1 + scale_base]; |
| 1495 | + uint16_t s2 = osv32_scales[row_id + 2 + scale_base]; |
| 1496 | + uint16_t s3 = osv32_scales[row_id + 3 + scale_base]; |
| 1497 | + |
| 1498 | + // Transform 4 rows directly to output |
| 1499 | + neon_transform_4rows_to_q4_0x4( |
| 1500 | + row0_base + weight_offset, row1_base + weight_offset, |
| 1501 | + row2_base + weight_offset, row3_base + weight_offset, s0, s1, s2, s3, |
| 1502 | + out); |
| 1503 | + out++; |
| 1504 | + } |
| 1505 | + } |
| 1506 | +} |
| 1507 | + |
1343 | 1508 | #if defined(__aarch64__) || defined(_M_ARM64) |
1344 | 1509 | static inline void load_fp16_4_to_chunk(const uint16_t *src, float *dst, |
1345 | 1510 | int chunk_size) { |
|
0 commit comments