From fe4449b2e4914d4a56702ba2ba1dc39d91974375 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 30 Nov 2024 17:39:17 +0000 Subject: [PATCH 001/751] Improve: `#pragma region` dashes --- include/stringzilla/stringzilla.h | 6 +++--- scripts/bench_sort.cpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 3721c5b0..90a4b7e9 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -1188,7 +1188,7 @@ SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t l #endif #endif // SZ_USE_ARM_SVE -#pragma region Hardware - Specific API +#pragma region Hardware Specific API #if SZ_USE_X86_AVX512 @@ -4458,7 +4458,7 @@ SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t windo * * 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES * * 2020 TigerLake: VP2INTERSECT */ -#pragma region AVX - 512 Implementation +#pragma region AVX512 Implementation #if SZ_USE_X86_AVX512 #pragma GCC push_options @@ -6274,7 +6274,7 @@ SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) /* * @brief Pick the right implementation for the string search algorithms. */ -#pragma region Compile - Time Dispatching +#pragma region Compile Time Dispatching SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index b70409ca..f46be4a3 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -21,7 +21,7 @@ using strings_t = std::vector; using idx_t = sz_size_t; using permute_t = std::vector; -#pragma region - C callbacks +#pragma region C callbacks static char const *get_start(sz_sequence_t const *array_c, sz_size_t i) { strings_t const &array = *reinterpret_cast(array_c->handle); From 585f7d5dd8940a045fce616c23fbe147e1a1b3f5 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 30 Nov 2024 17:41:03 +0000 Subject: [PATCH 002/751] Fix: `sz_look_up_transform_avx512` declaration --- include/stringzilla/stringzilla.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 90a4b7e9..e1c1d910 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -1202,8 +1202,8 @@ SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_tranform */ -SZ_PUBLIC void sz_look_up_tranform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); +/** @copydoc sz_look_up_transform */ +SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ From 715ad100d6e667f5c34ad60752ef6f34f90c993d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 2 Dec 2024 22:03:45 +0000 Subject: [PATCH 003/751] Docs: Levenshtein tutorial in Jupyter --- scripts/test_levenshtein.ipynb | 342 +++++++++++++++++++++++++++------ 1 file changed, 283 insertions(+), 59 deletions(-) diff --git a/scripts/test_levenshtein.ipynb b/scripts/test_levenshtein.ipynb index fc8f9bf6..4718c386 100644 --- a/scripts/test_levenshtein.ipynb +++ b/scripts/test_levenshtein.ipynb @@ -1,29 +1,52 @@ { "cells": [ { - "cell_type": "code", - "execution_count": 25, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exploring the Impact of Evaluation Order on Edit Distance Algorithms\n", + "\n", + "Removing data-dependencies in the Wagner-Fisher, Needleman-Wunsch, Smith-Waterman, and Gotoh Dynamic Programming algorithms to explain the hardware-accelerated variants in StringZilla." + ] + }, + { + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "import numpy as np\n", - "import random" + "## Levenshtein Distance\n", + "\n", + "Levenshtein edit distance is one of the most broadly studied string similarity metrics.\n", + "It is defined as the minimum number of single-character insertions, deletions, and substitutions required to change one string into another.\n", + "The Levenshtein distance between two strings is calculated using dynamic programming algorithms, such as the Wagner-Fisher algorithm, and its variations for Bioinformatics: \n", + "\n", + "- Needleman-Wunsch for global alignment with substitution matrices, \n", + "- Smith-Waterman for local alignment with substitution matrices, \n", + "- Gotoh for different penalties for gap opening and extensions.\n", + "\n", + "Given the shared nature of these algorithms, the same tricks can be applied to all of them to improve their performance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Exploring the Impact of Evaluation Order on the Wagner Fisher Algorithm for Levenshtein Edit Distance" + "## Warner-Fisher Algorithm\n", + "\n", + "Wagner-Fisher algorithm, in its most naive form, has a time and space complexity of $O(NM)$, where $N$ and $M$ are the lengths of the two strings being compared.\n", + "A rectangular matrix of size $(N+1) \\times (M+1)$ is created to store the edit distances between all prefixes of the two strings.\n", + "The first row and column are, naturally, initialized with ${0, 1, 2, ..., N}$ and ${0, 1, 2, ..., M}$ respectively." ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "def algo_v0(s1, s2) -> int:\n", + "from typing import Tuple\n", + "import numpy as np # NumPy for matrices\n", + "\n", + "def wagner_fisher(s1: str, s2: str) -> Tuple[int, np.ndarray]:\n", " # Create a matrix of size (len(s1)+1) x (len(s2)+1)\n", " matrix = np.zeros((len(s1) + 1, len(s2) + 1), dtype=int)\n", "\n", @@ -38,12 +61,12 @@ " for j in range(1, len(s2) + 1):\n", " substitution_cost = s1[i - 1] != s2[j - 1]\n", " matrix[i, j] = min(\n", - " matrix[i - 1, j] + 1, # Deletion\n", - " matrix[i, j - 1] + 1, # Insertion\n", - " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " matrix[i - 1, j] + 1, #? Deletion cost\n", + " matrix[i, j - 1] + 1, #? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", " )\n", "\n", - " # Return the Levenshtein distance\n", + " # The distance will be placed in the bottom right corner of the matrix\n", " return matrix[len(s1), len(s2)], matrix" ] }, @@ -51,25 +74,32 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Accelerating this exact algorithm isn't trivial, is the `matrix[i, j]` value has a dependency on the `matrix[i, j-1]` value.\n", - "So we can't brute-force accelerate the inner loop.\n", - "Instead, we can show that we can evaluate the matrix in a different order, and still get the same result." + "This algorithm is almost never recommended for practical use, as it has a quadratic space complexity.\n", + "It's trivial to see that the space complexity can be reduced to $O(min(N, M))$ by only storing the last two rows of the matrix, but we want to keep the entire matrix as a reference to allow debugging and visualization." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "![](https://mathworld.wolfram.com/images/eps-svg/SkewDiagonal_1000.svg)" + "## Diagonal Evaluation Order\n", + "\n", + "Accelerating this exact algorithm with SIMD instructions isn't trivial, is the `matrix[i, j]` value has a dependency on the `matrix[i, j - 1]` value.\n", + "So we can't brute-force accelerate the inner loop.\n", + "Instead, we can show that we can evaluate the matrix in a different order, and still get the same result.\n", + "\n", + "![Skewed Diagonals Evaluation Order](https://mathworld.wolfram.com/images/eps-svg/SkewDiagonal_1000.svg)\n", + "\n", + "But before complicating things too much, let's start with a simple case - when both strings have identical lengths and the DP matrix has a square shape." ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "def algo_v1(s1, s2, verbose: bool = False) -> int:\n", + "def square_skewed_diagonals(s1: str, s2: str, verbose: bool = False) -> Tuple[int, np.ndarray]:\n", " assert len(s1) == len(s2), \"First define an algo for square matrices!\"\n", " # Create a matrix of size (len(s1)+1) x (len(s2)+1)\n", " matrix = np.zeros((len(s1) + 1, len(s2) + 1), dtype=int)\n", @@ -83,35 +113,45 @@ "\n", " # Number of rows and columns in the square matrix.\n", " n = len(s1) + 1\n", + " \n", + " # Number of diagonals and skewed diagonals in the square matrix of size (n x n).\n", " skew_diagonals_count = 2 * n - 1\n", - " # Compute Levenshtein distance\n", - " for skew_diagonal_idx in range(2, skew_diagonals_count):\n", - " skew_diagonal_length = (skew_diagonal_idx + 1) if skew_diagonal_idx < n else (2*n - skew_diagonal_idx - 1)\n", + " \n", + " # Populate the matrix in 2 separate loops: for the top left triangle and for the bottom right triangle.\n", + " for skew_diagonal_idx in range(2, n):\n", + " skew_diagonal_length = skew_diagonal_idx + 1\n", + " for offset_within_skew_diagonal in range(1, skew_diagonal_length - 1):\n", + " # If we haven't passed the main skew diagonal yet, \n", + " # then we have to skip the first and the last operation,\n", + " # as those are already pre-populated and form the first column \n", + " # and the first row of the Levenshtein matrix respectively.\n", + " i = skew_diagonal_idx - offset_within_skew_diagonal\n", + " j = offset_within_skew_diagonal\n", + " if verbose:\n", + " print(f\"top left triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " substitution_cost = s1[i - 1] != s2[j - 1]\n", + " matrix[i, j] = min(\n", + " matrix[i - 1, j] + 1, #? Deletion cost\n", + " matrix[i, j - 1] + 1, #? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", + " )\n", + " \n", + " # Now the bottom right triangle of the matrix.\n", + " for skew_diagonal_idx in range(n, skew_diagonals_count):\n", + " skew_diagonal_length = 2*n - skew_diagonal_idx - 1\n", " for offset_within_skew_diagonal in range(skew_diagonal_length):\n", - " if skew_diagonal_idx < n:\n", - " # If we passed the main skew diagonal yet, \n", - " # Then we have to skip the first and the last operation,\n", - " # as those are already pre-populated and form the first column \n", - " # and the first row of the Levenshtein matrix respectively.\n", - " if offset_within_skew_diagonal == 0 or offset_within_skew_diagonal + 1 == skew_diagonal_length:\n", - " continue \n", - " i = skew_diagonal_idx - offset_within_skew_diagonal\n", - " j = offset_within_skew_diagonal\n", - " if verbose:\n", - " print(f\"top left triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", - " else:\n", - " i = n - offset_within_skew_diagonal - 1\n", - " j = skew_diagonal_idx - n + offset_within_skew_diagonal + 1\n", - " if verbose:\n", - " print(f\"bottom right triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " i = n - offset_within_skew_diagonal - 1\n", + " j = skew_diagonal_idx - n + offset_within_skew_diagonal + 1\n", + " if verbose:\n", + " print(f\"bottom right triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", " substitution_cost = s1[i - 1] != s2[j - 1]\n", " matrix[i, j] = min(\n", - " matrix[i - 1, j] + 1, # Deletion\n", - " matrix[i, j - 1] + 1, # Insertion\n", - " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " matrix[i - 1, j] + 1, #? Deletion cost\n", + " matrix[i, j - 1] + 1, #? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", " )\n", "\n", - " # Return the Levenshtein distance\n", + " # Similarly, the distance will be placed in the bottom right corner of the matrix\n", " return matrix[len(s1), len(s2)], matrix" ] }, @@ -124,16 +164,17 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ + "import random\n", "for _ in range(10):\n", " s1 = ''.join(random.choices(\"ab\", k=50))\n", " s2 = ''.join(random.choices(\"ab\", k=50))\n", - " d0, _ = algo_v0(s1, s2)\n", - " d1, _ = algo_v1(s1, s2)\n", - " assert d0 == d1 " + " d0, _ = wagner_fisher(s1, s2)\n", + " d1, _ = square_skewed_diagonals(s1, s2)\n", + " assert d0 == d1" ] }, { @@ -146,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -154,7 +195,7 @@ "text/plain": [ "('listen',\n", " 'silent',\n", - " 'distance = 4',\n", + " 'distance = np.int64(4)',\n", " array([[0, 1, 2, 3, 4, 5, 6],\n", " [1, 1, 2, 2, 3, 4, 5],\n", " [2, 2, 1, 2, 3, 4, 5],\n", @@ -164,7 +205,7 @@ " [6, 5, 5, 5, 4, 3, 4]]))" ] }, - "execution_count": 29, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -174,13 +215,13 @@ "s2 = \"silent\"\n", "# s1 = ''.join(random.choices(\"abcd\", k=100))\n", "# s2 = ''.join(random.choices(\"abcd\", k=100))\n", - "distance, baseline = algo_v0(s1, s2)\n", + "distance, baseline = wagner_fisher(s1, s2)\n", "s1, s2, f\"{distance = }\", baseline" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -191,7 +232,7 @@ " array([0, 0, 0, 0, 0, 0, 0], dtype=uint64))" ] }, - "execution_count": 30, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -233,7 +274,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -244,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -258,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -269,7 +310,7 @@ " array([6, 4, 3, 2, 3, 4, 6], dtype=uint64))" ] }, - "execution_count": 33, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -306,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -317,7 +358,7 @@ " array([4, 5, 4, 5, 5, 5, 6], dtype=uint64))" ] }, - "execution_count": 34, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -342,12 +383,195 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "assert distance == following[0], f\"{distance = } != {following[0] = }\"" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generalizing to Non-Square Matrices" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def skewed_diagonals(s1, s2, verbose: bool = False) -> int:\n", + " shorter, longer = (s1, s2) if len(s1) < len(s2) else (s2, s1) \n", + " shorter_dim = len(shorter) + 1\n", + " longer_dim = len(longer) + 1\n", + " # Create a matrix of size (len(s1)+1) x (len(s2)+1)\n", + " matrix = np.zeros((len(shorter) + 1, len(longer) + 1), dtype=int)\n", + " matrix[:, :] = 99\n", + "\n", + " # Initialize the first column and first row of the matrix\n", + " for i in range(shorter_dim):\n", + " matrix[i, 0] = i\n", + " for j in range(longer_dim):\n", + " matrix[0, j] = j\n", + "\n", + " # Let's say we are dealing with 6 and 9 letter words.\n", + " # The matrix will have size 7 x 10, parameterized as (shorter_dim x longer_dim).\n", + " # It will have:\n", + " # - 8 diagonals of increasing length, at positions: 0, 1, 2, 3, 4, 5, 6, 7.\n", + " # - 2 diagonals of fixed length, at positions: 8, 9.\n", + " # - 8 diagonals of decreasing length, at positions: 10, 11, 12, 13, 14, 15, 16, 17.\n", + " skew_diagonals_count = 2 * longer_dim - 1\n", + "\n", + " # Same as with square matrices, the 0th diagonal contains - just one element - zero - skipping it.\n", + " # Same as with square matrices, the 1st diagonal contains the values 1 and 1 - skipping it.\n", + " # Now let's handle the rest of the upper triangle.\n", + " for skew_diagonal_idx in range(2, shorter_dim + 1):\n", + " skew_diagonal_length = (skew_diagonal_idx + 1)\n", + " for offset_within_skew_diagonal in range(1, skew_diagonal_length-1): #! Skip the first column & row\n", + " # If we haven't passed the main skew diagonal yet, \n", + " # then we have to skip the first and the last operation,\n", + " # as those are already pre-populated and form the first column \n", + " # and the first row of the Levenshtein matrix respectively.\n", + " i = skew_diagonal_idx - offset_within_skew_diagonal\n", + " j = offset_within_skew_diagonal\n", + " if verbose:\n", + " print(f\"top left triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " shorter_char = shorter[i - 1]\n", + " longer_char = longer[j - 1]\n", + " substitution_cost = shorter_char != longer_char\n", + " matrix[i, j] = min(\n", + " matrix[i - 1, j] + 1, # Deletion\n", + " matrix[i, j - 1] + 1, # Insertion\n", + " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " )\n", + " \n", + " # Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. \n", + " for skew_diagonal_idx in range(shorter_dim + 1, longer_dim + 1):\n", + " skew_diagonal_length = shorter_dim\n", + " for offset_within_skew_diagonal in range(skew_diagonal_length):\n", + " i = shorter_dim - offset_within_skew_diagonal - 1\n", + " j = offset_within_skew_diagonal + 1\n", + " if verbose:\n", + " print(f\"anti-band: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " shorter_char = shorter[i - 1]\n", + " longer_char = longer[j - 1]\n", + " substitution_cost = shorter_char != longer_char\n", + " matrix[i, j] = min(\n", + " matrix[i - 1, j] + 1, # Deletion\n", + " matrix[i, j - 1] + 1, # Insertion\n", + " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " )\n", + " \n", + " # Now let's handle the bottom right triangle.\n", + " for skew_diagonal_idx in range(longer_dim + 1, skew_diagonals_count):\n", + " skew_diagonal_length = 2 * longer_dim - skew_diagonal_idx - 1\n", + " for offset_within_skew_diagonal in range(skew_diagonal_length):\n", + " i = shorter_dim - offset_within_skew_diagonal - 1\n", + " j = skew_diagonal_idx - longer_dim + offset_within_skew_diagonal + 1\n", + " if verbose:\n", + " print(f\"bottom right triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " shorter_char = shorter[i - 1]\n", + " longer_char = longer[j - 1]\n", + " substitution_cost = shorter_char != longer_char\n", + " matrix[i, j] = min(\n", + " matrix[i - 1, j] + 1, # Deletion\n", + " matrix[i, j - 1] + 1, # Insertion\n", + " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " )\n", + "\n", + " # Return the Levenshtein distance\n", + " return matrix[len(shorter), len(longer)], matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('listeners',\n", + " 'silents',\n", + " 'distance = np.int64(5)',\n", + " array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n", + " [1, 1, 2, 2, 3, 4, 5, 6, 7, 8],\n", + " [2, 2, 1, 2, 3, 4, 5, 6, 7, 8],\n", + " [3, 2, 2, 2, 3, 4, 5, 6, 7, 8],\n", + " [4, 3, 3, 3, 3, 3, 4, 5, 6, 7],\n", + " [5, 4, 4, 4, 4, 4, 3, 4, 5, 6],\n", + " [6, 5, 5, 5, 4, 5, 4, 4, 5, 6],\n", + " [7, 6, 6, 5, 5, 5, 5, 5, 5, 5]]))" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "s1 = \"listeners\"\n", + "s2 = \"silents\"\n", + "distance, baseline = skewed_diagonals(s1, s2)\n", + "s1, s2, f\"{distance = }\", baseline" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('listeners',\n", + " 'silents',\n", + " 'distance = np.int64(5)',\n", + " array([[0, 1, 2, 3, 4, 5, 6, 7],\n", + " [1, 1, 2, 2, 3, 4, 5, 6],\n", + " [2, 2, 1, 2, 3, 4, 5, 6],\n", + " [3, 2, 2, 2, 3, 4, 5, 5],\n", + " [4, 3, 3, 3, 3, 4, 4, 5],\n", + " [5, 4, 4, 4, 3, 4, 5, 5],\n", + " [6, 5, 5, 5, 4, 3, 4, 5],\n", + " [7, 6, 6, 6, 5, 4, 4, 5],\n", + " [8, 7, 7, 7, 6, 5, 5, 5],\n", + " [9, 8, 8, 8, 7, 6, 6, 5]]))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "distance, baseline = wagner_fisher(s1, s2)\n", + "s1, s2, f\"{distance = }\", baseline" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "s1 = ''.join(random.choices(\"abcd\", k=5))\n", + "s2 = ''.join(random.choices(\"abcd\", k=6))\n", + "distance_v0, baseline_v0 = wagner_fisher(s1, s2)\n", + "distance_v2, baseline_v2 = skewed_diagonals(s1, s2, verbose=False)\n", + "assert distance_v0 == distance_v2, f\"{distance_v0 = } != {distance_v2 = }\"\n", + "assert np.all(baseline_v0 == baseline_v2), f\"{baseline_v0 = }\\n{baseline_v2 = }\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -366,7 +590,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.2" } }, "nbformat": 4, From d3b423a4c647bec1c823857a4bc043b77d6c2df3 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 09:09:19 +0000 Subject: [PATCH 004/751] Improve: Levenshtein functions for unicode --- scripts/test.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/scripts/test.cpp b/scripts/test.cpp index 47ef46d2..cb7d0079 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1394,6 +1394,20 @@ static void test_levenshtein_distances() { {"abc", "adc", 1}, // one substitution {"abc", "abc", 0}, // same string {"ggbuzgjux{}l", "gbuzgjux{}l", 1}, // one insertion (prepended) + {"apple", "aple", 1}, + // Unicode: + {"αβγδ", "αγδ", 2}, // Each Greek symbol is 2 bytes in size + {"مرحبا بالعالم", "مرحبا يا عالم", 3}, // "Hello World" vs "Welcome to the World" ? + {"école", "école", 3}, // letter "é" as a single character vs "e" + "´" + {"Schön", "Scho\u0308n", 3}, // "ö" represented as "o" + "¨" + {"💖", "💗", 1}, // 4-byte emojis: Different hearts + {"𠜎 𠜱 𠝹 𠱓", "𠜎𠜱𠝹𠱓", 3}, // Ancient Chinese characters, no spaces vs spaces + {"München", "Muenchen", 2}, // German name with umlaut vs. its transcription + {"façade", "facade", 2}, // "ç" represented as "c" with cedilla vs. plain "c" + {"こんにちは世界", "こんばんは世界", 3}, // Japanese: "Good morning world" vs "Good evening world" + {"👩‍👩‍👧‍👦", "👨‍👩‍👧‍👦", 1}, // Family emojis with different compositions + {"Data科学123", "Data科學321", 3}, + {"🙂🌍🚀", "🙂🌎✨", 5}, }; using matrix_t = std::int8_t[256][256]; @@ -1435,6 +1449,7 @@ static void test_levenshtein_distances() { std::size_t iterations; } fuzzy_cases[] = { {10, 1000}, + {64, 128}, {100, 100}, {1000, 10}, }; From 1765f334230e60c7884a1c7efc48f2227c1ed2c9 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 09:10:38 +0000 Subject: [PATCH 005/751] Add: Missing Rust interfaces `sz_checksum`, `sz_hash`, `sz_edit_distance_utf8`, `sz_edit_distance_bounded`, `sz_edit_distance_utf8_bounded`. --- rust/lib.rs | 145 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 139 insertions(+), 6 deletions(-) diff --git a/rust/lib.rs b/rust/lib.rs index 30150efb..08c8772a 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -8,7 +8,7 @@ pub mod sz { - use core::ffi::c_void; + use core::{ffi::c_void, usize}; // Import the functions from the StringZilla C library. extern "C" { @@ -54,6 +54,10 @@ pub mod sz { needle_length: usize, ) -> *const c_void; + fn sz_hash(text: *const c_void, length: usize) -> u64; + + fn sz_checksum(text: *const c_void, length: usize) -> u64; + fn sz_edit_distance( haystack1: *const c_void, haystack1_length: usize, @@ -98,8 +102,6 @@ pub mod sz { allocator: *const c_void, ) -> isize; - // type RandomGeneratorT = fn(*mut c_void) -> u64; - fn sz_generate( alphabet: *const c_void, alphabet_size: usize, @@ -110,6 +112,51 @@ pub mod sz { ); } + /// Computes the checksum value of unsigned bytes in a given byte slice `text`. + /// This function is useful for verifying data integrity and detecting changes in + /// binary data, such as files or network packets. + /// + /// # Arguments + /// + /// * `text`: The byte slice to compute the checksum for. + /// + /// # Returns + /// + /// A `u64` representing the checksum value of the input byte slice. + pub fn checksum(text: T) -> u64 + where + T: AsRef<[u8]>, + { + let text_ref = text.as_ref(); + let text_pointer = text_ref.as_ptr() as _; + let text_length = text_ref.len(); + let result = unsafe { sz_checksum(text_pointer, text_length) }; + return result; + } + + /// Computes a 64-bit AES-based hash value for a given byte slice `text`. + /// This function is designed to provide a high-quality hash value for use in + /// hash tables, data structures, and cryptographic applications. + /// Unlike the checksum function, the hash function is order-sensitive. + /// + /// # Arguments + /// + /// * `text`: The byte slice to compute the checksum for. + /// + /// # Returns + /// + /// A `u64` representing the hash value of the input byte slice. + pub fn hash(text: T) -> u64 + where + T: AsRef<[u8]>, + { + let text_ref = text.as_ref(); + let text_pointer = text_ref.as_ptr() as _; + let text_length = text_ref.len(); + let result = unsafe { sz_hash(text_pointer, text_length) }; + return result; + } + /// Locates the first matching substring within `haystack` that equals `needle`. /// This function is similar to the `memmem()` function in LibC, but, unlike `strstr()`, /// it requires the length of both haystack and needle to be known beforehand. @@ -445,7 +492,7 @@ pub mod sz { F: AsRef<[u8]>, S: AsRef<[u8]>, { - edit_distance_bounded(first, second, 0) + edit_distance_bounded(first, second, usize::MAX) } /// Computes the Levenshtein edit distance between two UTF8 strings, using the Wagner-Fisher @@ -465,7 +512,7 @@ pub mod sz { F: AsRef<[u8]>, S: AsRef<[u8]>, { - edit_distance_utf8_bounded(first, second, 0) + edit_distance_utf8_bounded(first, second, usize::MAX) } /// Computes the Hamming edit distance between two strings, counting the number of substituted characters. @@ -987,6 +1034,34 @@ pub trait StringZilla<'a, N> where N: AsRef<[u8]> + 'a, { + /// Computes the checksum value of unsigned bytes in a given string. + /// This function is useful for verifying data integrity and detecting changes in + /// binary data, such as files or network packets. + /// + /// # Examples + /// + /// ``` + /// use stringzilla::StringZilla; + /// + /// let text = "Hello"; + /// assert_eq!(text.sz_checksum(), Some(500)); + /// ``` + fn sz_checksum(&self) -> u64; + + /// Computes a 64-bit AES-based hash value for a given string. + /// This function is designed to provide a high-quality hash value for use in + /// hash tables, data structures, and cryptographic applications. + /// Unlike the checksum function, the hash function is order-sensitive. + /// + /// # Examples + /// + /// ``` + /// use stringzilla::StringZilla; + /// + /// assert_ne!("Hello".sz_hash(), "World".sz_hash()); + /// ``` + fn sz_hash(&self) -> u64; + /// Searches for the first occurrence of `needle` in `self`. /// /// # Examples @@ -1072,6 +1147,45 @@ where /// ``` fn sz_edit_distance(&self, other: N) -> usize; + /// Computes the Levenshtein edit distance between `self` and `other`. + /// + /// # Examples + /// + /// ``` + /// use stringzilla::StringZilla; + /// + /// let first = "kitten"; + /// let second = "sitting"; + /// assert_eq!(first.sz_edit_distance_utf8(second.as_bytes()), 3); + /// ``` + fn sz_edit_distance_utf8(&self, other: N) -> usize; + + /// Computes the bounded Levenshtein edit distance between `self` and `other`. + /// + /// # Examples + /// + /// ``` + /// use stringzilla::StringZilla; + /// + /// let first = "kitten"; + /// let second = "sitting"; + /// assert_eq!(first.sz_edit_distance_bounded(second.as_bytes()), 3); + /// ``` + fn sz_edit_distance_bounded(&self, other: N, bound: usize) -> usize; + + /// Computes the bounded Levenshtein edit distance between `self` and `other`. + /// + /// # Examples + /// + /// ``` + /// use stringzilla::StringZilla; + /// + /// let first = "kitten"; + /// let second = "sitting"; + /// assert_eq!(first.sz_edit_distance_utf8_bounded(second.as_bytes()), 3); + /// ``` + fn sz_edit_distance_utf8_bounded(&self, other: N, bound: usize) -> usize; + /// Computes the alignment score between `self` and `other` using the specified /// substitution matrix and gap penalty. /// @@ -1231,7 +1345,6 @@ where /// assert_eq!(matches, vec![b"!", b"d", b"l", b"r", b"w", b" ", b",", b"l", b"l", b"H"]); /// ``` fn sz_find_last_not_of(&'a self, needles: &'a N) -> RangeRMatches<'a>; - } impl<'a, T, N> StringZilla<'a, N> for T @@ -1239,6 +1352,14 @@ where T: AsRef<[u8]> + ?Sized, N: AsRef<[u8]> + 'a, { + fn sz_checksum(&self) -> u64 { + sz::checksum(self) + } + + fn sz_hash(&self) -> u64 { + sz::hash(self) + } + fn sz_find(&self, needle: N) -> Option { sz::find(self, needle) } @@ -1267,6 +1388,18 @@ where sz::edit_distance(self, other) } + fn sz_edit_distance_utf8(&self, other: N) -> usize { + sz::edit_distance_utf8(self, other) + } + + fn sz_edit_distance_bounded(&self, other: N, bound: usize) -> usize { + sz::edit_distance_bounded(self, other, bound) + } + + fn sz_edit_distance_utf8_bounded(&self, other: N, bound: usize) -> usize { + sz::edit_distance_utf8_bounded(self, other, bound) + } + fn sz_alignment_score(&self, other: N, matrix: [[i8; 256]; 256], gap: i8) -> isize { sz::alignment_score(self, other, matrix, gap) } From 62ca6a0e4635cd251bc97530b10438aa13a08eb5 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 09:12:02 +0000 Subject: [PATCH 006/751] Fix: Default Levenshtein upper bound --- python/lib.c | 7 +++---- swift/StringProtocol+StringZilla.swift | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/lib.c b/python/lib.c index dcf96625..c5346772 100644 --- a/python/lib.c +++ b/python/lib.c @@ -1858,8 +1858,8 @@ static PyObject *_Str_edit_distance(PyObject *self, PyObject *args, PyObject *kw return NULL; } - Py_ssize_t bound = 0; // Default value for bound - if (bound_obj && ((bound = PyLong_AsSsize_t(bound_obj)) < 0)) { + sz_size_t bound = SZ_SIZE_MAX; // Default value for bound + if (bound_obj && ((bound = (sz_size_t)PyLong_AsSize_t(bound_obj)) == (sz_size_t)(-1))) { PyErr_Format(PyExc_ValueError, "Bound must be a non-negative integer"); return NULL; } @@ -1877,8 +1877,7 @@ static PyObject *_Str_edit_distance(PyObject *self, PyObject *args, PyObject *kw reusing_allocator.free = &temporary_memory_free; reusing_allocator.handle = &temporary_memory; - sz_size_t distance = - function(str1.start, str1.length, str2.start, str2.length, (sz_size_t)bound, &reusing_allocator); + sz_size_t distance = function(str1.start, str1.length, str2.start, str2.length, bound, &reusing_allocator); // Check for memory allocation issues if (distance == SZ_SIZE_MAX) { diff --git a/swift/StringProtocol+StringZilla.swift b/swift/StringProtocol+StringZilla.swift index 0f7b36bc..d90c8afc 100644 --- a/swift/StringProtocol+StringZilla.swift +++ b/swift/StringProtocol+StringZilla.swift @@ -255,7 +255,7 @@ public extension StringZillaViewable { /// - Throws: If a memory allocation error has happened. @_specialize(where Self == String, S == String) @_specialize(where Self == String.UTF8View, S == String.UTF8View) - func editDistance(from other: S, bound: UInt64 = 0) throws -> UInt64? { + func editDistance(from other: S, bound: UInt64 = UInt64.max) throws -> UInt64? { var result: UInt64? // Use a do-catch block to handle potential errors From 0ee549a106b1ee524fa8059888219c03635e11e6 Mon Sep 17 00:00:00 2001 From: Govind Date: Sat, 7 Dec 2024 12:08:15 +0100 Subject: [PATCH 007/751] Make: Inline ASM for detecting CPU features on ARM Closes #143 --- c/lib.c | 49 ++++++++++++++++++++++++------- include/stringzilla/stringzilla.h | 5 ++-- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/c/lib.c b/c/lib.c index ee48400e..f38ac534 100644 --- a/c/lib.c +++ b/c/lib.c @@ -38,6 +38,43 @@ extern void *malloc(size_t length); #endif #endif +// On Apple Silicon, `mrs` is not allowed in user-space, so we need to use the `sysctl` API. +#if defined(__APPLE__) && defined(__MACH__) +#define SZ_APPLE 1 +#include +#endif + +#if defined(__linux__) +#define SZ_LINUX 1 +#endif + +SZ_INTERNAL sz_capability_t sz_capabilities_arm(void) { + // https://github.com/ashvardanian/SimSIMD/blob/28e536083602f85ad0c59456782c8864463ffb0e/include/simsimd/simsimd.h#L434 + // for documentation on how we detect capabilities across different ARM platforms. +#if defined(SZ_APPLE) + + // On Apple Silicon, `mrs` is not allowed in user-space, so we need to use the `sysctl` API. + uint32_t supports_neon = 0; + size_t size = sizeof(supports_neon); + if (sysctlbyname("hw.optional.neon", &supports_neon, &size, NULL, 0) != 0) supports_neon = 0; + + return (sz_capability_t)( // + (sz_cap_arm_neon_k * (supports_neon)) | // + (sz_cap_serial_k)); + +#elif defined(SZ_LINUX) + unsigned supports_neon = 1; // NEON is always supported + __asm__ __volatile__("mrs %0, ID_AA64PFR0_EL1" : "=r"(id_aa64pfr0_el1)); + unsigned supports_sve = ((id_aa64pfr0_el1 >> 32) & 0xF) >= 1; + return (sz_capability_t)( // + (sz_cap_neon_k * (supports_neon)) | // + (sz_cap_sve_k * (supports_sve)) | // + (sz_cap_serial_k)); +#else // SIMSIMD_DEFINED_LINUX + return sz_cap_serial_k; +#endif +} + SZ_DYNAMIC sz_capability_t sz_capabilities(void) { #if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 @@ -96,22 +133,12 @@ SZ_DYNAMIC sz_capability_t sz_capabilities(void) { #if SZ_USE_ARM_NEON || SZ_USE_ARM_SVE - // Every 64-bit Arm CPU supports NEON - unsigned supports_neon = 1; - unsigned supports_sve = 0; - unsigned supports_sve2 = 0; - sz_unused(supports_sve); - sz_unused(supports_sve2); - - return (sz_capability_t)( // - (sz_cap_arm_neon_k * supports_neon) | // - (sz_cap_serial_k)); + return sz_capabilities_arm(); #endif // SIMSIMD_TARGET_ARM return sz_cap_serial_k; } - typedef struct sz_implementations_t { sz_equal_t equal; sz_order_t order; diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 7aa9e6da..588a3282 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -260,7 +260,8 @@ typedef enum sz_capability_t { sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - + sz_cap_arm_sve2_k = 1 << 12, + sz_cap_arm_sve2p1_k = 1 << 13, sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability @@ -268,8 +269,6 @@ typedef enum sz_capability_t { sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability - sz_cap_x86_avx512vbmi2_k = 1 << 26, /// x86 AVX512 VBMI 2 instruction capability - } sz_capability_t; /** From 43471aa8131d17a6d6a4bf521b5a99aa2b59bd54 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:11:42 +0000 Subject: [PATCH 008/751] Add: New Levenshtein distance kernels --- README.md | 21 +- include/stringzilla/stringzilla.h | 501 +++++++++++++++++---- scripts/test.cpp | 23 +- scripts/test_levenshtein.ipynb | 706 +++++++++++++++++++++--------- 4 files changed, 957 insertions(+), 294 deletions(-) diff --git a/README.md b/README.md index dbfd3f9b..d5c59ff9 100644 --- a/README.md +++ b/README.md @@ -1367,6 +1367,20 @@ Other algorithms previously considered and deprecated: > [Exact String Matching Algorithms in Java](https://www-igm.univ-mlv.fr/~lecroq/string). > [SIMD-friendly algorithms for substring searching](http://0x80.pl/articles/simd-strfind.html). +### Exact Multiple Substring Search + +Few algorithms for multiple substring search are known. +Most are based on the Aho-Corasick automaton, which is a generalization of the KMP algorithm. +The naive implementation, however: + +- Allocates disjoint memory for each Trie node and Automaton state. +- Requires a lot of pointer chasing, limiting speculative execution. +- Has a lot of branches and conditional moves, which are hard to predict. +- Matches text a character at a time, which is slow on modern CPUs. + +There are several ways to improve the original algorithm. +One is to use sparse DFA representation, which is more cache-friendly, but would require extra processing to navigate state transitions. + ### Levenshtein Edit Distance Levenshtein distance is the best known edit-distance for strings, that checks, how many insertions, deletions, and substitutions are needed to transform one string to another. @@ -1388,10 +1402,11 @@ It's less known, than the others, derived from the Baeza-Yates-Gonnet algorithm, StringZilla introduces a different approach, extensively used in Unum's internal combinatorial optimization libraries. The approach doesn't change the number of trivial operations, but performs them in a different order, removing the data dependency, that occurs when computing the insertion costs. This results in much better vectorization for intra-core parallelism and potentially multi-core evaluation of a single request. +Moreover, it's easy to generalize to weighted edit-distances, where the cost of a substitution between two characters may not be the same for all pairs, often used in bioinformatics. Next design goals: -- [ ] Generalize fast traversals to rectangular matrices. +- [x] Generalize fast traversals to non-square matrices. - [ ] Port x86 AVX-512 solution to Arm NEON. > § Reading materials. @@ -1425,6 +1440,10 @@ With that solved, the SIMD implementation will become 5x faster than the serial [faq-dipeptide]: https://en.wikipedia.org/wiki/Dipeptide [faq-titin]: https://en.wikipedia.org/wiki/Titin +Next design goals: + +- [ ] Needleman-Wunsch Automata + ### Memory Copying, Fills, and Moves A lot has been written about the time computers spend copying memory and how that operation is implemented in LibC. diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 7aa9e6da..b6622c27 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -1,7 +1,7 @@ /** - * @brief StringZilla is a collection of simple string algorithms, designed to be used in Big Data applications. - * It may be slower than LibC, but has a broader & cleaner interface, and a very short implementation - * targeting modern x86 CPUs with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. + * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. + * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs + * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. * * Consider overriding the following macros to customize the library: * @@ -843,12 +843,12 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz * @see sz_hamming_distance_utf8 * @see https://en.wikipedia.org/wiki/Hamming_distance */ -SZ_DYNAMIC sz_size_t sz_hamming_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); +SZ_DYNAMIC sz_size_t sz_hamming_distance( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); /** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); +SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); /** * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. @@ -887,10 +887,11 @@ typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_s * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, * so the memory usage is linear in relation to ::a_length and ::b_length. * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. + * @param bound Exclusive upper bound on the distance, that allows us to exit early. + * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. + * Pass zero to check if the strings are equal. + * @return Unsigned integer for the edit distance. Zero means the strings are equal. + * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. * * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default * @see https://en.wikipedia.org/wiki/Levenshtein_distance @@ -1022,8 +1023,9 @@ typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_ * @param window_length Length of the rolling window in bytes. * @see sz_hashes, sz_hashes_intersection */ -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); +SZ_PUBLIC void sz_hashes_fingerprint( // + sz_cptr_t text, sz_size_t length, sz_size_t window_length, // + sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); @@ -1041,8 +1043,9 @@ typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_ * @param window_length Length of the rolling window in bytes. * @see sz_hashes, sz_hashes_fingerprint */ -SZ_PUBLIC sz_size_t sz_hashes_intersection(sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); +SZ_PUBLIC sz_size_t sz_hashes_intersection( // + sz_cptr_t text, sz_size_t length, sz_size_t window_length, // + sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); @@ -1773,8 +1776,8 @@ SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // TODO: Investigate alternative strategies for long needles. // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benfit from shifting the first and second characters - // further to the right, to achieve not only uniqness within the needle, but also avoid common + // Often dealing with UTF8, we will likely benefit from shifting the first and second characters + // further to the right, to achieve not only uniqueness within the needle, but also avoid common // rune prefixes of 2-, 3-, and 4-byte codes. if (length > 8) { // Pivot the first and second points right, until we find a character, that: @@ -1788,7 +1791,7 @@ SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, sz_u8_t const *start_u8 = (sz_u8_t const *)start; sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - // Let's begin with the seccond character, as the termination criterea there is more obvious + // Let's begin with the seccond character, as the termination criteria there is more obvious // and we may end up with more variants to check for the first candidate. for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && (vibrant_second + 1 < vibrant_third); @@ -2455,18 +2458,18 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // current_distances[0] = current_distances[1] = 1; // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_skew_diagonal_index = 2; - for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) { - sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_skew_diagonal_index - i - 2] != longer[i]; + sz_size_t next_diagonal_index = 2; + for (; next_diagonal_index != n; ++next_diagonal_index) { + sz_size_t const next_diagonal_length = next_diagonal_index + 1; + for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { + sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); } - // Don't forget to populate the first row and the fiest column of the Levenshtein matrix. - next_distances[0] = next_distances[next_skew_diagonal_length - 1] = next_skew_diagonal_index; - // Perform a circular rotarion of those buffers, to reuse the memory. + // Don't forget to populate the first row and the first column of the Levenshtein matrix. + next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; + // Perform a circular rotation of those buffers, to reuse the memory. sz_size_t *temporary = previous_distances; previous_distances = current_distances; current_distances = next_distances; @@ -2476,17 +2479,16 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal // index on either side, we will be cropping those values out. - sz_size_t total_diagonals = n + n - 1; - for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) { - sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index; - for (sz_size_t i = 0; i != next_skew_diagonal_length; ++i) { - sz_size_t cost_of_substitution = - shorter[shorter_length - 1 - i] != longer[next_skew_diagonal_index - n + i]; + sz_size_t diagonals_count = n + n - 1; + for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { + sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; + for (sz_size_t i = 0; i != next_diagonal_length; ++i) { + sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); } - // Perform a circular rotarion of those buffers, to reuse the memory, this time, with a shift, + // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, // dropping the first element in the current array. sz_size_t *temporary = previous_distances; previous_distances = current_distances + 1; @@ -2737,7 +2739,8 @@ SZ_PUBLIC sz_size_t sz_edit_distance_serial( // --longer_length, --shorter_length); // Bounded computations may exit early. - if (bound) { + int const is_bounded = bound < longer_length; + if (is_bounded) { // If one of the strings is empty - the edit distance is equal to the length of the other one. if (longer_length == 0) return sz_min_of_two(shorter_length, bound); if (shorter_length == 0) return sz_min_of_two(longer_length, bound); @@ -2746,7 +2749,7 @@ SZ_PUBLIC sz_size_t sz_edit_distance_serial( // } if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !bound) + if (shorter_length == longer_length && !is_bounded) return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, alloc); @@ -4555,10 +4558,10 @@ SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t windo * @brief AVX-512 implementation of the string search algorithms. * * Different subsets of AVX-512 were introduced in different years: - * * 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * * 2018 CannonLake: IFMA, VBMI - * * 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * * 2020 TigerLake: VP2INTERSECT + * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW + * - 2018 CannonLake: IFMA, VBMI + * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES + * - 2020 TigerLake: VP2INTERSECT */ #pragma region AVX512 Implementation @@ -5130,11 +5133,269 @@ SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n return SZ_NULL_CHAR; } +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ + apply_to = function) + +/** + * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. + * + * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. + * Supports an early exit, if the distance is bounded. + * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. + * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. + */ +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound) { + + sz_size_t const max_length = 63u; + sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); + sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); + + // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. + // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. + sz_size_t const shorter_dim = shorter_length + 1; + sz_size_t const longer_dim = longer_length + 1; + + // The next few buffers will be swapped around. + sz_u512_vec_t previous_vec, current_vec, next_vec; + sz_u512_vec_t gaps_vec, substitutions_vec; + + // Load the strings into ZMM registers - just once. + sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; + longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); + rotate_left_vec.zmm = _mm512_set_epi8( // + 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // + 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // + 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // + 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); + rotate_right_vec.zmm = _mm512_set_epi8( // + 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // + 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); + ones_vec.zmm = _mm512_set1_epi8(1); + bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); + + // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. + for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; + shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); + + // Let's say we are dealing with 3 and 5 letter words. + // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). + // It will have: + // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. + // - 2 diagonals of fixed length, at positions: 4, 5. + // - 3 diagonals of decreasing length, at positions: 6, 7, 8. + sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; + + // Initialize the first two diagonals: + // + // previous_vec.u8s[0] = 0; + // current_vec.u8s[0] = current_vec.u8s[1] = 1; + // + // We can do a similar thing with vector ops: + previous_vec.zmm = _mm512_setzero_si512(); + current_vec.zmm = _mm512_set1_epi8(1); + + // We skip diagonals 0 and 1, as they are trivial. + // We will start with diagonal 2, which has length 3, with the first and last elements being preset, + // so we are effectively computing just one value, as will be marked by a single set bit in + // the `next_diagonal_mask` on the very first iteration. + sz_size_t next_diagonal_index = 2; + __mmask64 next_diagonal_mask = 0; + + // Progress through the upper triangle of the Levenshtein matrix. + for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { + // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` + // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, + // and later merge with a mask with new values. + next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); + + // The mask also adds one set bit. + next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); + next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); + + // Check for equality between string slices. + __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); + substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); + substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); + gaps_vec.zmm = _mm512_add_epi8( + // Insertions or deletions + _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), + ones_vec.zmm); + next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); + + // Mark the current skewed diagonal as the previous one and the next one as the current one. + previous_vec.zmm = current_vec.zmm; + current_vec.zmm = next_vec.zmm; + + // Shift the shorter string + shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); + + // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. + __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // + return SZ_SIZE_MAX; + } + } + + // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. + for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { + // After this iteration, the value `shorted_dim - 1` in the `next_vec` + // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, + // and later merge with a mask with new values. + next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); + + // Make sure we update the first entry. + next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); + + // Check for equality between string slices. + __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); + substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); + gaps_vec.zmm = _mm512_add_epi8( + // Insertions or deletions + _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), + ones_vec.zmm); + next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); + + // Mark the current skewed diagonal as the previous one and the next one as the current one. + previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); + current_vec.zmm = next_vec.zmm; + + // Let's shift the longer string now. + longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); + + // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. + __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // + return SZ_SIZE_MAX; + } + } + + // Now let's handle the bottom right triangle. + for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { + + // Check for equality between string slices. + __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); + substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); + gaps_vec.zmm = _mm512_add_epi8( + // Insertions or deletions + _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), + ones_vec.zmm); + next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); + + // Mark the current skewed diagonal as the previous one and the next one as the current one. + previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); + current_vec.zmm = next_vec.zmm; + + // Let's shift the longer string now. + longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); + + // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. + __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // + return SZ_SIZE_MAX; + } + // In every following iterations we take use a shorter prefix of each register, + // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. + next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); + } + return current_vec.u8s[0]; +} + +/** + * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. + * + * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. + * Supports an early exit, if the distance is bounded. + * Uses a lot more CPU registers space, than the `upto63` variant. + * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. + * + * This may be one of the most freuqently called kernels for: + * - source code analysis, assuming most lines are either under 80 or under 120 characters long. + * - DNA sequence alignment, as most short reads are 50-300 characters long. + */ +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound) { + sz_unused(shorter && shorter_length && longer && longer_length && bound); + return 0; +} + +/** + * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. + * + * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. + * Supports an early exit, if the distance is bounded. + * Uses a lot more CPU registers space, than the `upto63` variant. + * + * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. + * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. + * This is the largest space-efficient variant, as strings beyond 255 characters may require + * 16-bit accumulators, which would be a significant bottleneck. + */ +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound) { + sz_unused(shorter && shorter_length && longer && longer_length && bound); + return 0; +} + +/** + * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, + * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. + * + * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. + * Supports an early exit, if the distance is bounded. + * Uses a lot more CPU registers space, than the `upto63` variant. + * + * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. + * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. + * This is the largest space-efficient variant, as strings beyond 255 characters may require + * 16-bit accumulators, which would be a significant bottleneck. + */ +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound) { + sz_unused(shorter && shorter_length && longer && longer_length && bound); + return 0; +} + +/** + * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. + * + * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. + * Supports an early exit, if the distance is bounded. + * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. + * + * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. + * + */ +SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound) { + sz_unused(shorter && shorter_length && longer && longer_length && bound); + return 0; +} + SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // sz_cptr_t shorter, sz_size_t shorter_length, // sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { + sz_unused(shorter && longer && bound && alloc); + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; if (!alloc) { @@ -5143,25 +5404,27 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // } // TODO: Generalize! - sz_size_t max_length = 256u * 256u; - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); + sz_size_t const max_length = 256u * 256u; + sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); sz_unused(longer_length && bound && max_length); +#if 0 // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; + // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. + sz_size_t const shorter_dim = shorter_length + 1; + sz_size_t const longer_dim = longer_length + 1; // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t buffer_length = sizeof(sz_u16_t) * n * 3 + shorter_length; - sz_u16_t *distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); + sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; + sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); if (!distances) return SZ_SIZE_MAX; + // The next few pointers will be swapped around. sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + n; - sz_u16_t *next_distances = current_distances + n; - sz_ptr_t shorter_reversed = (sz_ptr_t)(next_distances + n); + sz_u16_t *current_distances = previous_distances + longer_dim; + sz_u16_t *next_distances = current_distances + longer_dim; + sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); // Export the reversed string into the buffer. for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; @@ -5175,47 +5438,61 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; sz_u512_vec_t ones_u16_vec; ones_u16_vec.zmm = _mm512_set1_epi16(1); + // This is a mixed-precision implementation, using 8-bit representations for part of the operations. // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. sz_u512_vec_t shorter_vec, longer_vec; sz_u512_vec_t ones_u8_vec; ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); + // Let's say we are dealing with 3 and 5 letter words. + // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). + // It will have: + // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. + // - 2 diagonals of fixed length, at positions: 4, 5. + // - 3 diagonals of decreasing length, at positions: 6, 7, 8. + sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; + // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_skew_diagonal_index = 2; - for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) { - sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_skew_diagonal_length - i - 2); + sz_size_t next_diagonal_index = 2; + for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { + sz_size_t const next_diagonal_length = next_diagonal_index + 1; + for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { + sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + i); - // Our original code addressed the shorter string `[next_skew_diagonal_index - i - 2]` for growing `i`. - // If the `shorter` string was reversed, the `[next_skew_diagonal_index - i - 2]` would - // be equal to `[shorter_length - 1 - next_skew_diagonal_index + i + 2]`. - // Which simplified would be equal to `[shorter_length - next_skew_diagonal_index + i + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( - remaining_length_mask, shorter_reversed + shorter_length - next_skew_diagonal_index + i + 1); + longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); + // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` + // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the + // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - + // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to + // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. + shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // + remaining_length_mask, + shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. substitutions_vec.zmm = _mm512_cvtepi8_epi16( // _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); + substitutions_vec.zmm, + _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); + insertions_vec.zmm = + _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); + deletions_vec.zmm = + _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); // First get the minimum of insertions and deletions. next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i + 1, remaining_length_mask, next_vec.zmm); - i += register_length; + _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); + offset_within_diagonal += register_length; } - // Don't forget to populate the first row and the fiest column of the Levenshtein matrix. - next_distances[0] = next_distances[next_skew_diagonal_length - 1] = (sz_u16_t)next_skew_diagonal_index; - // Perform a circular rotarion of those buffers, to reuse the memory. + // Don't forget to populate the first row and the first column of the Levenshtein matrix. + next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; + // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. sz_u16_t *temporary = previous_distances; previous_distances = current_distances; current_distances = next_distances; @@ -5225,15 +5502,13 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal // index on either side, we will be cropping those values out. - sz_size_t total_diagonals = n + n - 1; - for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) { - sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index; - for (sz_size_t i = 0; i != next_skew_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_skew_diagonal_length - i); + for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { + sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; + for (sz_size_t i = 0; i != next_diagonal_length;) { + sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = - _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_skew_diagonal_index - n + i); + longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. @@ -5257,7 +5532,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // i += register_length; } - // Perform a circular rotarion of those buffers, to reuse the memory, this time, with a shift, + // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, // dropping the first element in the current array. sz_u16_t *temporary = previous_distances; previous_distances = current_distances + 1; @@ -5269,6 +5544,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // sz_size_t result = current_distances[0]; alloc->free(distances, buffer_length, alloc->handle); return result; +#endif + return 0; } SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // @@ -5276,21 +5553,37 @@ SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { - if (shorter_length == longer_length && !bound && shorter_length && shorter_length < 256u * 256u) - return _sz_edit_distance_skewed_diagonals_upto65k_avx512(shorter, shorter_length, longer, longer_length, bound, - alloc); + // Bounded computations may exit early. + int const is_bounded = bound < longer_length; + if (is_bounded) { + // If one of the strings is empty - the edit distance is equal to the length of the other one. + if (longer_length == 0) return sz_min_of_two(shorter_length, bound); + if (shorter_length == 0) return sz_min_of_two(longer_length, bound); + // If the difference in length is beyond the `bound`, there is no need to check at all. + if (longer_length - shorter_length > bound) return bound; + } + + // Make sure the shorter string is actually shorter. + if (shorter_length > longer_length) { + sz_cptr_t temporary = shorter; + shorter = longer; + longer = temporary; + sz_size_t temporary_length = shorter_length; + shorter_length = longer_length; + longer_length = temporary_length; + } + + // Dispatch the right implementation based on the length of the strings. + if (longer_length < 64u) + return _sz_edit_distance_skewed_diagonals_upto63_avx512( // + shorter, shorter_length, longer, longer_length, bound); + // else if (longer_length < 256u * 256u) + // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // + // shorter, shorter_length, longer, longer_length, bound, alloc); else return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); } -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,bmi,bmi2"))), \ - apply_to = function) - SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". @@ -5671,10 +5964,11 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_ sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; sz_u512_vec_t bitset_even_vec, bitset_odd_vec; sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); + bitmask_lookup_vec.zmm = _mm512_set_epi8( // + -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // + -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // + -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // + -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); while (length) { // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" @@ -5746,6 +6040,29 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz return sz_rfind_charset_serial(text, length, filter); } +SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // + sz_cptr_t haystack, sz_size_t haystack_length, // + sz_cptr_t const *needles, sz_size_t const *needles_lengths, // + sz_size_t *needle_offset) { + + // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. + // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from + // the `sz_find_avx512` and `sz_find_charset_avx512` functions. + // + // Pick the offsets within needles where there is the least variance in the characters. + // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": + // + // 0: 't' + // 1: 'h' + // 2: 'e', 'a', 'i', 'o', 'u' + // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' + // + // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make + // the most sense if we can only use 3 ZMM registers. + sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); + return 0; +} + /** * Computes the Needleman Wunsch alignment score between two strings. * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. diff --git a/scripts/test.cpp b/scripts/test.cpp index cb7d0079..eecc97f0 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1431,6 +1431,12 @@ static void test_levenshtein_distances() { received_score = sz::alignment_score(r, l, costs, -1); if (received != expected) print_failure("Levenshtein", r, l, expected, received); if ((std::size_t)(-received_score) != expected) print_failure("Scoring", r, l, expected, received_score); + + // Validate the bounded variants: + if (received > 1) { + assert(sz::edit_distance(l, r, received) == received); + assert(sz::edit_distance(r, l, received - 1) == SZ_SIZE_MAX); + } }; for (auto explicit_case : explicit_cases) @@ -1553,6 +1559,20 @@ static void test_stl_containers() { int main(int argc, char const **argv) { + auto dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, SZ_SIZE_MAX); + sz_assert(dist == 5); + dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 3); + sz_assert(dist == SZ_SIZE_MAX); + dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 4); + sz_assert(dist == SZ_SIZE_MAX); + dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 5); + sz_assert(dist == 5); + dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 6); + sz_assert(dist == 5); + + // Similarity measures and fuzzy search + test_levenshtein_distances(); + // Let's greet the user nicely sz_unused(argc && argv); std::printf("Hi, dear tester! You look nice today!\n"); @@ -1596,9 +1616,6 @@ int main(int argc, char const **argv) { test_search_with_misaligned_repetitions(); #endif - // Similarity measures and fuzzy search - test_levenshtein_distances(); - // Sequences of strings test_sequence_algorithms(); test_stl_containers(); diff --git a/scripts/test_levenshtein.ipynb b/scripts/test_levenshtein.ipynb index 4718c386..606939ae 100644 --- a/scripts/test_levenshtein.ipynb +++ b/scripts/test_levenshtein.ipynb @@ -70,6 +70,37 @@ " return matrix[len(s1), len(s2)], matrix" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('kiten',\n", + " 'katerinas',\n", + " 'distance_wf = np.int64(5)',\n", + " array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n", + " [1, 0, 1, 2, 3, 4, 5, 6, 7, 8],\n", + " [2, 1, 1, 2, 3, 4, 4, 5, 6, 7],\n", + " [3, 2, 2, 1, 2, 3, 4, 5, 6, 7],\n", + " [4, 3, 3, 2, 1, 2, 3, 4, 5, 6],\n", + " [5, 4, 4, 3, 2, 2, 3, 3, 4, 5]]))" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "s1 = \"kiten\"\n", + "s2 = \"katerinas\"\n", + "distance_wf, matrix_wf = wagner_fisher(s1, s2)\n", + "s1, s2, f\"{distance_wf = }\", matrix_wf" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -78,6 +109,53 @@ "It's trivial to see that the space complexity can be reduced to $O(min(N, M))$ by only storing the last two rows of the matrix, but we want to keep the entire matrix as a reference to allow debugging and visualization." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To feel safer, while designing our alternative traversal algorithm, let's define an extraction function, that will get the values of a certain skewed diagonal." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def get_skewed_diagonal(matrix: np.ndarray, index: int):\n", + " flipped_matrix = np.fliplr(matrix)\n", + " return np.flip(np.diag(flipped_matrix, k= matrix.shape[1] - index - 1))\n", + "\n", + "# Let's test this function right away.\n", + "matrix = np.array([\n", + " [1, 2, 3],\n", + " [4, 5, 6],\n", + " [7, 8, 9]])\n", + "assert np.all(get_skewed_diagonal(matrix, 2) == [7, 5, 3])\n", + "assert np.all(get_skewed_diagonal(matrix, 1) == [4, 2])\n", + "assert np.all(get_skewed_diagonal(matrix, 4) == [9])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([2, 3, 5, 6, 8])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_skewed_diagonal(matrix_wf, 10)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -95,11 +173,17 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "def square_skewed_diagonals(s1: str, s2: str, verbose: bool = False) -> Tuple[int, np.ndarray]:\n", + "from typing import Optional\n", + "\n", + "def square_skewed_diagonals(\n", + " s1: str, s2: str, \n", + " verbose: bool = False, \n", + " baseline: Optional[np.ndarray] = None) -> Tuple[int, np.ndarray]:\n", + "\n", " assert len(s1) == len(s2), \"First define an algo for square matrices!\"\n", " # Create a matrix of size (len(s1)+1) x (len(s2)+1)\n", " matrix = np.zeros((len(s1) + 1, len(s2) + 1), dtype=int)\n", @@ -115,20 +199,20 @@ " n = len(s1) + 1\n", " \n", " # Number of diagonals and skewed diagonals in the square matrix of size (n x n).\n", - " skew_diagonals_count = 2 * n - 1\n", + " diagonals_count = 2 * n - 1\n", " \n", " # Populate the matrix in 2 separate loops: for the top left triangle and for the bottom right triangle.\n", - " for skew_diagonal_idx in range(2, n):\n", - " skew_diagonal_length = skew_diagonal_idx + 1\n", - " for offset_within_skew_diagonal in range(1, skew_diagonal_length - 1):\n", + " for skew_diagonal_index in range(2, n):\n", + " skew_diagonal_length = skew_diagonal_index + 1\n", + " for offset_within_diagonal in range(1, skew_diagonal_length - 1):\n", " # If we haven't passed the main skew diagonal yet, \n", " # then we have to skip the first and the last operation,\n", " # as those are already pre-populated and form the first column \n", " # and the first row of the Levenshtein matrix respectively.\n", - " i = skew_diagonal_idx - offset_within_skew_diagonal\n", - " j = offset_within_skew_diagonal\n", + " i = skew_diagonal_index - offset_within_diagonal\n", + " j = offset_within_diagonal\n", " if verbose:\n", - " print(f\"top left triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " print(f\"top left triangle: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", " substitution_cost = s1[i - 1] != s2[j - 1]\n", " matrix[i, j] = min(\n", " matrix[i - 1, j] + 1, #? Deletion cost\n", @@ -136,20 +220,26 @@ " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", " )\n", " \n", + " if baseline is not None:\n", + " assert matrix[i, j] == baseline[i, j], f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", + " \n", " # Now the bottom right triangle of the matrix.\n", - " for skew_diagonal_idx in range(n, skew_diagonals_count):\n", - " skew_diagonal_length = 2*n - skew_diagonal_idx - 1\n", - " for offset_within_skew_diagonal in range(skew_diagonal_length):\n", - " i = n - offset_within_skew_diagonal - 1\n", - " j = skew_diagonal_idx - n + offset_within_skew_diagonal + 1\n", + " for skew_diagonal_index in range(n, diagonals_count):\n", + " skew_diagonal_length = 2 * n - skew_diagonal_index - 1\n", + " for offset_within_diagonal in range(skew_diagonal_length):\n", + " i = n - offset_within_diagonal - 1\n", + " j = skew_diagonal_index - n + offset_within_diagonal + 1\n", " if verbose:\n", - " print(f\"bottom right triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " print(f\"bottom right triangle: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", " substitution_cost = s1[i - 1] != s2[j - 1]\n", " matrix[i, j] = min(\n", " matrix[i - 1, j] + 1, #? Deletion cost\n", " matrix[i, j - 1] + 1, #? Insertion cost\n", " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", " )\n", + " \n", + " if baseline is not None:\n", + " assert matrix[i, j] == baseline[i, j], f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", "\n", " # Similarly, the distance will be placed in the bottom right corner of the matrix\n", " return matrix[len(s1), len(s2)], matrix" @@ -164,75 +254,97 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import random\n", "for _ in range(10):\n", - " s1 = ''.join(random.choices(\"ab\", k=50))\n", - " s2 = ''.join(random.choices(\"ab\", k=50))\n", - " d0, _ = wagner_fisher(s1, s2)\n", - " d1, _ = square_skewed_diagonals(s1, s2)\n", - " assert d0 == d1" + " s1 = ''.join(random.choices(\"abc\", k=50))\n", + " s2 = ''.join(random.choices(\"abc\", k=50))\n", + " distance_wf, matrix_wf = wagner_fisher(s1, s2)\n", + " distance_sd, matrix_sd = square_skewed_diagonals(s1, s2, baseline=matrix_wf)\n", + " assert distance_wf == distance_sd, f\"{distance_wf = } != {distance_sd = }\"\n", + " assert np.all(matrix_wf == matrix_sd), f\"{matrix_wf = }\\n{matrix_sd = }\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Going further, we can avoid storing the whole matrix, and only store two diagonals at a time.\n", - "The longer will never exceed N. The shorter one is always at most N-1, and is always shorter by one." + "## Vectorizing the Skewed Diagonals Algorithm\n", + "\n", + "Going further, we can avoid storing the whole matrix, and only store three diagonals at a time.\n", + "The longer will never exceed `n` in length.\n", + "The others are always at most `n-1`.\n", + "Let's try vectorizing different parts of our algorithm, validating it against the output of the naive algorithm for 2 strings: `\"BCDE\"` and `\"FKPU\"`." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "('listen',\n", - " 'silent',\n", - " 'distance = np.int64(4)',\n", - " array([[0, 1, 2, 3, 4, 5, 6],\n", - " [1, 1, 2, 2, 3, 4, 5],\n", - " [2, 2, 1, 2, 3, 4, 5],\n", - " [3, 2, 2, 2, 3, 4, 5],\n", - " [4, 3, 3, 3, 3, 4, 4],\n", - " [5, 4, 4, 4, 3, 4, 5],\n", - " [6, 5, 5, 5, 4, 3, 4]]))" + "('BCDE',\n", + " 'FKPU',\n", + " 'distance_wf = np.int64(4)',\n", + " array([[0, 1, 2, 3, 4],\n", + " [1, 1, 2, 3, 4],\n", + " [2, 2, 2, 3, 4],\n", + " [3, 3, 3, 3, 4],\n", + " [4, 4, 4, 4, 4]]))" ] }, - "execution_count": 4, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "s1 = \"listen\"\n", - "s2 = \"silent\"\n", - "# s1 = ''.join(random.choices(\"abcd\", k=100))\n", - "# s2 = ''.join(random.choices(\"abcd\", k=100))\n", - "distance, baseline = wagner_fisher(s1, s2)\n", - "s1, s2, f\"{distance = }\", baseline" + "s1 = \"BCDE\"\n", + "s2 = \"FKPU\"\n", + "distance_wf, matrix_wf = wagner_fisher(s1, s2)\n", + "s1, s2, f\"{distance_wf = }\", matrix_wf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Replacing the letters with numbers and annotating with a header row and column for `\"BCDE\"` and `\"FKPU\"`:\n", + "\n", + "| | | **B** | **C** | **D** | **E** |\n", + "| ----- | --- | ----- | ----- | ----- | ----- |\n", + "| | a | b | c | d | e |\n", + "| **F** | f | g | h | i | j |\n", + "| **K** | k | l | m | n | o |\n", + "| **P** | p | q | r | s | t |\n", + "| **U** | u | v | w | x | y |\n", + "\n", + "At any point we will be working with 3 diagonals:\n", + "\n", + "- `previous` set to `[a]` at start\n", + "- `current` set to `[f, b]` at start\n", + "- `following` set to `[k, g, c]` at start" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(array([0, 0, 0, 0, 0, 0, 0], dtype=uint64),\n", - " array([1, 1, 0, 0, 0, 0, 0], dtype=uint64),\n", - " array([0, 0, 0, 0, 0, 0, 0], dtype=uint64))" + "(array([0, 0, 0, 0, 0], dtype=uint64),\n", + " array([1, 1, 0, 0, 0], dtype=uint64),\n", + " array([0, 0, 0, 0, 0], dtype=uint64))" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -242,17 +354,6 @@ "# Number of rows and columns in the square matrix.\n", "n = len(s1) + 1\n", "\n", - "# Let's use just a couple of arrays to store the previous skew diagonals.\n", - "# Let's imagine that our Levenshtein matrix is gonna have 5x5 size for two words of length 4.\n", - "# B C D E << s2 characters: BCDE\n", - "# + ---------\n", - "# | a b c d e\n", - "# F | f g h i j\n", - "# K | k l m n o\n", - "# P | p q r s t\n", - "# U | u v w x y\n", - "# ^\n", - "# ^ s1 characters: FKPU\n", "following = np.zeros(n, dtype=np.uint) # let's assume we are computing the main skew diagonal: [u, q, m, i, e]\n", "current = np.zeros(n, dtype=np.uint) # will contain: [p, l, h, e]\n", "previous = np.zeros(n, dtype=np.uint) # will contain: [k, g, c]\n", @@ -269,71 +370,46 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To feel safer, while designing our alternative traversal algorithm, let's define an extraction function, that will get the values of a certain skewed diagonal." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "def get_skewed_diagonal(matrix: np.ndarray, index: int):\n", - " flipped_matrix = np.fliplr(matrix)\n", - " return np.flip(np.diag(flipped_matrix, k= matrix.shape[1] - index - 1))" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "matrix = np.array([[1, 2, 3],\n", - " [4, 5, 6],\n", - " [7, 8, 9]])\n", - "assert np.all(get_skewed_diagonal(matrix, 2) == [7, 5, 3])\n", - "assert np.all(get_skewed_diagonal(matrix, 1) == [4, 2])\n", - "assert np.all(get_skewed_diagonal(matrix, 4) == [9])" + "Now we can rewrite the first nested loop for the upper triangle of the matrix in NumPy primitives, using it's `np.minimum` function to calculate the minimum of three values." ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(array([5, 3, 2, 2, 3, 5, 0], dtype=uint64),\n", - " array([6, 4, 3, 2, 3, 4, 6], dtype=uint64),\n", - " array([6, 4, 3, 2, 3, 4, 6], dtype=uint64))" + "(array([3, 2, 2, 3, 0], dtype=uint64),\n", + " array([4, 3, 2, 3, 4], dtype=uint64),\n", + " array([4, 3, 2, 3, 4], dtype=uint64))" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# To evaluate every subsequent entry:\n", - "following_skew_diagonal_idx = 2\n", - "while following_skew_diagonal_idx < n:\n", - " following_skew_diagonal_length = following_skew_diagonal_idx + 1\n", + "next_diagonal_index = 2\n", + "while next_diagonal_index < n:\n", + " next_skew_diagonal_length = next_diagonal_index + 1\n", "\n", - " old_substitution_costs = previous[:following_skew_diagonal_length - 2]\n", - " added_substitution_costs = [s1[following_skew_diagonal_idx - i - 2] != s2[i] for i in range(following_skew_diagonal_length - 2)]\n", + " old_substitution_costs = previous[:next_skew_diagonal_length - 2]\n", + " added_substitution_costs = [s1[next_diagonal_index - i - 2] != s2[i] for i in range(next_skew_diagonal_length - 2)]\n", " substitution_costs = old_substitution_costs + added_substitution_costs\n", "\n", - " following[1:following_skew_diagonal_length-1] = np.minimum(current[1:following_skew_diagonal_length-1] + 1, current[:following_skew_diagonal_length-2] + 1) # Insertions or deletions\n", - " following[1:following_skew_diagonal_length-1] = np.minimum(following[1:following_skew_diagonal_length-1], substitution_costs) # Substitutions\n", - " following[0] = following_skew_diagonal_idx\n", - " following[following_skew_diagonal_length-1] = following_skew_diagonal_idx\n", - " assert np.all(following[:following_skew_diagonal_length] == get_skewed_diagonal(baseline, following_skew_diagonal_idx))\n", + " following[1:next_skew_diagonal_length - 1] = np.minimum(current[1:next_skew_diagonal_length - 1] + 1, current[:next_skew_diagonal_length - 2] + 1) # Insertions or deletions\n", + " following[1:next_skew_diagonal_length - 1] = np.minimum(following[1:next_skew_diagonal_length - 1], substitution_costs) # Substitutions\n", + " following[0] = next_diagonal_index\n", + " following[next_skew_diagonal_length - 1] = next_diagonal_index\n", + " assert np.all(following[:next_skew_diagonal_length] == get_skewed_diagonal(matrix_wf, next_diagonal_index))\n", " \n", " previous[:] = current[:]\n", " current[:] = following[:]\n", - " following_skew_diagonal_idx += 1\n", + " next_diagonal_index += 1\n", "\n", "previous, current, following # Log the state" ] @@ -342,74 +418,107 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal index on either side, we will be cropping those values out." + "By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a larger diagonal.\n", + "From now onwards, we will be shrinking.\n", + "Instead of adding value equal to the skewed diagonal index on either side, we will be cropping those values out." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(array([5, 4, 5, 5, 5, 6, 0], dtype=uint64),\n", - " array([4, 5, 4, 5, 5, 5, 6], dtype=uint64),\n", - " array([4, 5, 4, 5, 5, 5, 6], dtype=uint64))" + "(array([4, 4, 4, 4, 0], dtype=uint64),\n", + " array([4, 4, 4, 4, 4], dtype=uint64),\n", + " array([4, 4, 4, 4, 4], dtype=uint64))" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "while following_skew_diagonal_idx < 2 * n - 1:\n", - " following_skew_diagonal_length = 2 * n - 1 - following_skew_diagonal_idx\n", - " old_substitution_costs = previous[:following_skew_diagonal_length]\n", - " added_substitution_costs = [s1[len(s1) - i - 1] != s2[following_skew_diagonal_idx - n + i] for i in range(following_skew_diagonal_length)]\n", + "while next_diagonal_index < 2 * n - 1:\n", + " next_skew_diagonal_length = 2 * n - 1 - next_diagonal_index\n", + " old_substitution_costs = previous[:next_skew_diagonal_length]\n", + " added_substitution_costs = [s1[len(s1) - i - 1] != s2[next_diagonal_index - n + i] for i in range(next_skew_diagonal_length)]\n", " substitution_costs = old_substitution_costs + added_substitution_costs\n", " \n", - " following[:following_skew_diagonal_length] = np.minimum(current[:following_skew_diagonal_length] + 1, current[1:following_skew_diagonal_length+1] + 1) # Insertions or deletions\n", - " following[:following_skew_diagonal_length] = np.minimum(following[:following_skew_diagonal_length], substitution_costs) # Substitutions\n", - " assert np.all(following[:following_skew_diagonal_length] == get_skewed_diagonal(baseline, following_skew_diagonal_idx)), f\"\\n{following[:following_skew_diagonal_length]} not equal to \\n{get_skewed_diagonal(baseline, following_skew_diagonal_idx)}\"\n", + " following[:next_skew_diagonal_length] = np.minimum(current[:next_skew_diagonal_length] + 1, current[1 : next_skew_diagonal_length + 1] + 1) # Insertions or deletions\n", + " following[:next_skew_diagonal_length] = np.minimum(following[:next_skew_diagonal_length], substitution_costs) # Substitutions\n", + " assert np.all(following[:next_skew_diagonal_length] == get_skewed_diagonal(matrix_wf, next_diagonal_index)), f\"\\n{following[:next_skew_diagonal_length]} not equal to \\n{get_skewed_diagonal(baseline, next_diagonal_index)}\"\n", " \n", - " previous[:following_skew_diagonal_length] = current[1:following_skew_diagonal_length+1]\n", - " current[:following_skew_diagonal_length] = following[:following_skew_diagonal_length]\n", - " following_skew_diagonal_idx += 1\n", + " previous[:next_skew_diagonal_length] = current[1:next_skew_diagonal_length + 1]\n", + " current[:next_skew_diagonal_length] = following[:next_skew_diagonal_length]\n", + " next_diagonal_index += 1\n", "\n", "previous, current, following # Log the state" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "assert distance == following[0], f\"{distance = } != {following[0] = }\"" + "assert distance_wf == following[0], f\"{distance_wf = } != {following[0] = }\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Generalizing to Non-Square Matrices" + "## Generalizing to Non-Square Matrices\n", + "\n", + "Let's imaging 2 inputs of length 3 and 5: `\"KPU\"` and `\"BCDEF\"`:\n", + "\n", + "| | | **B** | **C** | **D** | **E** | **F** |\n", + "| ----- | --- | ----- | ----- | ----- | ----- | ----- |\n", + "| | a | b | c | d | e | f |\n", + "| **K** | g | h | i | j | k | l |\n", + "| **P** | m | n | o | p | q | r |\n", + "| **U** | s | t | u | v | w | x |\n", + "\n", + "At any point we will be working with 3 diagonals:\n", + "\n", + "- `previous` set to `[a]` at start\n", + "- `current` set to `[g, b]` at start\n", + "- `next` set to `[m, h, c]` at start\n", + "\n", + "Once we proceed to for X cycles:\n", + "\n", + "- `previous` set to `[s, n, i, d]`\n", + "- `current` set to `[t, o, j, e]`\n", + "- `next` set to `[u, p, k, f]`\n" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "def skewed_diagonals(s1, s2, verbose: bool = False) -> int:\n", - " shorter, longer = (s1, s2) if len(s1) < len(s2) else (s2, s1) \n", + "from typing import Optional\n", + "\n", + "def skewed_diagonals(\n", + " s1: str, s2: str, \n", + " verbose: bool = False, \n", + " baseline: Optional[np.ndarray] = None) -> Tuple[int, np.ndarray]:\n", + " \n", + " shorter, longer = (s1, s2) if len(s1) <= len(s2) else (s2, s1) \n", + " baseline = baseline if len(s1) <= len(s2) else baseline.T\n", " shorter_dim = len(shorter) + 1\n", " longer_dim = len(longer) + 1\n", - " # Create a matrix of size (len(s1)+1) x (len(s2)+1)\n", - " matrix = np.zeros((len(shorter) + 1, len(longer) + 1), dtype=int)\n", - " matrix[:, :] = 99\n", + " if verbose:\n", + " print(f\"{shorter=}, {longer=}, {shorter_dim=}, {longer_dim=}\")\n", + " \n", + " # Create a matrix of size (shorter_dim) x (longer_dim)\n", + " matrix = np.zeros((shorter_dim, longer_dim), dtype=int)\n", + " matrix[:, :] = longer_dim + 1 # or +inf \n", "\n", " # Initialize the first column and first row of the matrix\n", " for i in range(shorter_dim):\n", @@ -417,111 +526,110 @@ " for j in range(longer_dim):\n", " matrix[0, j] = j\n", "\n", - " # Let's say we are dealing with 6 and 9 letter words.\n", - " # The matrix will have size 7 x 10, parameterized as (shorter_dim x longer_dim).\n", + " # Let's say we are dealing with 3 and 5 letter words.\n", + " # The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim).\n", " # It will have:\n", - " # - 8 diagonals of increasing length, at positions: 0, 1, 2, 3, 4, 5, 6, 7.\n", - " # - 2 diagonals of fixed length, at positions: 8, 9.\n", - " # - 8 diagonals of decreasing length, at positions: 10, 11, 12, 13, 14, 15, 16, 17.\n", - " skew_diagonals_count = 2 * longer_dim - 1\n", + " # - 4 diagonals of increasing length, at positions: 0, 1, 2, 3.\n", + " # - 2 diagonals of fixed length, at positions: 4, 5.\n", + " # - 3 diagonals of decreasing length, at positions: 6, 7, 8.\n", + " diagonals_count = shorter_dim + longer_dim - 1\n", "\n", " # Same as with square matrices, the 0th diagonal contains - just one element - zero - skipping it.\n", " # Same as with square matrices, the 1st diagonal contains the values 1 and 1 - skipping it.\n", " # Now let's handle the rest of the upper triangle.\n", - " for skew_diagonal_idx in range(2, shorter_dim + 1):\n", - " skew_diagonal_length = (skew_diagonal_idx + 1)\n", - " for offset_within_skew_diagonal in range(1, skew_diagonal_length-1): #! Skip the first column & row\n", + " for skew_diagonal_index in range(2, shorter_dim):\n", + " skew_diagonal_length = (skew_diagonal_index + 1)\n", + " for offset_within_diagonal in range(1, skew_diagonal_length - 1): #! Skip the first column & row\n", " # If we haven't passed the main skew diagonal yet, \n", " # then we have to skip the first and the last operation,\n", " # as those are already pre-populated and form the first column \n", " # and the first row of the Levenshtein matrix respectively.\n", - " i = skew_diagonal_idx - offset_within_skew_diagonal\n", - " j = offset_within_skew_diagonal\n", + " i = skew_diagonal_index - offset_within_diagonal\n", + " j = offset_within_diagonal\n", " if verbose:\n", - " print(f\"top left triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " print(f\"top left triangle: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", " shorter_char = shorter[i - 1]\n", " longer_char = longer[j - 1]\n", " substitution_cost = shorter_char != longer_char\n", " matrix[i, j] = min(\n", - " matrix[i - 1, j] + 1, # Deletion\n", - " matrix[i, j - 1] + 1, # Insertion\n", - " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " matrix[i - 1, j] + 1, #? Deletion cost\n", + " matrix[i, j - 1] + 1, #? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", " )\n", " \n", + " if baseline is not None:\n", + " assert matrix[i, j] == baseline[i, j], f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", + " \n", " # Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. \n", - " for skew_diagonal_idx in range(shorter_dim + 1, longer_dim + 1):\n", + " for skew_diagonal_index in range(shorter_dim, longer_dim):\n", " skew_diagonal_length = shorter_dim\n", - " for offset_within_skew_diagonal in range(skew_diagonal_length):\n", - " i = shorter_dim - offset_within_skew_diagonal - 1\n", - " j = offset_within_skew_diagonal + 1\n", + " for offset_within_diagonal in range(skew_diagonal_length - 1): #! Skip the first row\n", + " i = shorter_dim - offset_within_diagonal - 1\n", + " j = skew_diagonal_index - shorter_dim + offset_within_diagonal + 1\n", " if verbose:\n", - " print(f\"anti-band: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " print(f\"anti-band: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", " shorter_char = shorter[i - 1]\n", " longer_char = longer[j - 1]\n", " substitution_cost = shorter_char != longer_char\n", " matrix[i, j] = min(\n", - " matrix[i - 1, j] + 1, # Deletion\n", - " matrix[i, j - 1] + 1, # Insertion\n", - " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " matrix[i - 1, j] + 1, #? Deletion cost\n", + " matrix[i, j - 1] + 1, #? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", " )\n", + " \n", + " if baseline is not None:\n", + " assert matrix[i, j] == baseline[i, j], f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", " \n", " # Now let's handle the bottom right triangle.\n", - " for skew_diagonal_idx in range(longer_dim + 1, skew_diagonals_count):\n", - " skew_diagonal_length = 2 * longer_dim - skew_diagonal_idx - 1\n", - " for offset_within_skew_diagonal in range(skew_diagonal_length):\n", - " i = shorter_dim - offset_within_skew_diagonal - 1\n", - " j = skew_diagonal_idx - longer_dim + offset_within_skew_diagonal + 1\n", + " for skew_diagonal_index in range(longer_dim, diagonals_count):\n", + " skew_diagonal_length = diagonals_count - skew_diagonal_index\n", + " for offset_within_diagonal in range(skew_diagonal_length):\n", + " i = shorter_dim - offset_within_diagonal - 1\n", + " j = skew_diagonal_index - shorter_dim + offset_within_diagonal + 1\n", " if verbose:\n", - " print(f\"bottom right triangle: {skew_diagonal_idx=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " print(f\"bottom right triangle: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\")\n", + " assert (i - 1) >= 0 and (i - 1) < len(shorter), f\"{i = }\"\n", + " assert (j - 1) >= 0 and (j - 1) < len(longer), f\"{j = }\"\n", " shorter_char = shorter[i - 1]\n", " longer_char = longer[j - 1]\n", " substitution_cost = shorter_char != longer_char\n", " matrix[i, j] = min(\n", - " matrix[i - 1, j] + 1, # Deletion\n", - " matrix[i, j - 1] + 1, # Insertion\n", - " matrix[i - 1, j - 1] + substitution_cost, # Substitution\n", + " matrix[i - 1, j] + 1, #? Deletion cost\n", + " matrix[i, j - 1] + 1, #? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, #? Substitution cost\n", " )\n", + " \n", + " if baseline is not None:\n", + " assert matrix[i, j] == baseline[i, j], f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", "\n", " # Return the Levenshtein distance\n", - " return matrix[len(shorter), len(longer)], matrix" + " distance = matrix[len(shorter), len(longer)]\n", + " if len(s1) > len(s2):\n", + " matrix = matrix.T\n", + " return distance, matrix" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "('listeners',\n", - " 'silents',\n", - " 'distance = np.int64(5)',\n", - " array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n", - " [1, 1, 2, 2, 3, 4, 5, 6, 7, 8],\n", - " [2, 2, 1, 2, 3, 4, 5, 6, 7, 8],\n", - " [3, 2, 2, 2, 3, 4, 5, 6, 7, 8],\n", - " [4, 3, 3, 3, 3, 3, 4, 5, 6, 7],\n", - " [5, 4, 4, 4, 4, 4, 3, 4, 5, 6],\n", - " [6, 5, 5, 5, 4, 5, 4, 4, 5, 6],\n", - " [7, 6, 6, 5, 5, 5, 5, 5, 5, 5]]))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "s1 = \"listeners\"\n", - "s2 = \"silents\"\n", - "distance, baseline = skewed_diagonals(s1, s2)\n", - "s1, s2, f\"{distance = }\", baseline" + "import random\n", + "for _ in range(100):\n", + " len1 = random.randint(1, 50)\n", + " len2 = random.randint(1, 50)\n", + " s1 = ''.join(random.choices(\"abc\", k=len1))\n", + " s2 = ''.join(random.choices(\"abc\", k=len2))\n", + " distance_wf, matrix_wf = wagner_fisher(s1, s2)\n", + " distance_sd, matrix_sd = skewed_diagonals(s1, s2, baseline=matrix_wf, verbose=False)\n", + " assert distance_wf == distance_sd, f\"{distance_wf = } != {distance_sd = }\"\n", + " assert np.all(matrix_wf == matrix_sd), f\"{matrix_wf = }\\n{matrix_sd = }\"" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -529,7 +637,7 @@ "text/plain": [ "('listeners',\n", " 'silents',\n", - " 'distance = np.int64(5)',\n", + " 'distance_sd = np.int64(5)',\n", " array([[0, 1, 2, 3, 4, 5, 6, 7],\n", " [1, 1, 2, 2, 3, 4, 5, 6],\n", " [2, 2, 1, 2, 3, 4, 5, 6],\n", @@ -542,36 +650,238 @@ " [9, 8, 8, 8, 7, 6, 6, 5]]))" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "distance, baseline = wagner_fisher(s1, s2)\n", - "s1, s2, f\"{distance = }\", baseline" + "s1 = \"listeners\"\n", + "s2 = \"silents\"\n", + "distance_wf, matrix_wf = wagner_fisher(s1, s2)\n", + "distance_sd, matrix_sd = skewed_diagonals(s1, s2, baseline=matrix_wf)\n", + "s1, s2, f\"{distance_sd = }\", matrix_sd" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bounding the Error\n", + "\n", + "It's easy to spot that the algorithm can be further optimized if we are dealing with \"bounded\" edit distances, where the maximum allowed number of edits is known in advance.\n", + "In such cases, we only need to evaluate a band around the main diagonal, and can skip the rest of the matrix.\n", + "For the bound $k$, we only need to evaluate $2k+1$ diagonals." ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ - "s1 = ''.join(random.choices(\"abcd\", k=5))\n", - "s2 = ''.join(random.choices(\"abcd\", k=6))\n", - "distance_v0, baseline_v0 = wagner_fisher(s1, s2)\n", - "distance_v2, baseline_v2 = skewed_diagonals(s1, s2, verbose=False)\n", - "assert distance_v0 == distance_v2, f\"{distance_v0 = } != {distance_v2 = }\"\n", - "assert np.all(baseline_v0 == baseline_v2), f\"{baseline_v0 = }\\n{baseline_v2 = }\"" + "from typing import Optional\n", + "\n", + "\n", + "def bounded_skewed_diagonals(\n", + " s1: str,\n", + " s2: str,\n", + " verbose: bool = False,\n", + " bound: Optional[int] = None,\n", + " baseline: Optional[np.ndarray] = None,\n", + ") -> Tuple[int, np.ndarray]:\n", + "\n", + " shorter, longer = (s1, s2) if len(s1) <= len(s2) else (s2, s1)\n", + " baseline = baseline if len(s1) <= len(s2) else baseline.T\n", + " shorter_dim = len(shorter) + 1\n", + " longer_dim = len(longer) + 1\n", + " if verbose:\n", + " print(f\"{shorter=}, {longer=}, {shorter_dim=}, {longer_dim=}\")\n", + "\n", + " # Create a matrix of size (shorter_dim) x (longer_dim)\n", + " matrix = np.zeros((shorter_dim, longer_dim), dtype=int)\n", + " matrix[:, :] = np.iinfo(matrix.dtype).max\n", + "\n", + " # Initialize the first column and first row of the matrix\n", + " for i in range(shorter_dim):\n", + " matrix[i, 0] = i\n", + " for j in range(longer_dim):\n", + " matrix[0, j] = j\n", + "\n", + " # Let's say we are dealing with 3 and 5 letter words.\n", + " # The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim).\n", + " # It will have:\n", + " # - 4 diagonals of increasing length, at positions: 0, 1, 2, 3.\n", + " # - 2 diagonals of fixed length, at positions: 4, 5.\n", + " # - 3 diagonals of decreasing length, at positions: 6, 7, 8.\n", + " diagonals_count = shorter_dim + longer_dim - 1\n", + "\n", + " # Same as with square matrices, the 0th diagonal contains - just one element - zero - skipping it.\n", + " # Same as with square matrices, the 1st diagonal contains the values 1 and 1 - skipping it.\n", + " # In unbounded case, we the upper triangle will have `shorter_dim` rows and columns.\n", + " # In bounded case, we will have `min(bound, shorter_dim)` rows and columns.\n", + " upper_triangle_dim = min(bound, shorter_dim) if bound is not None else shorter_dim\n", + " for skew_diagonal_index in range(2, upper_triangle_dim):\n", + " skew_diagonal_length = skew_diagonal_index + 1\n", + " for offset_within_diagonal in range(\n", + " 1, skew_diagonal_length - 1\n", + " ): #! Skip the first column & row\n", + " # If we haven't passed the main skew diagonal yet,\n", + " # then we have to skip the first and the last operation,\n", + " # as those are already pre-populated and form the first column\n", + " # and the first row of the Levenshtein matrix respectively.\n", + " i = skew_diagonal_index - offset_within_diagonal\n", + " j = offset_within_diagonal\n", + " if verbose:\n", + " print(\n", + " f\"top left triangle: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\"\n", + " )\n", + " shorter_char = shorter[i - 1]\n", + " longer_char = longer[j - 1]\n", + " substitution_cost = shorter_char != longer_char\n", + " matrix[i, j] = min(\n", + " matrix[i - 1, j] + 1, # ? Deletion cost\n", + " matrix[i, j - 1] + 1, # ? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, # ? Substitution cost\n", + " )\n", + "\n", + " # Validation checks:\n", + " if baseline is not None:\n", + " assert (\n", + " matrix[i, j] == baseline[i, j]\n", + " ), f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", + "\n", + " # Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles.\n", + " # In the unbounded case, we will enumerate diagonal indices from `shorter_dim` to `longer_dim`.\n", + " # In the bounded case, we go through the same \n", + " for skew_diagonal_index in range(shorter_dim, longer_dim):\n", + " skew_diagonal_length = shorter_dim\n", + " for offset_within_diagonal in range(\n", + " skew_diagonal_length - 1\n", + " ): #! Skip the first row\n", + " i = shorter_dim - offset_within_diagonal - 1\n", + " j = skew_diagonal_index - shorter_dim + offset_within_diagonal + 1\n", + " if verbose:\n", + " print(\n", + " f\"anti-band: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\"\n", + " )\n", + " shorter_char = shorter[i - 1]\n", + " longer_char = longer[j - 1]\n", + " substitution_cost = shorter_char != longer_char\n", + " matrix[i, j] = min(\n", + " matrix[i - 1, j] + 1, # ? Deletion cost\n", + " matrix[i, j - 1] + 1, # ? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, # ? Substitution cost\n", + " )\n", + "\n", + " if baseline is not None:\n", + " assert (\n", + " matrix[i, j] == baseline[i, j]\n", + " ), f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", + "\n", + " # Now let's handle the bottom right triangle.\n", + " for skew_diagonal_index in range(longer_dim, diagonals_count):\n", + " skew_diagonal_length = diagonals_count - skew_diagonal_index\n", + " for offset_within_diagonal in range(skew_diagonal_length):\n", + " i = shorter_dim - offset_within_diagonal - 1\n", + " j = skew_diagonal_index - shorter_dim + offset_within_diagonal + 1\n", + " if verbose:\n", + " print(\n", + " f\"bottom right triangle: {skew_diagonal_index=}, {skew_diagonal_length=}, {i=}, {j=}\"\n", + " )\n", + " assert (i - 1) >= 0 and (i - 1) < len(shorter), f\"{i = }\"\n", + " assert (j - 1) >= 0 and (j - 1) < len(longer), f\"{j = }\"\n", + " shorter_char = shorter[i - 1]\n", + " longer_char = longer[j - 1]\n", + " substitution_cost = shorter_char != longer_char\n", + " matrix[i, j] = min(\n", + " matrix[i - 1, j] + 1, # ? Deletion cost\n", + " matrix[i, j - 1] + 1, # ? Insertion cost\n", + " matrix[i - 1, j - 1] + substitution_cost, # ? Substitution cost\n", + " )\n", + "\n", + " if baseline is not None:\n", + " assert (\n", + " matrix[i, j] == baseline[i, j]\n", + " ), f\"{matrix[i, j]} != {baseline[i, j]} at {i=}, {j=}\"\n", + "\n", + " # Return the Levenshtein distance\n", + " distance = matrix[len(shorter), len(longer)]\n", + " if len(s1) > len(s2):\n", + " matrix = matrix.T\n", + " return distance, matrix" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Putting Everything Together" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "def vectorized_skewed_diagonals(\n", + " s1: str, s2: str, \n", + " verbose: bool = False, \n", + " baseline: Optional[np.ndarray] = None) -> Tuple[int, np.ndarray]:\n", + " \n", + " shorter, longer = (s1, s2) if len(s1) <= len(s2) else (s2, s1) \n", + " baseline = baseline if len(s1) <= len(s2) else baseline.T\n", + " shorter_dim = len(shorter) + 1\n", + " longer_dim = len(longer) + 1\n", + " if verbose:\n", + " print(f\"{shorter=}, {longer=}, {shorter_dim=}, {longer_dim=}\")\n", + " \n", + " # Create a matrix of size (shorter_dim) x (longer_dim)\n", + " matrix = np.zeros((shorter_dim, longer_dim), dtype=int)\n", + " matrix[:, :] = longer_dim + 1 # or +inf \n", + "\n", + " # Initialize the first column and first row of the matrix\n", + " for i in range(shorter_dim):\n", + " matrix[i, 0] = i\n", + " for j in range(longer_dim):\n", + " matrix[0, j] = j\n", + "\n", + " # Let's say we are dealing with 3 and 5 letter words.\n", + " # The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim).\n", + " # It will have:\n", + " # - 4 diagonals of increasing length, at positions: 0, 1, 2, 3.\n", + " # - 2 diagonals of fixed length, at positions: 4, 5.\n", + " # - 3 diagonals of decreasing length, at positions: 6, 7, 8.\n", + " diagonals_count = shorter_dim + longer_dim - 1\n", + "\n", + " # Same as with square matrices, the 0th diagonal contains - just one element - zero - skipping it.\n", + " # Same as with square matrices, the 1st diagonal contains the values 1 and 1 - skipping it.\n", + " # Now let's handle the rest of the upper triangle.\n", + " next_diagonal_index = 2\n", + " while next_diagonal_index < shorter_dim:\n", + " next_skew_diagonal_length = next_diagonal_index + 1\n", + "\n", + " old_substitution_costs = previous[:next_skew_diagonal_length - 2]\n", + " added_substitution_costs = [shorter[next_diagonal_index - offset_within_diagonal - 2] != longer[offset_within_diagonal] for offset_within_diagonal in range(next_skew_diagonal_length - 2)]\n", + " substitution_costs = old_substitution_costs + added_substitution_costs\n", + "\n", + " following[1:next_skew_diagonal_length - 1] = np.minimum(current[1:next_skew_diagonal_length - 1] + 1, current[:next_skew_diagonal_length - 2] + 1) # Insertions or deletions\n", + " following[1:next_skew_diagonal_length - 1] = np.minimum(following[1:next_skew_diagonal_length - 1], substitution_costs) # Substitutions\n", + " following[0] = next_diagonal_index\n", + " following[next_skew_diagonal_length - 1] = next_diagonal_index\n", + " assert np.all(following[:next_skew_diagonal_length] == get_skewed_diagonal(baseline, next_diagonal_index))\n", + " \n", + " previous[:] = current[:]\n", + " current[:] = following[:]\n", + " next_diagonal_index += 1\n", + " \n", + " # Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. \n", + " while next_diagonal_index < longer_dim:\n", + " next_skew_diagonal_length = shorter_dim\n", + " \n", + " ..." + ] } ], "metadata": { From d0678f87cfb00a4268e584aaf07c341f2f7f7c50 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:13:48 +0000 Subject: [PATCH 009/751] Fix: Wrong env. variable names --- CONTRIBUTING.md | 2 +- c/lib.c | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index da369582..524d6c49 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -116,7 +116,7 @@ Replacing the default compiler is not recommended, as it may break the system, b ```bash brew install llvm -cmake -D CMAKE_BUILD_TYPE=Release -D SIMSIMD_BUILD_TESTS=1 \ +cmake -D CMAKE_BUILD_TYPE=Release -D STRINGZILLA_BUILD_TEST=1 \ -D CMAKE_C_COMPILER="$(brew --prefix llvm)/bin/clang" \ -D CMAKE_CXX_COMPILER="$(brew --prefix llvm)/bin/clang++" \ -B build_release diff --git a/c/lib.c b/c/lib.c index ee48400e..19d22ba5 100644 --- a/c/lib.c +++ b/c/lib.c @@ -92,7 +92,7 @@ SZ_DYNAMIC sz_capability_t sz_capabilities(void) { (sz_cap_x86_gfni_k * (supports_gfni)) | // (sz_cap_serial_k)); -#endif // SIMSIMD_TARGET_X86 +#endif // SZ_TARGET_X86 #if SZ_USE_ARM_NEON || SZ_USE_ARM_SVE @@ -107,7 +107,7 @@ SZ_DYNAMIC sz_capability_t sz_capabilities(void) { (sz_cap_arm_neon_k * supports_neon) | // (sz_cap_serial_k)); -#endif // SIMSIMD_TARGET_ARM +#endif // SZ_TARGET_ARM return sz_cap_serial_k; } From ecb377541d0c706cf8997faff4f026b07e3f76f3 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:18:34 +0000 Subject: [PATCH 010/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/types.h --- include/stringzilla/{stringzilla.h => types.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{stringzilla.h => types.h} (100%) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/types.h similarity index 100% rename from include/stringzilla/stringzilla.h rename to include/stringzilla/types.h From 22e3d1e34d62d68c1e89df7c8bdc201faa18a9de Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:18:34 +0000 Subject: [PATCH 011/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/types.h --- include/stringzilla/stringzilla.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/stringzilla.h => temp (100%) diff --git a/include/stringzilla/stringzilla.h b/temp similarity index 100% rename from include/stringzilla/stringzilla.h rename to temp From 8cb0742b2d1b31b61fac5272f17017953c6677e6 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:18:34 +0000 Subject: [PATCH 012/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/types.h --- temp => include/stringzilla/stringzilla.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/stringzilla.h (100%) diff --git a/temp b/include/stringzilla/stringzilla.h similarity index 100% rename from temp rename to include/stringzilla/stringzilla.h From 9e577be71dcd2e20854bf55f08c54854b3e82989 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:19:11 +0000 Subject: [PATCH 013/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/find.h --- include/stringzilla/{stringzilla.h => find.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{stringzilla.h => find.h} (100%) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/find.h similarity index 100% rename from include/stringzilla/stringzilla.h rename to include/stringzilla/find.h From 14ba3bf3c43408438a7de9ad57118c747c1347b1 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:19:11 +0000 Subject: [PATCH 014/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/find.h --- include/stringzilla/stringzilla.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/stringzilla.h => temp (100%) diff --git a/include/stringzilla/stringzilla.h b/temp similarity index 100% rename from include/stringzilla/stringzilla.h rename to temp From 974ed78822dc0b519dd61bc1c4dc18d59fe4ad15 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:19:11 +0000 Subject: [PATCH 015/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/find.h --- temp => include/stringzilla/stringzilla.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/stringzilla.h (100%) diff --git a/temp b/include/stringzilla/stringzilla.h similarity index 100% rename from temp rename to include/stringzilla/stringzilla.h From 9e9f2567d052d635722921a1d70ec63d69ec6669 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:19:29 +0000 Subject: [PATCH 016/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/hash.h --- include/stringzilla/{stringzilla.h => hash.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{stringzilla.h => hash.h} (100%) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/hash.h similarity index 100% rename from include/stringzilla/stringzilla.h rename to include/stringzilla/hash.h From 08d0a20d35d3b29a44b9c8a826d53435c3ef839c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:19:29 +0000 Subject: [PATCH 017/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/hash.h --- include/stringzilla/stringzilla.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/stringzilla.h => temp (100%) diff --git a/include/stringzilla/stringzilla.h b/temp similarity index 100% rename from include/stringzilla/stringzilla.h rename to temp From 1f60e6d7c81f0e285e594eb63fee6119e05a3e69 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:19:29 +0000 Subject: [PATCH 018/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/hash.h --- temp => include/stringzilla/stringzilla.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/stringzilla.h (100%) diff --git a/temp b/include/stringzilla/stringzilla.h similarity index 100% rename from temp rename to include/stringzilla/stringzilla.h From d74e5dca2e62eb0078cb2ebacc0dac2b8bb92d54 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:20:12 +0000 Subject: [PATCH 019/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/similarity.h --- include/stringzilla/{stringzilla.h => similarity.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{stringzilla.h => similarity.h} (100%) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/similarity.h similarity index 100% rename from include/stringzilla/stringzilla.h rename to include/stringzilla/similarity.h From 10d829efcb8ed4cfa5f2db4050f8403184484423 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:20:12 +0000 Subject: [PATCH 020/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/similarity.h --- include/stringzilla/stringzilla.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/stringzilla.h => temp (100%) diff --git a/include/stringzilla/stringzilla.h b/temp similarity index 100% rename from include/stringzilla/stringzilla.h rename to temp From e23c35ff2c2d4ccb752f4ffbf9b6f39a1677b532 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:20:12 +0000 Subject: [PATCH 021/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/similarity.h --- temp => include/stringzilla/stringzilla.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/stringzilla.h (100%) diff --git a/temp b/include/stringzilla/stringzilla.h similarity index 100% rename from temp rename to include/stringzilla/stringzilla.h From 3f9c248fbf59add2246055462e8fc19dc9f1693b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:21:30 +0000 Subject: [PATCH 022/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/small_string.h --- include/stringzilla/{stringzilla.h => small_string.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{stringzilla.h => small_string.h} (100%) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/small_string.h similarity index 100% rename from include/stringzilla/stringzilla.h rename to include/stringzilla/small_string.h From 89c46810c2f9bfafa31f8592339f9a1b45dcc245 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:21:30 +0000 Subject: [PATCH 023/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/small_string.h --- include/stringzilla/stringzilla.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/stringzilla.h => temp (100%) diff --git a/include/stringzilla/stringzilla.h b/temp similarity index 100% rename from include/stringzilla/stringzilla.h rename to temp From 3464cb428ae9a8721ab82a8c4bff214aa9ce6254 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:21:30 +0000 Subject: [PATCH 024/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/small_string.h --- temp => include/stringzilla/stringzilla.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/stringzilla.h (100%) diff --git a/temp b/include/stringzilla/stringzilla.h similarity index 100% rename from temp rename to include/stringzilla/stringzilla.h From 085d2d3c8b99e0f90d320dd027040e554e410929 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:22:12 +0000 Subject: [PATCH 025/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/sort.h --- include/stringzilla/{stringzilla.h => sort.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{stringzilla.h => sort.h} (100%) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/sort.h similarity index 100% rename from include/stringzilla/stringzilla.h rename to include/stringzilla/sort.h From cbfe5c7ac6371047eae88621b092297474d0b82a Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:22:12 +0000 Subject: [PATCH 026/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/sort.h --- include/stringzilla/stringzilla.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/stringzilla.h => temp (100%) diff --git a/include/stringzilla/stringzilla.h b/temp similarity index 100% rename from include/stringzilla/stringzilla.h rename to temp From c357c3ea756523d3bcc8d8f25068ad08aef5456d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 11:22:12 +0000 Subject: [PATCH 027/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/sort.h --- temp => include/stringzilla/stringzilla.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/stringzilla.h (100%) diff --git a/temp b/include/stringzilla/stringzilla.h similarity index 100% rename from temp rename to include/stringzilla/stringzilla.h From 66778d6b2b3aa0eed27e32fbdceef79b8c54eda5 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:14:26 +0000 Subject: [PATCH 028/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/memory.h --- include/stringzilla/{stringzilla.h => memory.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{stringzilla.h => memory.h} (100%) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/memory.h similarity index 100% rename from include/stringzilla/stringzilla.h rename to include/stringzilla/memory.h From 45e57eefd796841cbd14ee7f75ec42b42b5bde0c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:14:26 +0000 Subject: [PATCH 029/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/memory.h --- include/stringzilla/stringzilla.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/stringzilla.h => temp (100%) diff --git a/include/stringzilla/stringzilla.h b/temp similarity index 100% rename from include/stringzilla/stringzilla.h rename to temp From 2f7652141bd8dc3c2c38ab34321567bfcdb91d93 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:14:27 +0000 Subject: [PATCH 030/751] Make: Split ./include/stringzilla/stringzilla.h to ./include/stringzilla/memory.h --- temp => include/stringzilla/stringzilla.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/stringzilla.h (100%) diff --git a/temp b/include/stringzilla/stringzilla.h similarity index 100% rename from temp rename to include/stringzilla/stringzilla.h From 2a1fcd113d217e3124f6501c38e93a318aca37f0 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:48:51 +0000 Subject: [PATCH 031/751] Fix: Filter `find.h` file --- include/stringzilla/find.h | 6944 +++++------------------------------- 1 file changed, 797 insertions(+), 6147 deletions(-) diff --git a/include/stringzilla/find.h b/include/stringzilla/find.h index de7fbcac..a51bd4c6 100644 --- a/include/stringzilla/find.h +++ b/include/stringzilla/find.h @@ -1,724 +1,32 @@ /** - * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. - * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs - * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. - * - * Consider overriding the following macros to customize the library: - * - * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. - * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. - * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. - * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. - * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html - * - * @file stringzilla.h + * @brief Hardware-accelerated sub-string and character-set search utilities. + * @file find.h * @author Ash Vardanian - */ -#ifndef STRINGZILLA_H_ -#define STRINGZILLA_H_ - -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 0 - -/** - * @brief When set to 1, the library will include the following LibC headers: and . - * In debug builds (SZ_DEBUG=1), the library will also include and . * - * You may want to disable this compiling for use in the kernel, or in embedded systems. - * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. - * https://artificial-mind.net/projects/compile-health/ - */ -#ifndef SZ_AVOID_LIBC -#define SZ_AVOID_LIBC (0) // true or false -#endif - -/** - * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address - * that is not divisible by eight. On x86 enabled by default. On ARM it's not. + * Includes core APIs: * - * Most platforms support it, but there is no industry standard way to check for those. - * This value will mostly affect the performance of the serial (SWAR) backend. - */ -#ifndef SZ_USE_MISALIGNED_LOADS -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) -#define SZ_USE_MISALIGNED_LOADS (1) // true or false -#else -#define SZ_USE_MISALIGNED_LOADS (0) // true or false -#endif -#endif - -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif - -/** - * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. - * 64-bit on most platforms where pointers are 64-bit. - * 32-bit on platforms where pointers are 32-bit. - */ -#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. -#else -#define SZ_DETECT_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. -#endif - -/** - * @brief On Big-Endian machines StringZilla will work in compatibility mode. - * This disables SWAR hacks to minimize code duplication, assuming practically - * all modern popular platforms are Little-Endian. + * - `sz_equal` + * - `sz_find` and reverse-order `sz_rfind` + * - `sz_find_byte` and reverse-order `sz_rfind_byte` + * - `sz_find_charset` and reverse-order `sz_rfind_charset` * - * This variable is hard to infer from macros reliably. It's best to set it manually. - * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). - * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. - * https://stackoverflow.com/a/27054190 - */ -#ifndef SZ_DETECT_BIG_ENDIAN -#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ - defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. - */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) -#else -#define SZ_DEBUG (0) -#endif -#endif - -/** - * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. - * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. - * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. - */ -#ifndef SZ_SWAR_THRESHOLD -#if SZ_DEBUG -#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds -#else -#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds -#endif -#endif - -/* Annotation for the public API symbols: + * Convenience functions for character-set matching: * - * - `SZ_PUBLIC` is used for functions that are part of the public API. - * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. - * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. + * - `sz_find_char_from` + * - `sz_find_char_not_from` + * - `sz_rfind_char_from` + * - `sz_rfind_char_not_from` */ -#ifndef SZ_DYNAMIC -#if SZ_DYNAMIC_DISPATCH -#if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_DYNAMIC __declspec(dllexport) -#define SZ_EXTERNAL __declspec(dllimport) -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#else -#define SZ_DYNAMIC __attribute__((visibility("default"))) -#define SZ_EXTERNAL extern -#define SZ_PUBLIC __attribute__((unused)) inline static -#define SZ_INTERNAL __attribute__((always_inline)) inline static -#endif // _WIN32 || __CYGWIN__ -#else -#define SZ_DYNAMIC inline static -#define SZ_EXTERNAL extern -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#endif // SZ_DYNAMIC_DISPATCH -#endif // SZ_DYNAMIC +#ifndef STRINGZILLA_FIND_H_ +#define STRINGZILLA_FIND_H_ -/** - * @brief Alignment macro for 64-byte alignment. - */ -#if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) -#elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) -#else -#define SZ_ALIGN64 -#endif +#include "types.h" #ifdef __cplusplus extern "C" { #endif -/* - * Let's infer the integer types or pull them from LibC, - * if that is allowed by the user. - */ -#if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` -typedef int8_t sz_i8_t; // Always 8 bits -typedef uint8_t sz_u8_t; // Always 8 bits -typedef uint16_t sz_u16_t; // Always 16 bits -typedef int32_t sz_i32_t; // Always 32 bits -typedef uint32_t sz_u32_t; // Always 32 bits -typedef uint64_t sz_u64_t; // Always 64 bits -typedef int64_t sz_i64_t; // Always 64 bits -typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits -typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits - -#else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. -typedef signed char sz_i8_t; // Always 8 bits -typedef unsigned char sz_u8_t; // Always 8 bits -typedef unsigned short sz_u16_t; // Always 16 bits -typedef int sz_i32_t; // Always 32 bits -typedef unsigned int sz_u32_t; // Always 32 bits -typedef long long sz_i64_t; // Always 64 bits -typedef unsigned long long sz_u64_t; // Always 64 bits - -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT -typedef unsigned long long sz_size_t; // 64-bit. -typedef long long sz_ssize_t; // 64-bit. -#else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT - -#endif // SZ_AVOID_LIBC - -/** - * @brief Compile-time assert macro similar to `static_assert` in C++. - */ -#define sz_static_assert(condition, name) \ - typedef struct { \ - int static_assert_##name : (condition) ? 1 : -1; \ - } sz_static_assert_##name##_t - -sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); -sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); - -#pragma region Public API - -typedef char *sz_ptr_t; // A type alias for `char *` -typedef char const *sz_cptr_t; // A type alias for `char const *` -typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions - -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings - -typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit -typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> - -/** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. - */ -typedef struct sz_string_view_t { - sz_cptr_t start; - sz_size_t length; -} sz_string_view_t; - -/** - * @brief Enumeration of SIMD capabilities of the target architecture. - * Used to introspect the supported functionality of the dynamic library. - */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability - - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability - -} sz_capability_t; - -/** - * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. - * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. - */ -SZ_DYNAMIC sz_capability_t sz_capabilities(void); - -/** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert - */ -typedef union sz_charset_t { - sz_u64_t _u64s[4]; - sz_u32_t _u32s[8]; - sz_u16_t _u16s[16]; - sz_u8_t _u8s[32]; -} sz_charset_t; - -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } - -/** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } - -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast - -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { - // Checking the bit can be done in different ways: - // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 - // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 - // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 - // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 - return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); -} - -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast -} - -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { - s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // - s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; -} - -typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); -typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); - -/** - * @brief Some complex pattern matching algorithms may require memory allocations. - * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed - */ -typedef struct sz_memory_allocator_t { - sz_memory_allocate_t allocate; - sz_memory_free_t free; - void *handle; -} sz_memory_allocator_t; - -/** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. - */ -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); - -/** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. - * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. - */ -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); - -/** - * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. - * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. - */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE -#endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length - -/** - * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). - * Differs in layout from Folly, Clang, GCC, and probably most other implementations. - * It's designed to avoid any branches on read-only operations, and can store up - * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. - * - * @section Changing Length - * - * One nice thing about this design, is that you can, in many cases, change the length of the string - * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, - * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, - * only changing the last byte containing the length. - */ -typedef union sz_string_t { - -#if !SZ_DETECT_BIG_ENDIAN - - struct external { - sz_ptr_t start; - sz_size_t length; - sz_size_t space; - sz_size_t padding; - } external; - - struct internal { - sz_ptr_t start; - sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; - } internal; - -#else - - struct external { - sz_ptr_t start; - sz_size_t space; - sz_size_t padding; - sz_size_t length; - } external; - - struct internal { - sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; - sz_u8_t length; - } internal; - -#endif - - sz_size_t words[4]; - -} sz_string_t; - -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); -typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); - -/** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. - * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. - * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap - * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. - */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Checks if all characters in the range are valid ASCII characters. - * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. - */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - -/** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. - * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. - */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); - -/** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. - */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); - -/** - * @brief Initializes a string class instance to an empty value. - */ -SZ_PUBLIC void sz_string_init(sz_string_t *string); - -/** - * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, - * alternative being - allocated in a remote region of the heap. - */ -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); - -/** - * @brief Unpacks the opaque instance of a string class into its components. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. - * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. - */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); - -/** - * @brief Unpacks only the start and length of the string. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - */ -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); - -/** - * @brief Constructs a string of a given ::length with noisy contents. - * Use the returned character pointer to populate the string. - * - * @param string String to initialize. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); - -/** - * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. - * This is beneficial, if several insertions are expected, and we want to minimize allocations. - * - * @param string String to grow. - * @param new_capacity The number of characters to reserve space for, including existing ones. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); - -/** - * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. - * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. - * Similar to `sz_string_reserve`, but changes the length of the ::string. - * - * @param string String to grow. - * @param offset Offset of the first byte to reserve space for. - * If provided offset is larger than the length, it will be capped. - * @param added_length The number of new characters to reserve space for. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); - -/** - * @brief Removes a range from a string. Changes the length, but not the capacity. - * Performs no allocations or deallocations and can't fail. - * - * @param string String to clean. - * @param offset Offset of the first byte to remove. - * @param length Number of bytes to remove. Out-of-bound ranges will be capped. - * @return Number of bytes removed. - */ -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); - -/** - * @brief Shrinks the string to fit the current length, if it's allocated on the heap. - * It's the reverse operation of ::sz_string_reserve. - * - * @param string String to shrink. - * @param allocator Memory allocator to use for the allocation. - * @return Whether the operation was successful. The only failures can come from the allocator. - * On failure, the string will remain unchanged. - */ -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); - -/** - * @brief Frees the string, if it's allocated on the heap. - * If the string is on the stack, the function clears/resets the state. - */ -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); - -#pragma endregion - -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); +#pragma region Core API /** * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. @@ -733,9 +41,6 @@ typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); */ SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - /** * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. * @@ -749,9 +54,32 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, */ SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#if SZ_USE_HASWELL +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_byte */ +SZ_PUBLIC sz_cptr_t sz_rfind_byte_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif + +#if SZ_USE_SKYLAKE +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_byte */ +SZ_PUBLIC sz_cptr_t sz_rfind_byte_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif + +#if SZ_USE_NEON +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_byte */ +SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif + /** * @brief Locates first matching substring. * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. @@ -765,9 +93,6 @@ SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, */ SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - /** * @brief Locates the last matching substring. * @@ -779,29 +104,49 @@ SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cp */ SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * +#if SZ_USE_HASWELL +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind */ +SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif + +#if SZ_USE_SKYLAKE +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind */ +SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif + +#if SZ_USE_NEON +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind */ +SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif + +/** + * @brief Finds the first character present from the ::set, present in ::text. + * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. + * May have identical implementation and performance to ::sz_rfind_charset. + * + * Useful for parsing, when we want to skip a set of characters. Examples: + * * 6 whitespaces: " \t\n\r\v\f". + * * 16 digits forming a float number: "0123456789,.eE+-". + * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. + * * 2 JSON string special characters useful to locate the end of the string: "\"\\". + * * @param text String to be scanned. * @param set Set of relevant characters. * @return Pointer to the first matching character from ::set. */ SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - /** * @brief Finds the last character present from the ::set, present in ::text. * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. @@ -819,3406 +164,680 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_ */ SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_find_charset */ +SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#pragma endregion - -#pragma region String Similarity Measures API - -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); +#if SZ_USE_HASWELL +/** @copydoc sz_find_charset */ +SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_charset */ +SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); +#if SZ_USE_ICE +/** @copydoc sz_find_charset */ +SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_charset */ +SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); +#if SZ_USE_NEON +/** @copydoc sz_find_charset */ +SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_charset */ +SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#endif -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); +#pragma endregion // Core API -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); +#pragma region Serial Implementation /** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance + * @brief Byte-level equality comparison between two strings. + * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); +SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { + sz_cptr_t const a_end = a + length; +#if SZ_USE_MISALIGNED_LOADS + if (length >= SZ_SWAR_THRESHOLD) { + sz_u64_vec_t a_vec, b_vec; + for (; a + 8 <= a_end; a += 8, b += 8) { + a_vec = sz_u64_load(a); + b_vec = sz_u64_load(b); + if (a_vec.u64 != b_vec.u64) return sz_false_k; + } + } +#endif + while (a != a_end && *a == *b) a++, b++; + return (sz_bool_t)(a_end == a); +} /** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. + * @brief Chooses the offsets of the most interesting characters in a search needle. * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. + * Search throughput can significantly deteriorate if we are matching the wrong characters. + * Say the needle is "aXaYa", and we are comparing the first, second, and last character. + * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance + * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. + * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and + * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the + * bytes will carry absolutely no value and will be equal to 0x04. */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); +SZ_INTERNAL void _sz_locate_needle_anomalies( // + sz_cptr_t start, sz_size_t length, // + sz_size_t *first, sz_size_t *second, sz_size_t *third) { -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); + *first = 0; + *second = length / 2; + *third = length - 1; -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); + // + int has_duplicates = // + start[*first] == start[*second] || // + start[*first] == start[*third] || // + start[*second] == start[*third]; -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); + // Loop through letters to find non-colliding variants. + if (length > 3 && has_duplicates) { + // Pivot the middle point right, until we find a character different from the first one. + while (start[*second] == start[*first] && *second + 1 < *third) ++(*second); + // Pivot the third (last) point left, until we find a different character. + while ((start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1)) + --(*third); + } -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); + // TODO: Investigate alternative strategies for long needles. + // On very long needles we have the luxury to choose! + // Often dealing with UTF8, we will likely benefit from shifting the first and second characters + // further to the right, to achieve not only uniqueness within the needle, but also avoid common + // rune prefixes of 2-, 3-, and 4-byte codes. + if (length > 8) { + // Pivot the first and second points right, until we find a character, that: + // > is different from others. + // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. + // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. + // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. + // + // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. + // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. + sz_u8_t const *start_u8 = (sz_u8_t const *)start; + sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); + // Let's begin with the seccond character, as the termination criteria there is more obvious + // and we may end up with more variants to check for the first candidate. + while ((start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && + (vibrant_second + 1 < vibrant_third)) + ++vibrant_second; -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); + // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. + if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } + else { vibrant_second = *second; } -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); + // Now check the first character. + while ((start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || + start_u8[vibrant_first] == start_u8[vibrant_third]) && + (vibrant_first + 1 < vibrant_second)) + ++vibrant_first; -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); + // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. + // We don't need to shift the third one when dealing with texts as the last byte of the text is + // also the last byte of a rune and contains the most information. + if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } + } +} -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); +SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { + for (sz_cptr_t const end = text + length; text != end; ++text) + if (sz_charset_contains(set, *text)) return text; + return SZ_NULL_CHAR; +} -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); +SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" + sz_cptr_t const end = text; + for (text += length; text != end;) + if (sz_charset_contains(set, *(text -= 1))) return text; + return SZ_NULL_CHAR; +#pragma GCC diagnostic pop +} /** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint + * @brief Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each byte signifies a match. */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each byte is set. + // For that take the bottom 7 bits of each byte, add one to them, + // and if this sets the top bit to one, then all the 7 bits are ones as well. + vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); + return vec; +} -#pragma endregion +/* Find the first occurrence of a @b single-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. + * Identical to `memchr(haystack, needle[0], haystack_length)`. + */ +SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { -#pragma region Convenience API + if (!h_length) return SZ_NULL_CHAR; + sz_cptr_t const h_end = h + h_length; -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); +#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. +#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) + if (*h == *n) return h; +#endif -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); + // Broadcast the n into every byte of a 64-bit integer to use SWAR + // techniques and process eight characters at a time. + sz_u64_vec_t h_vec, n_vec, match_vec; + match_vec.u64 = 0; + n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; + for (; h + 8 <= h_end; h += 8) { + h_vec.u64 = *(sz_u64_t const *)h; + match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); + if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; + } +#endif -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); + // Handle the misaligned tail. + for (; h < h_end; ++h) + if (*h == *n) return h; + return SZ_NULL_CHAR; +} -/** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset +/* Find the last occurrence of a @b single-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. + * Identical to `memrchr(haystack, needle[0], haystack_length)`. */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -#pragma endregion +sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { -#pragma region String Sequences API + if (!h_length) return SZ_NULL_CHAR; + sz_cptr_t const h_start = h; -struct sz_sequence_t; + // Reposition the `h` pointer to the end, as we will be walking backwards. + h = h + h_length - 1; -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); +#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. +#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) + if (*h == *n) return h; +#endif -typedef struct sz_sequence_t { - sz_sorted_idx_t *order; - sz_size_t count; - sz_sequence_member_start_t get_start; - sz_sequence_member_length_t get_length; - void const *handle; -} sz_sequence_t; + // Broadcast the n into every byte of a 64-bit integer to use SWAR + // techniques and process eight characters at a time. + sz_u64_vec_t h_vec, n_vec, match_vec; + n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; + for (; h >= h_start + 7; h -= 8) { + h_vec.u64 = *(sz_u64_t const *)(h - 7); + match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); + if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; + } +#endif -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); + for (; h >= h_start; --h) + if (*h == *n) return h; + return SZ_NULL_CHAR; +} /** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. + * @brief 2Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each 2byte signifies a match. */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each 2byte is set. + // For that take the bottom 15 bits of each 2byte, add one to them, + // and if this sets the top bit to one, then all the 15 bits are ones as well. + vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); + return vec; +} /** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. + * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); +SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. - * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. - */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); + // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. + sz_assert(h_length >= 2 && "The haystack is too short."); + sz_cptr_t const h_end = h + h_length; -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); +#if !SZ_USE_MISALIGNED_LOADS + // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; +#endif -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); - -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. - */ -#ifndef SZ_USE_X86_AVX512 -#ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 -#else -#define SZ_USE_X86_AVX512 0 -#endif -#endif - -#ifndef SZ_USE_X86_AVX2 -#ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 -#else -#define SZ_USE_X86_AVX2 0 -#endif -#endif - -#ifndef SZ_USE_ARM_NEON -#ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 -#else -#define SZ_USE_ARM_NEON 0 -#endif -#endif - -#ifndef SZ_USE_ARM_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 -#else -#define SZ_USE_ARM_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 -#include -#endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_ARM_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_X86_AVX512 - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_X86_AVX2 -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_ARM_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_ARM_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes - -/** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. - * @note If you want to catch it, put a breakpoint at @b `__GI_exit` - */ -#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` -#include // `EXIT_FAILURE` -SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { - fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); - exit(EXIT_FAILURE); -} -#define sz_assert(condition) \ - do { \ - if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ - } while (0) -#else -#define sz_assert(condition) ((void)(condition)) -#endif - -/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. - * The following section of compiler intrinsics comes in 2 flavors. - */ -#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL -#include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. -#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { - x = x - ((x >> 1) & 0x5555555555555555ull); - x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); - return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; -} -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x80000000u) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { - x = x - ((x >> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >> 2) & 0x33333333); - return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; -} -#else -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } -#endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given -#pragma intrinsic(_byteswap_uint64) -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } -#pragma intrinsic(_byteswap_ulong) -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } -#else -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } -#endif - -SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } - -/** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. - * - * Similar to `_mm_blend_epi16` intrinsic on x86. - * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. - * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching - */ -SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } - -/* - * Efficiently computing the minimum and maximum of two or three values can be tricky. - * The simple branching baseline would be: - * - * x < y ? x : y // can replace with 1 conditional move - * - * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. - * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function - * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax - * Using only bit-shifts for singed integers it would be: - * - * y + ((x - y) & (x - y) >> 31) // 4 unique operations - * - * Alternatively, for any integers using multiplication: - * - * (x > y) * y + (x <= y) * x // 5 operations - * - * Alternatively, to avoid multiplication: - * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if SZ_DETECT_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if SZ_DETECT_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. - * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; - - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - -/** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - */ -#define SZ_U32_MAX_PRIME (2147483647u) - -/** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 - */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) - -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } + sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; + n_vec.u64 = 0; + n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; + n_vec.u64 *= 0x0001000100010001ull; // broadcast - return _sz_hash_mix(hash_low, hash_high); -} + // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. + for (; h + 9 <= h_end; h += 8) { + h_even_vec.u64 = *(sz_u64_t *)h; + h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); + matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); + matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); + matches_even_vec.u64 >>= 8; + if (matches_even_vec.u64 + matches_odd_vec.u64) { + sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; + return h + sz_u64_ctz(match_indicators) / 8; } } -} - -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { - static sz_u8_t const lowered[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return lowered[c]; -} -/** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { - static sz_u8_t const upped[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return upped[c]; + for (; h + 2 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; + return SZ_NULL_CHAR; } /** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. + * @brief 4Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each 4byte signifies a match. */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); -} - -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each 4byte is set. + // For that take the bottom 31 bits of each 4byte, add one to them, + // and if this sets the top bit to one, then all the 31 bits are ones as well. + vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); + return vec; } /** - * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. + * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { +SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - if (!length) return sz_true_k; - sz_u8_t const *h = (sz_u8_t const *)text; - sz_u8_t const *const h_end = h + length; + // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. + sz_assert(h_length >= 4 && "The haystack is too short."); + sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; + for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; #endif - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); + sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; + n_vec.u64 = 0; + n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; + n_vec.u64 *= 0x0000000100000001ull; // broadcast - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); + // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. + // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) + sz_u64_t h_page_current, h_page_next; + for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { + h_page_current = *(sz_u64_t *)h; + h_page_next = *(sz_u32_t *)(h + 8); + h0_vec.u64 = (h_page_current); + h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); + h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); + h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); + matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); + matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); + matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); + matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; + if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { + matches0_vec.u64 >>= 24; + matches1_vec.u64 >>= 16; + matches2_vec.u64 >>= 8; + sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; + return h + sz_u64_ctz(match_indicators) / 8; } } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); + for (; h + 4 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; + return SZ_NULL_CHAR; } -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; +/** + * @brief 3Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each 3byte signifies a match. + */ +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each 4byte is set. + // For that take the bottom 31 bits of each 4byte, add one to them, + // and if this sets the top bit to one, then all the 31 bits are ones as well. + vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); + return vec; } -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} +/** + * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. + */ +SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. + // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. + sz_assert(h_length >= 3 && "The haystack is too short."); + sz_cptr_t const h_end = h + h_length; +#if !SZ_USE_MISALIGNED_LOADS + // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; #endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); - - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} - -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); -} - -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", off) -#endif -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - sz_ptr_t end = target + length; - // Dealing with short strings, a single sequential pass would be faster. - // If the size is larger than 2 words, then at least 1 of them will be aligned. - // But just one aligned word may not be worth SWAR. - if (length < SZ_SWAR_THRESHOLD) - while (target != end) *(target++) = value; - - // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. - else { - sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; - while ((sz_size_t)target & 7ull) *(target++) = value; - while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; - while (target != end) *(target++) = value; - } -} -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", on) -#endif + // We fetch 12 + sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; + sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; + sz_u64_vec_t n_vec; + n_vec.u64 = 0; + n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; + n_vec.u64 *= 0x0000000001000001ull; // broadcast -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); -} + // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. + // We load the subsequent two-byte word as well. + sz_u64_t h_page_current, h_page_next; + for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { + h_page_current = *(sz_u64_t *)h; + h_page_next = *(sz_u16_t *)(h + 8); + h0_vec.u64 = (h_page_current); + h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); + h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); + h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); + h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); + matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); + matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); + matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); + matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); + matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. - // Existing implementations often have two passes, in normal and reversed order, - // depending on the relation of `target` and `source` addresses. - // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html - // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 - // - // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. - // Or if we know that they don't intersect! In that case the traversal order is irrelevant, - // but older CPUs may predict and fetch forward-passes better. - if (target < source || target >= source + length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - target += length, source += length; -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; -#endif - while (length--) *(--target) = *(--source); + if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { + matches0_vec.u64 >>= 16; + matches1_vec.u64 >>= 8; + matches3_vec.u64 <<= 8; + matches4_vec.u64 <<= 16; + sz_u64_t match_indicators = + matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; + return h + sz_u64_ctz(match_indicators) / 8; + } } -} - -#pragma endregion - -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; + for (; h + 3 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; + return SZ_NULL_CHAR; } -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; +/** + * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. + * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. + */ +SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial( // + sz_cptr_t h_chars, sz_size_t h_length, // + sz_cptr_t n_chars, sz_size_t n_length) { + sz_assert(n_length <= 256 && "The pattern is too long."); + // Several popular string matching algorithms are using a bad-character shift table. + // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html + // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html + // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html + union { + sz_u8_t jumps[256]; + sz_u64_vec_t vecs[64]; + } bad_shift_table; - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { + // Let's initialize the table using SWAR to the total length of the string. + sz_u8_t const *h = (sz_u8_t const *)h_chars; + sz_u8_t const *n = (sz_u8_t const *)n_chars; + { + sz_u64_vec_t n_length_vec; + n_length_vec.u64 = n_length; + n_length_vec.u64 *= 0x0101010101010101ull; // broadcast + for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; + for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); + } - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; + // Another common heuristic is to match a few characters from different parts of a string. + // Raita suggests to use the first two, the last, and the middle character of the pattern. + sz_u32_vec_t h_vec, n_vec; - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} + // Broadcast those characters into an unsigned integer. + n_vec.u8s[0] = n[offset_first]; + n_vec.u8s[1] = n[offset_first + 1]; + n_vec.u8s[2] = n[offset_mid]; + n_vec.u8s[3] = n[offset_last]; -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; + // Scan through the whole haystack, skipping the last `n_length - 1` bytes. + for (sz_size_t i = 0; i <= h_length - n_length;) { + h_vec.u8s[0] = h[i + offset_first]; + h_vec.u8s[1] = h[i + offset_first + 1]; + h_vec.u8s[2] = h[i + offset_mid]; + h_vec.u8s[3] = h[i + offset_last]; + if (h_vec.u32 == n_vec.u32 && sz_equal_serial((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; + i += bad_shift_table.jumps[h[i + n_length - 1]]; } + return SZ_NULL_CHAR; } -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} +/** + * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. + * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. + */ +SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial( // + sz_cptr_t h_chars, sz_size_t h_length, // + sz_cptr_t n_chars, sz_size_t n_length) { + sz_assert(n_length <= 256 && "The pattern is too long."); + union { + sz_u8_t jumps[256]; + sz_u64_vec_t vecs[64]; + } bad_shift_table; -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; + // Let's initialize the table using SWAR to the total length of the string. + sz_u8_t const *h = (sz_u8_t const *)h_chars; + sz_u8_t const *n = (sz_u8_t const *)n_chars; + { + sz_u64_vec_t n_length_vec; + n_length_vec.u64 = n_length; + n_length_vec.u64 *= 0x0101010101010101ull; // broadcast + for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; + for (sz_size_t i = 0; i + 1 < n_length; ++i) + bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); } -} -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} + // Another common heuristic is to match a few characters from different parts of a string. + // Raita suggests to use the first two, the last, and the middle character of the pattern. + sz_u32_vec_t h_vec, n_vec; -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } + // Broadcast those characters into an unsigned integer. + n_vec.u8s[0] = n[offset_first]; + n_vec.u8s[1] = n[offset_first + 1]; + n_vec.u8s[2] = n[offset_mid]; + n_vec.u8s[3] = n[offset_last]; - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; + // Scan through the whole haystack, skipping the first `n_length - 1` bytes. + for (sz_size_t j = 0; j <= h_length - n_length;) { + sz_size_t i = h_length - n_length - j; + h_vec.u8s[0] = h[i + offset_first]; + h_vec.u8s[1] = h[i + offset_first + 1]; + h_vec.u8s[2] = h[i + offset_mid]; + h_vec.u8s[3] = h[i + offset_last]; + if (h_vec.u32 == n_vec.u32 && sz_equal_serial((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; + j += bad_shift_table.jumps[h[i]]; } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); + return SZ_NULL_CHAR; } -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; +/** + * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle + * using a given search function, and then verifies the remaining part of the needle. + */ +SZ_INTERNAL sz_cptr_t _sz_find_with_prefix( // + sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, sz_find_t find_prefix, sz_size_t prefix_length) { - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } + sz_size_t suffix_length = n_length - prefix_length; + while (1) { + sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); + if (!found) return SZ_NULL_CHAR; - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; + // Verify the remaining part of the needle + sz_size_t remaining = h_length - (found - h); + if (remaining < n_length) return SZ_NULL_CHAR; + if (sz_equal_serial(found + prefix_length, n + prefix_length, suffix_length)) return found; - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } + // Adjust the position. + h = found + 1; + h_length = remaining - 1; } - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } + // Unreachable, but helps silence compiler warnings: + return SZ_NULL_CHAR; } -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} +/** + * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the + * needle using a given search function, and then verifies the remaining part of the needle. + */ +SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, + sz_find_t find_suffix, sz_size_t suffix_length) { -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { + sz_size_t prefix_length = n_length - suffix_length; + while (1) { + sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); + if (!found) return SZ_NULL_CHAR; -#if SZ_DETECT_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else + // Verify the remaining part of the needle + sz_size_t remaining = found - h; + if (remaining < prefix_length) return SZ_NULL_CHAR; + if (sz_equal_serial(found - prefix_length, n, prefix_length)) return found - prefix_length; - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; + // Adjust the position. + h_length = remaining - 1; } - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif + // Unreachable, but helps silence compiler warnings: + return SZ_NULL_CHAR; } -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif +SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); } -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. - */ -#pragma region AVX2 Implementation - -#if SZ_USE_X86_AVX2 -#pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include +SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial( // + sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); +} -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); +SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial( // + sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); } -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; +SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } +#if _SZ_IS_BIG_ENDIAN + sz_find_t backends[] = { + (sz_find_t)sz_find_byte_serial, + (sz_find_t)_sz_find_horspool_upto_256bytes_serial, + (sz_find_t)_sz_find_horspool_over_256bytes_serial, + }; - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} + return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); +#else + sz_find_t backends[] = { + // For very short strings brute-force SWAR makes sense. + (sz_find_t)sz_find_byte_serial, + (sz_find_t)_sz_find_2byte_serial, + (sz_find_t)_sz_find_3byte_serial, + (sz_find_t)_sz_find_4byte_serial, + // To avoid constructing the skip-table, let's use the prefixed approach. + (sz_find_t)_sz_find_over_4bytes_serial, + // For longer needles - use skip tables. + (sz_find_t)_sz_find_horspool_upto_256bytes_serial, + (sz_find_t)_sz_find_horspool_over_256bytes_serial, + }; -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - char value_char = *(char *)&value; - __m256i value_vec = _mm256_set1_epi8(value_char); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); - // sz_fill_serial(target, length, value); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 32) sz_fill_serial(target, length, value); - // When the buffer is aligned, we can avoid any split-stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u16_t value16 = (sz_u16_t)value * 0x0101u; - sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; - sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; - } + return backends[ + // For very short strings brute-force SWAR makes sense. + (n_length > 1) + (n_length > 2) + (n_length > 3) + + // To avoid constructing the skip-table, let's use the prefixed approach. + (n_length > 4) + + // For longer needles - use skip tables. + (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); +#endif } -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 32; target += 32, source += 32, length -= 32) - // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - // sz_copy_serial(target, source, length); - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); - if (length) sz_copy_serial(target, source, length); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; target += 32, source += 32, body_length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { - _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); - _mm256_store_si256((__m256i *)(target + body_length - 32), - _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); - } - if (body_length) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } +SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; - } -} + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL_CHAR; + + sz_find_t backends[] = { + // For very short strings brute-force SWAR makes sense. + (sz_find_t)sz_rfind_byte_serial, + // TODO: implement reverse-order SWAR for 2/3/4 byte variants. + // TODO: (sz_find_t)_sz_rfind_2byte_serial, + // TODO: (sz_find_t)_sz_rfind_3byte_serial, + // TODO: (sz_find_t)_sz_rfind_4byte_serial, + // To avoid constructing the skip-table, let's use the prefixed approach. + // (sz_find_t)_sz_rfind_over_4bytes_serial, + // For longer needles - use skip tables. + (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, + (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, + }; -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - for (target += length, source += length; length >= 32; length -= 32) - _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); - while (length--) *(--target) = *(--source); - } + return backends[ + // For very short strings brute-force SWAR makes sense. + 0 + + // To avoid constructing the skip-table, let's use the prefixed approach. + (n_length > 1) + + // For longer needles - use skip tables. + (n_length > 256)](h, h_length, n, n_length); } -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} +#pragma endregion // Serial Implementation -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +/* AVX2 implementation of the string search algorithms for Haswell processors and newer. + * Very minimalistic (compared to AVX-512), but still faster than the serial implementation. + */ +#pragma region Haswell Implementation +#if SZ_USE_HASWELL +#pragma GCC push_options +#pragma GCC target("haswell") +#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } +SZ_PUBLIC sz_bool_t sz_equal_haswell(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { + sz_u256_vec_t a_vec, b_vec; - // We need to pull the lookup table into 8x YMM registers. - // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, - // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, - // so that we can at least compensate high latency with twice larger window and one more level of lookup. - sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // - lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // - lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // - lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; - - lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); - lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); - lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); - lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); - lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); - lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); - lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); - lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); - lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); - lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); - lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); - lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); - lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); - lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); - lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); - lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); - - // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. - sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; - - /// Top and bottom nibbles of the source are used separately. - sz_u256_vec_t source_vec, source_bot_vec; - sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, - blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; - - // Handling the head. while (length >= 32) { - // Load and separate the nibbles of each byte in the source. - source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); - source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); - - // In the first round, we select using the 4th bit. - not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - - // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. - // The first round selects using the 3rd bit. - not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_32_to_63_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_third_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - blended_96_to_127_vec.ymm, // - blended_64_to_95_vec.ymm, // - not_third_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_160_to_191_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_third_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - blended_224_to_255_vec.ymm, // - blended_192_to_223_vec.ymm, // - not_third_bit_vec.ymm); - - // The second round selects using the 2nd bit. - not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_64_to_95_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_second_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_192_to_223_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_second_bit_vec.ymm); - - // The third round selects using the 1st bit. - not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_128_to_159_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_first_bit_vec.ymm); - - // And dump the result into the target. - _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); - source += 32, target += 32, length -= 32; + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); + // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. + int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); + if (difference_mask == 0) { a += 32, b += 32, length -= 32; } + else { return sz_false_k; } } - // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); + if (length) return sz_equal_serial(a, b, length); + return sz_true_k; } -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { +SZ_PUBLIC sz_cptr_t sz_find_byte_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { int mask; sz_u256_vec_t h_vec, n_vec; n_vec.ymm = _mm256_set1_epi8(n[0]); @@ -4233,7 +852,7 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t return sz_find_byte_serial(h, h_length, n); } -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { +SZ_PUBLIC sz_cptr_t sz_rfind_byte_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { int mask; sz_u256_vec_t h_vec, n_vec; n_vec.ymm = _mm256_set1_epi8(n[0]); @@ -4248,11 +867,11 @@ SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_ return sz_rfind_byte_serial(h, h_length, n); } -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { +SZ_PUBLIC sz_cptr_t sz_find_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); + if (n_length == 1) return sz_find_byte_haswell(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; @@ -4270,9 +889,10 @@ SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, s h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); + matches = // + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); while (matches) { int potential_offset = sz_u32_ctz(matches); if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; @@ -4283,11 +903,11 @@ SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, s return sz_find_serial(h, h_length, n, n_length); } -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { +SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); + if (n_length == 1) return sz_rfind_byte_haswell(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; @@ -4307,9 +927,10 @@ SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); + matches = // + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); while (matches) { int potential_offset = sz_u32_clz(matches); if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) @@ -4321,7 +942,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, return sz_rfind_serial(h, h_length, n, n_length); } -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. @@ -4336,11 +957,12 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_ch sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; sz_u256_vec_t bitset_even_vec, bitset_odd_vec; sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); + bitmask_lookup_vec.ymm = _mm256_set_epi8( // + -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // + -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" + // The following algorithm is a transposed equivalent of the "SIMD-ized check which bytes are in a set" // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so // StrinZilla uses a somewhat different approach. // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new @@ -4408,289 +1030,27 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_ch return sz_find_charset_serial(text, length, filter); } -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { return sz_rfind_charset_serial(text, length, filter); } -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - #pragma clang attribute pop #pragma GCC pop_options -#endif -#pragma endregion +#endif // SZ_USE_HASWELL +#pragma endregion // Haswell Implementation -/* - * @brief AVX-512 implementation of the string search algorithms. +/* AVX512 implementation of the string search algorithms for Skylake and newer CPUs. + * Includes extensions: F, CD, ER, PF, VL, DQ, BW. * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT + * This is the "starting level" for the advanced algorithms using K-mask registers on x86. */ -#pragma region AVX512 Implementation - -#if SZ_USE_X86_AVX512 +#pragma region Skylake Implementation +#if SZ_USE_SKYLAKE #pragma GCC push_options #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} - -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } - } - - return sz_equal_k; -} - -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { +SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { __mmask64 mask; sz_u512_vec_t a_vec, b_vec; @@ -4714,219 +1074,6 @@ SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) return sz_true_k; } -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - __m512i value_vec = _mm512_set1_epi8(value); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); - // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, value_vec); - } - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, value_vec); - for (target += head_length; body_length >= 64; target += 64, body_length -= 64) - _mm512_store_si512(target, value_vec); - _mm512_mask_storeu_epi8(target, tail_mask, value_vec); - } -} - -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 64; target += 64, source += 64, length -= 64) - // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); - // __mmask64 mask = _sz_u64_mask_until(length); - // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { - for (; length >= 64; target += 64, source += 64, length -= 64) - _mm512_store_si512(target, _mm512_load_si512(source)); - // At this point the length is guaranteed to be under 64. - __mmask64 mask = _sz_u64_mask_until(length); - // Aligned load and stores would work too, but it's not defined. - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. - // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(target + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source)); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (target += head_length, source += head_length; // - body_length >= 128; // - target += 64, source += 64, body_length -= 128) { - _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); - _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); - } - if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); - } -} - -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target == source) return; // Don't be silly, don't move the data if it's already there. - - // On very short buffers, that are one cache line in width or less, we don't need any loops. - // We can also avoid any data-dependencies between iterations, assuming we have 32 registers - // to pre-load the data, before writing it back. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - else if (length <= 128) { - sz_size_t last_length = length - 64; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); - _mm512_storeu_epi8(target, source0); - _mm512_mask_storeu_epi8(target + 64, mask, source1); - } - else if (length <= 192) { - sz_size_t last_length = length - 128; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_mask_storeu_epi8(target + 128, mask, source2); - } - else if (length <= 256) { - sz_size_t last_length = length - 192; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_loadu_epi8(source + 128); - __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_storeu_epi8(target + 128, source2); - _mm512_mask_storeu_epi8(target + 192, mask, source3); - } - - // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // The absolute most common case of using "moves" is shifting the data within a continuous buffer - // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, - // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. - // - // Remember: - // - if we are shifting data left, that we are traversing to the right. - // - if we are shifting data right, that we are traversing to the left. - int const left_to_right_traversal = source > target; - - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - } - } -} - SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { __mmask64 mask; sz_u512_vec_t h_vec, n_vec; @@ -4950,7 +1097,7 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { +SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; @@ -4969,20 +1116,21 @@ SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, + // We have several optimized versions of the algorithm for shorter strings, // but they all mimic the default case for unbounded length needles if (n_length >= 64) { for (; h_length >= n_length + 64; h += 64, h_length -= 64) { h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + matches = _kand_mask64( // + _kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; + if (sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } @@ -4996,10 +1144,11 @@ SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + matches = _kand_mask64( // + _kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); if (matches) return h + sz_u64_ctz(matches); } } @@ -5014,10 +1163,11 @@ SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + matches = _kand_mask64( // + _kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); @@ -5034,893 +1184,126 @@ SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + matches = _kand_mask64( // + _kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; + if (n_length <= 3 || sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) - -/** - * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - - sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - - // The next few buffers will be swapped around. - sz_u512_vec_t previous_vec, current_vec, next_vec; - sz_u512_vec_t gaps_vec, substitutions_vec; - - // Load the strings into ZMM registers - just once. - sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; - longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); - rotate_left_vec.zmm = _mm512_set_epi8( // - 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // - 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); - rotate_right_vec.zmm = _mm512_set_epi8( // - 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // - 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); - ones_vec.zmm = _mm512_set1_epi8(1); - bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); - - // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Initialize the first two diagonals: - // - // previous_vec.u8s[0] = 0; - // current_vec.u8s[0] = current_vec.u8s[1] = 1; - // - // We can do a similar thing with vector ops: - previous_vec.zmm = _mm512_setzero_si512(); - current_vec.zmm = _mm512_set1_epi8(1); - - // We skip diagonals 0 and 1, as they are trivial. - // We will start with diagonal 2, which has length 3, with the first and last elements being preset, - // so we are effectively computing just one value, as will be marked by a single set bit in - // the `next_diagonal_mask` on the very first iteration. - sz_size_t next_diagonal_index = 2; - __mmask64 next_diagonal_mask = 0; - - // Progress through the upper triangle of the Levenshtein matrix. - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // The mask also adds one set bit. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = current_vec.zmm; - current_vec.zmm = next_vec.zmm; - - // Shift the shorter string - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. - for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { - // After this iteration, the value `shorted_dim - 1` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // Make sure we update the first entry. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the bottom right triangle. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - // In every following iterations we take use a shorter prefix of each register, - // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. - next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); - } - return current_vec.u8s[0]; -} - -/** - * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. - * - * This may be one of the most freuqently called kernels for: - * - source code analysis, assuming most lines are either under 80 or under 120 characters long. - * - DNA sequence alignment, as most short reads are 50-300 characters long. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, - * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. - * - * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - sz_unused(shorter && longer && bound && alloc); - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize! - sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && bound && max_length); - -#if 0 - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. - // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; - sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - // The next few pointers will be swapped around. - sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + longer_dim; - sz_u16_t *next_distances = current_distances + longer_dim; - sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); - - // Export the reversed string into the buffer. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Using ZMM registers, we can process 32x 16-bit values at once, - // storing 16 bytes of each string in YMM registers. - sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; - sz_u512_vec_t ones_u16_vec; - ones_u16_vec.zmm = _mm512_set1_epi16(1); - - // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. - sz_u512_vec_t shorter_vec, longer_vec; - sz_u512_vec_t ones_u8_vec; - ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); - // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` - // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the - // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - - // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to - // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // - remaining_length_mask, - shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow - // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, - _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); - deletions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); - offset_within_diagonal += register_length; - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); - // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. - // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would - // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. - // Which simplified would be equal to just `[i]`. Beautiful! - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. The compare it against the accumulated - // substitution costs. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; - } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; + return SZ_NULL_CHAR; } -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; +SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + __mmask64 mask; + sz_u512_vec_t h_vec, n_vec; + n_vec.zmm = _mm512_set1_epi8(n[0]); + + while (h_length >= 64) { + h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); + mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); + if (mask) return h + h_length - 1 - sz_u64_clz(mask); + h_length -= 64; } - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; + if (h_length) { + mask = _sz_u64_mask_until(h_length); + h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); + // Reuse the same `mask` variable to find the bit that doesn't match + mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); + if (mask) return h + 64 - sz_u64_clz(mask) - 1; } - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); + return SZ_NULL_CHAR; } -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } +SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL_CHAR; + if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } + // Broadcast those characters into ZMM registers. + __mmask64 mask; + __mmask64 matches; + sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; + n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); + n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); + n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); + // Scan through the string. + sz_cptr_t h_reversed; + for (; h_length >= n_length + 64; h_length -= 64) { + h_reversed = h + h_length - n_length - 64 + 1; + h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); + h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); + h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); + matches = _kand_mask64( // + _kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + while (matches) { + int potential_offset = sz_u64_clz(matches); + if (n_length <= 3 || sz_equal_skylake(h + h_length - n_length - potential_offset, n, n_length)) + return h + h_length - n_length - potential_offset; + sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && + "The bit must be set before we squash it"); + matches &= ~((sz_u64_t)1 << (63 - potential_offset)); + } } - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); + // The "tail" of the function uses masked loads to process the remaining bytes. + { + mask = _sz_u64_mask_until(h_length - n_length + 1); + h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); + h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); + h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); + matches = _kand_mask64( // + _kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + while (matches) { + int potential_offset = sz_u64_clz(matches); + if (n_length <= 3 || sz_equal_skylake(h + 64 - potential_offset - 1, n, n_length)) + return h + 64 - potential_offset - 1; + sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && + "The bit must be set before we squash it"); + matches &= ~((sz_u64_t)1 << (63 - potential_offset)); } } + + return SZ_NULL_CHAR; } #pragma clang attribute pop #pragma GCC pop_options +#endif // SZ_USE_SKYLAKE +#pragma endregion // Skylake Implementation +/* AVX512 implementation of the string search algorithms for Ice Lake and newer CPUs. + * Includes extensions: + * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, + * - 2018 CannonLake: IFMA, VBMI, + * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. + */ +#pragma region Ice Lake Implementation +#if SZ_USE_ICE #pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ apply_to = function) -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // We need to pull the lookup table into 4x ZMM registers. - // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` - // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to - // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. - // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. - // - // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 6 cycles latency, ports: 1*FP12 - // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p05 - // - On Genoa: 1 cycle latency, ports: 1*FP0123 - // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 4 cycles latency, ports: 1*FP01 - // - sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); - lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); - lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); - lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); - - sz_u512_vec_t first_bit_vec, second_bit_vec; - first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); - second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); - - __mmask64 first_bit_mask, second_bit_mask; - sz_u512_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; - - // Handling the head. - if (head_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); - source += head_length, target += head_length, length -= head_length; - } - - // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. - while (length >= 64) { - source_vec.zmm = _mm512_loadu_si512(source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! - source += 64, target += 64, length -= 64; - } - - // Handling the tail. - if (tail_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); - source += tail_length, target += tail_length, length -= tail_length; - } -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. // In practice, that only hurts, even when we have matches every 5-ish bytes. @@ -6035,365 +1418,30 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_ return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { return sz_rfind_charset_serial(text, length, filter); } -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { - - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; -} - -/** - * Computes the Needleman Wunsch alignment score between two strings. - * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. - * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used - * on strings not exceeding 2^24 length or 16.7 million characters. - * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store - * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of - * a 256 x 256 matrix, but from a single row! - */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t const max_length = 256ull * 256ull * 256ull; - sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && max_length); - - sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; - sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); - sz_i32_t *previous_distances = distances; - sz_i32_t *current_distances = previous_distances + n; - - // Intialize the first row of the Levenshtein matrix with `iota`. - for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) - previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; - - /// Contains up to 16 consecutive characters from the longer string. - sz_u512_vec_t longer_vec; - sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; - sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; - sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; - - // Prepare constants and masks. - sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; - { - char is_third_or_fourth_check, is_second_or_fourth_check; - *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; - is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); - is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); - gap_vec.zmm = _mm512_set1_epi32(gap); - } - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; - - // Load one row of the substitution matrix into four ZMM registers. - sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; - row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); - row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); - row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); - row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); - - // In the serial version we have one forward pass, that computes the deletion, - // insertion, and substitution costs at once. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; - // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); - // } - // - // Given the complexity of handling the data-dependency between consecutive insertion cost computations - // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation - // separately. - // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. - // 2. Compute the pairwise minimum with deletion costs. - // 3. Inclusive prefix minimum computation to combine with addition costs. - // Proceeding with substitutions: - for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { - sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); - __mmask64 mask = _sz_u64_mask_until(register_length); - longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); - - // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source - // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. - // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. - shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); - shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); - shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); - shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); - - // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using - // the AND logical operation, checking the top two bits of every byte. - // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. - __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); - __mmask64 is_second_or_fourth = - _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); - lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( - is_third_or_fourth, - // Choose between the first and the second. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), - // Choose between the third and the fourth. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); - - // First, sign-extend lower and upper 16 bytes to 16-bit integers. - __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); - __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); - - // Now extend those 16-bit integers to 32-bit. - // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. - // To minimize the number of loads and stores, we can combine our substitution costs with the previous - // distances, containing the deletion costs. - { - cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); - cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Inclusive prefix minimum computation to combine with insertion costs. - // Simply disabling this operation results in 5x performance improvement, meaning - // that this operation is responsible for 80% of the total runtime. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // current_distances[idx_longer + 1] = - // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); - // } - // - // To perform the same operation in vectorized form, we need to perform a tree-like reduction, - // that will involve multiple steps. It's quite expensive and should be first tested in the - // "experimental" section. - // - // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); - // ... yet this approach is also quite expensive. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 16 to 31. - if (register_length > 16) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 32) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 48) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); - } - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[longer_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); - else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); -} - -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, - // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} - #pragma clang attribute pop #pragma GCC pop_options -#endif +#endif // SZ_USE_ICE +#pragma endregion // Ice Lake Implementation -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. +/* Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit + * Arm processors. Covers billions of mobile CPUs worldwide, including Apple's A-series, and Qualcomm's Snapdragon. */ -#pragma region ARM NEON - -#if SZ_USE_ARM_NEON +#pragma region NEON Implementation +#if SZ_USE_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; } -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { sz_u128_vec_t a_vec, b_vec; for (; length >= 16; a += 16, b += 16, length -= 16) { @@ -6408,131 +1456,6 @@ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { return sz_true_k; } -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // In most cases the `source` and the `target` are not aligned, but we should - // at least make sure that writes don't touch many cache lines. - // NEON has an instruction to load and write 64 bytes at once. - // - // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; - // length -= head_length; - // for (; length >= 64; target += 64, source += 64, length -= 64) - // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); - // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; - // - // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: - for (; length >= 16; target += 16, source += 16, length -= 16) - vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); - if (length) sz_copy_serial(target, source, length); -} - -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // When moving small buffers, using a small buffer on stack as a temporary storage is faster. - - if (target < source || target >= source + length) { - // Non-overlapping, proceed forward - sz_copy_neon(target, source, length); - } - else { - // Overlapping, proceed backward - target += length; - source += length; - - sz_u128_vec_t src_vec; - while (length >= 16) { - target -= 16, source -= 16, length -= 16; - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - } - while (length) { - target -= 1, source -= 1, length -= 1; - *target = *source; - } - } -} - -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register - - while (length >= 16) { - vst1q_u8((sz_u8_t *)target, fill_vec); - target += 16; - length -= 16; - } - - // Handle remaining bytes - if (length) sz_fill_serial(target, length, value); -} - -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. - - // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. - // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. - uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); - lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); - lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); - lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); - - sz_u128_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u128_vec_t blended_0_to_255_vec; - - // Process the head with serial code - for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; - - // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction - // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position - // within each 64-byte range of the table. - // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ - length -= head_length; - length -= tail_length; - for (; length >= 16; source += 16, target += 16, length -= 16) { - source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); - lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); - lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); - lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); - vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); - } - - // Process the tail with serial code - for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; -} - SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { sz_u64_t matches; sz_u128_vec_t h_vec, n_vec, matches_vec; @@ -6569,8 +1492,8 @@ SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_ return sz_rfind_byte_serial(h, h_length, n); } -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { +SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register( // + sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, uint8x16_t set_bottom_vec_u8x16) { // Once we've read the characters in the haystack, we want to // compare them against our bitset. The serial version of that code @@ -6744,253 +1667,36 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_ch #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm Neon - -#pragma endregion +#endif // SZ_USE_NEON +#pragma endregion // NEON Implementation -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. +/* Implementation of the string search algorithms using the Arm SVE variable-length registers, + * available in Arm v9 processors, like in Apple M4+ and Graviton 3+ CPUs. */ -#pragma region ARM SVE - -#if SZ_USE_ARM_SVE +#pragma region SVE Implementation +#if SZ_USE_SVE #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) - -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - svuint8_t value_vec = svdup_u8(value); - sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) - - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svst1_u8(mask, (unsigned char *)target, value_vec); - } - else { - // Calculate head, body, and tail sizes - sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); - sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned head - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svst1_u8(head_mask, (unsigned char *)target, value_vec); - target += head_length; - - // Aligned body loop - for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { - svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); - } - - // Handle unaligned tail - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svst1_u8(tail_mask, (unsigned char *)target, value_vec); - } -} - -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_size_t vec_len = svcntb(); // Vector length in bytes - - // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, - // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - // - // int is_huge = length >= 4ull * 1024ull * 1024ull; - // - // When the buffer is small, there isn't much to innovate. - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations - // and handle the head, body, and tail separately. We can also traverse the buffer in both directions - // as Arm generally supports more simultaneous stores than x86 CPUs. - // - // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. - // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) - // we will pay a huge penalty on loads, fetching the same content many times. - // It may be better to allow caching (and subsequent eviction), in favor of using four-element - // tuples, wich will be guaranteed to be a multiple of a cache line. - // - // Another approach is to use the `LD4B` instructions, which will populate four registers at once. - // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. - else { - // Calculating head, body, and tail sizes depends on the `vec_len`, - // but it's runtime constant, and the modulo operation is expensive! - // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. - sz_size_t head_length = 16 - ((sz_size_t)target % 16); - sz_size_t tail_length = (sz_size_t)(target + length) % 16; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned parts - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); - svst1_u8(head_mask, (unsigned char *)target, head_data); - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); - svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); - target += head_length; - source += head_length; - - // Aligned body loop, walking in two directions - for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { - svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); - svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); - svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); - svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); - } - // Up to (vec_len * 2 - 1) bytes of data may be left in the body, - // so we can unroll the last two optional loop iterations. - if (body_length > vec_len) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - body_length -= vec_len; - source += body_length; - target += body_length; - } - if (body_length) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - } -} - #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm SVE +#endif // SZ_USE_SVE +#pragma endregion // SVE Implementation -#pragma endregion - -/* - * @brief Pick the right implementation for the string search algorithms. +/* Pick the right implementation for the string search algorithms. + * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. */ #pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { - - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; - - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ - - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); -} - #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 - return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif -} - -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 - return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); -#endif -} - -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif -} - -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_copy_neon(target, source, length); -#else - sz_copy_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_move_neon(target, source, length); -#else - sz_move_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 - sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 - sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON - sz_fill_neon(target, length, value); -#else - sz_fill_serial(target, length, value); -#endif -} - -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON - sz_look_up_transform_neon(source, length, lut, target); -#else - sz_look_up_transform_serial(source, length, lut, target); -#endif -} +#pragma region Core Funcitonality SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_HASWELL + return sz_find_byte_haswell(haystack, h_length, needle); +#elif SZ_USE_NEON return sz_find_byte_neon(haystack, h_length, needle); #else return sz_find_byte_serial(haystack, h_length, needle); @@ -6998,11 +1704,11 @@ SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cpt } SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_HASWELL + return sz_rfind_byte_haswell(haystack, h_length, needle); +#elif SZ_USE_NEON return sz_rfind_byte_neon(haystack, h_length, needle); #else return sz_rfind_byte_serial(haystack, h_length, needle); @@ -7010,11 +1716,11 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cp } SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON +#if SZ_USE_ICE + return sz_find_skylake(haystack, h_length, needle, n_length); +#elif SZ_USE_HASWELL + return sz_find_haswell(haystack, h_length, needle, n_length); +#elif SZ_USE_NEON return sz_find_neon(haystack, h_length, needle, n_length); #else return sz_find_serial(haystack, h_length, needle, n_length); @@ -7022,11 +1728,11 @@ SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t n } SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON +#if SZ_USE_ICE + return sz_rfind_skylake(haystack, h_length, needle, n_length); +#elif SZ_USE_HASWELL + return sz_rfind_haswell(haystack, h_length, needle, n_length); +#elif SZ_USE_NEON return sz_rfind_neon(haystack, h_length, needle, n_length); #else return sz_rfind_serial(haystack, h_length, needle, n_length); @@ -7034,11 +1740,11 @@ SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t } SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON +#if SZ_USE_ICE + return sz_find_charset_ice(text, length, set); +#elif SZ_USE_HASWELL + return sz_find_charset_haswell(text, length, set); +#elif SZ_USE_NEON return sz_find_charset_neon(text, length, set); #else return sz_find_charset_serial(text, length, set); @@ -7046,69 +1752,19 @@ SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charse } SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON +#if SZ_USE_ICE + return sz_rfind_charset_ice(text, length, set); +#elif SZ_USE_HASWELL + return sz_rfind_charset_haswell(text, length, set); +#elif SZ_USE_NEON return sz_rfind_charset_neon(text, length, set); #else return sz_rfind_charset_serial(text, length, set); #endif } -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); -#else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); -#endif -} - -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} - -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); -#else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); -#endif -} - -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); -#else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); -#endif -} +#pragma endregion +#pragma region Helper Shortcuts SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_charset_t set; @@ -7140,17 +1796,11 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_ return sz_rfind_charset(h, h_length, &set); } -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); -} - -#endif -#pragma endregion +#pragma endregion // Helper Shortcuts +#endif // !SZ_DYNAMIC_DISPATCH +#pragma endregion // Compile Time Dispatching #ifdef __cplusplus -#pragma GCC diagnostic pop } #endif // __cplusplus - -#endif // STRINGZILLA_H_ +#endif // STRINGZILLA_FIND_H_ From 295d49a38d66b08075357ac829ad66d80b5edab0 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:49:26 +0000 Subject: [PATCH 032/751] Fix: Filter `memory.h` file --- include/stringzilla/memory.h | 6359 ++-------------------------------- 1 file changed, 262 insertions(+), 6097 deletions(-) diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index de7fbcac..87957878 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -1,3082 +1,166 @@ /** - * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. - * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs - * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. - * - * Consider overriding the following macros to customize the library: - * - * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. - * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. - * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. - * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. - * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html - * - * @file stringzilla.h + * @brief Hardware-accelerated memory operations. + * @file memory.h * @author Ash Vardanian - */ -#ifndef STRINGZILLA_H_ -#define STRINGZILLA_H_ - -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 0 - -/** - * @brief When set to 1, the library will include the following LibC headers: and . - * In debug builds (SZ_DEBUG=1), the library will also include and . - * - * You may want to disable this compiling for use in the kernel, or in embedded systems. - * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. - * https://artificial-mind.net/projects/compile-health/ - */ -#ifndef SZ_AVOID_LIBC -#define SZ_AVOID_LIBC (0) // true or false -#endif - -/** - * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address - * that is not divisible by eight. On x86 enabled by default. On ARM it's not. - * - * Most platforms support it, but there is no industry standard way to check for those. - * This value will mostly affect the performance of the serial (SWAR) backend. - */ -#ifndef SZ_USE_MISALIGNED_LOADS -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) -#define SZ_USE_MISALIGNED_LOADS (1) // true or false -#else -#define SZ_USE_MISALIGNED_LOADS (0) // true or false -#endif -#endif - -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif - -/** - * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. - * 64-bit on most platforms where pointers are 64-bit. - * 32-bit on platforms where pointers are 32-bit. - */ -#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. -#else -#define SZ_DETECT_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. -#endif - -/** - * @brief On Big-Endian machines StringZilla will work in compatibility mode. - * This disables SWAR hacks to minimize code duplication, assuming practically - * all modern popular platforms are Little-Endian. - * - * This variable is hard to infer from macros reliably. It's best to set it manually. - * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). - * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. - * https://stackoverflow.com/a/27054190 - */ -#ifndef SZ_DETECT_BIG_ENDIAN -#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ - defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. - */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) -#else -#define SZ_DEBUG (0) -#endif -#endif - -/** - * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. - * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. - * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. - */ -#ifndef SZ_SWAR_THRESHOLD -#if SZ_DEBUG -#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds -#else -#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds -#endif -#endif - -/* Annotation for the public API symbols: - * - * - `SZ_PUBLIC` is used for functions that are part of the public API. - * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. - * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. - */ -#ifndef SZ_DYNAMIC -#if SZ_DYNAMIC_DISPATCH -#if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_DYNAMIC __declspec(dllexport) -#define SZ_EXTERNAL __declspec(dllimport) -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#else -#define SZ_DYNAMIC __attribute__((visibility("default"))) -#define SZ_EXTERNAL extern -#define SZ_PUBLIC __attribute__((unused)) inline static -#define SZ_INTERNAL __attribute__((always_inline)) inline static -#endif // _WIN32 || __CYGWIN__ -#else -#define SZ_DYNAMIC inline static -#define SZ_EXTERNAL extern -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#endif // SZ_DYNAMIC_DISPATCH -#endif // SZ_DYNAMIC - -/** - * @brief Alignment macro for 64-byte alignment. - */ -#if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) -#elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) -#else -#define SZ_ALIGN64 -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -/* - * Let's infer the integer types or pull them from LibC, - * if that is allowed by the user. - */ -#if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` -typedef int8_t sz_i8_t; // Always 8 bits -typedef uint8_t sz_u8_t; // Always 8 bits -typedef uint16_t sz_u16_t; // Always 16 bits -typedef int32_t sz_i32_t; // Always 32 bits -typedef uint32_t sz_u32_t; // Always 32 bits -typedef uint64_t sz_u64_t; // Always 64 bits -typedef int64_t sz_i64_t; // Always 64 bits -typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits -typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits - -#else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. -typedef signed char sz_i8_t; // Always 8 bits -typedef unsigned char sz_u8_t; // Always 8 bits -typedef unsigned short sz_u16_t; // Always 16 bits -typedef int sz_i32_t; // Always 32 bits -typedef unsigned int sz_u32_t; // Always 32 bits -typedef long long sz_i64_t; // Always 64 bits -typedef unsigned long long sz_u64_t; // Always 64 bits - -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT -typedef unsigned long long sz_size_t; // 64-bit. -typedef long long sz_ssize_t; // 64-bit. -#else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT - -#endif // SZ_AVOID_LIBC - -/** - * @brief Compile-time assert macro similar to `static_assert` in C++. - */ -#define sz_static_assert(condition, name) \ - typedef struct { \ - int static_assert_##name : (condition) ? 1 : -1; \ - } sz_static_assert_##name##_t - -sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); -sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); - -#pragma region Public API - -typedef char *sz_ptr_t; // A type alias for `char *` -typedef char const *sz_cptr_t; // A type alias for `char const *` -typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions - -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings - -typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit -typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> - -/** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. - */ -typedef struct sz_string_view_t { - sz_cptr_t start; - sz_size_t length; -} sz_string_view_t; - -/** - * @brief Enumeration of SIMD capabilities of the target architecture. - * Used to introspect the supported functionality of the dynamic library. - */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability - - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability - -} sz_capability_t; - -/** - * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. - * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. - */ -SZ_DYNAMIC sz_capability_t sz_capabilities(void); - -/** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert - */ -typedef union sz_charset_t { - sz_u64_t _u64s[4]; - sz_u32_t _u32s[8]; - sz_u16_t _u16s[16]; - sz_u8_t _u8s[32]; -} sz_charset_t; - -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } - -/** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } - -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast - -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { - // Checking the bit can be done in different ways: - // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 - // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 - // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 - // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 - return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); -} - -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast -} - -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { - s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // - s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; -} - -typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); -typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); - -/** - * @brief Some complex pattern matching algorithms may require memory allocations. - * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed - */ -typedef struct sz_memory_allocator_t { - sz_memory_allocate_t allocate; - sz_memory_free_t free; - void *handle; -} sz_memory_allocator_t; - -/** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. - */ -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); - -/** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. - * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. - */ -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); - -/** - * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. - * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. - */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE -#endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length - -/** - * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). - * Differs in layout from Folly, Clang, GCC, and probably most other implementations. - * It's designed to avoid any branches on read-only operations, and can store up - * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. - * - * @section Changing Length - * - * One nice thing about this design, is that you can, in many cases, change the length of the string - * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, - * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, - * only changing the last byte containing the length. - */ -typedef union sz_string_t { - -#if !SZ_DETECT_BIG_ENDIAN - - struct external { - sz_ptr_t start; - sz_size_t length; - sz_size_t space; - sz_size_t padding; - } external; - - struct internal { - sz_ptr_t start; - sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; - } internal; - -#else - - struct external { - sz_ptr_t start; - sz_size_t space; - sz_size_t padding; - sz_size_t length; - } external; - - struct internal { - sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; - sz_u8_t length; - } internal; - -#endif - - sz_size_t words[4]; - -} sz_string_t; - -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); -typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); - -/** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. - * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. - * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap - * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. - */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Checks if all characters in the range are valid ASCII characters. - * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. - */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - -/** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. - * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. - */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); - -/** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. - */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); - -/** - * @brief Initializes a string class instance to an empty value. - */ -SZ_PUBLIC void sz_string_init(sz_string_t *string); - -/** - * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, - * alternative being - allocated in a remote region of the heap. - */ -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); - -/** - * @brief Unpacks the opaque instance of a string class into its components. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. - * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. - */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); - -/** - * @brief Unpacks only the start and length of the string. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - */ -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); - -/** - * @brief Constructs a string of a given ::length with noisy contents. - * Use the returned character pointer to populate the string. - * - * @param string String to initialize. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); - -/** - * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. - * This is beneficial, if several insertions are expected, and we want to minimize allocations. - * - * @param string String to grow. - * @param new_capacity The number of characters to reserve space for, including existing ones. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); - -/** - * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. - * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. - * Similar to `sz_string_reserve`, but changes the length of the ::string. - * - * @param string String to grow. - * @param offset Offset of the first byte to reserve space for. - * If provided offset is larger than the length, it will be capped. - * @param added_length The number of new characters to reserve space for. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); - -/** - * @brief Removes a range from a string. Changes the length, but not the capacity. - * Performs no allocations or deallocations and can't fail. - * - * @param string String to clean. - * @param offset Offset of the first byte to remove. - * @param length Number of bytes to remove. Out-of-bound ranges will be capped. - * @return Number of bytes removed. - */ -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); - -/** - * @brief Shrinks the string to fit the current length, if it's allocated on the heap. - * It's the reverse operation of ::sz_string_reserve. - * - * @param string String to shrink. - * @param allocator Memory allocator to use for the allocation. - * @return Whether the operation was successful. The only failures can come from the allocator. - * On failure, the string will remain unchanged. - */ -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); - -/** - * @brief Frees the string, if it's allocated on the heap. - * If the string is on the stack, the function clears/resets the state. - */ -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); - -#pragma endregion - -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); - -/** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Locates the last matching substring. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -#pragma endregion - -#pragma region String Similarity Measures API - -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); - -/** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); - -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); - -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); - -#pragma endregion - -#pragma region Convenience API - -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -#pragma endregion - -#pragma region String Sequences API - -struct sz_sequence_t; - -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); - -typedef struct sz_sequence_t { - sz_sorted_idx_t *order; - sz_size_t count; - sz_sequence_member_start_t get_start; - sz_sequence_member_length_t get_length; - void const *handle; -} sz_sequence_t; - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); - -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. - * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. - */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); - -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); - -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); - -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. - */ -#ifndef SZ_USE_X86_AVX512 -#ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 -#else -#define SZ_USE_X86_AVX512 0 -#endif -#endif - -#ifndef SZ_USE_X86_AVX2 -#ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 -#else -#define SZ_USE_X86_AVX2 0 -#endif -#endif - -#ifndef SZ_USE_ARM_NEON -#ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 -#else -#define SZ_USE_ARM_NEON 0 -#endif -#endif - -#ifndef SZ_USE_ARM_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 -#else -#define SZ_USE_ARM_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 -#include -#endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_ARM_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_X86_AVX512 - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_X86_AVX2 -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_ARM_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_ARM_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes - -/** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. - * @note If you want to catch it, put a breakpoint at @b `__GI_exit` - */ -#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` -#include // `EXIT_FAILURE` -SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { - fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); - exit(EXIT_FAILURE); -} -#define sz_assert(condition) \ - do { \ - if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ - } while (0) -#else -#define sz_assert(condition) ((void)(condition)) -#endif - -/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. - * The following section of compiler intrinsics comes in 2 flavors. - */ -#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL -#include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. -#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { - x = x - ((x >> 1) & 0x5555555555555555ull); - x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); - return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; -} -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x80000000u) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { - x = x - ((x >> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >> 2) & 0x33333333); - return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; -} -#else -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } -#endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given -#pragma intrinsic(_byteswap_uint64) -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } -#pragma intrinsic(_byteswap_ulong) -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } -#else -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } -#endif - -SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } - -/** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. - * - * Similar to `_mm_blend_epi16` intrinsic on x86. - * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. - * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching - */ -SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } - -/* - * Efficiently computing the minimum and maximum of two or three values can be tricky. - * The simple branching baseline would be: - * - * x < y ? x : y // can replace with 1 conditional move - * - * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. - * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function - * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax - * Using only bit-shifts for singed integers it would be: - * - * y + ((x - y) & (x - y) >> 31) // 4 unique operations - * - * Alternatively, for any integers using multiplication: - * - * (x > y) * y + (x <= y) * x // 5 operations - * - * Alternatively, to avoid multiplication: - * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if SZ_DETECT_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if SZ_DETECT_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. + * Includes: * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! + * - `sz_copy` - analog to `memcpy` + * - `sz_move` - analog to `memmove` + * - `sz_fill` - analog to `memset` + * - `sz_look_up_transform` - LUT transformation of a string, similar to OpenCV LUT + * - `sz_detect_encoding` - similar to `iconv` or `chardet` + * + * Convenience functions for character-set mapping: + * + * - `sz_tolower`, `sz_toupper`, `sz_toascii` for ASCII ranges */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; +#ifndef STRINGZILLA_MEMORY_H_ +#define STRINGZILLA_MEMORY_H_ - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; +#include "types.h" - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; +#ifdef __cplusplus +extern "C" { +#endif - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } +#pragma region Core API - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } +/** + * @brief Similar to `memcpy`, copies contents of one string into another. + * The behavior is undefined if the strings overlap. + * + * @param target String to copy into. + * @param length Number of bytes to copy. + * @param source String to copy from. + */ +SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} +/** + * @brief Similar to `memmove`, copies (moves) contents of one string into another. + * Unlike `sz_copy`, allows overlapping strings as arguments. + * + * @param target String to copy into. + * @param length Number of bytes to copy. + * @param source String to copy from. + */ +SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { +/** + * @brief Similar to `memset`, fills a string with a given value. + * + * @param target String to fill. + * @param length Number of bytes to fill. + * @param value Value to fill with. + */ +SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; +/** @copydoc sz_copy */ +SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_move */ +SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_fill */ +SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } +#if SZ_USE_HASWELL +/** @copydoc sz_copy */ +SZ_PUBLIC sz_cptr_t sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_move */ +SZ_PUBLIC sz_cptr_t sz_move_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_rfind_fill */ +SZ_PUBLIC sz_cptr_t sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value); #endif - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { +#if SZ_USE_SKYLAKE +/** @copydoc sz_copy */ +SZ_PUBLIC sz_cptr_t sz_copy_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_move */ +SZ_PUBLIC sz_cptr_t sz_move_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_rfind_fill */ +SZ_PUBLIC sz_cptr_t sz_fill_skylake(sz_ptr_t target, sz_size_t length, sz_u8_t value); +#endif - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; +#if SZ_USE_NEON +/** @copydoc sz_copy */ +SZ_PUBLIC sz_cptr_t sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_move */ +SZ_PUBLIC sz_cptr_t sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_rfind_fill */ +SZ_PUBLIC sz_cptr_t sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); +#endif - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; +/** + * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. + * + * Can be used to implement some form of string normalization, partially masking punctuation marks, + * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has + * broad implications in image processing, where image channel transformations are often done using LUTs. + * + * @param text String to be normalized. + * @param length Number of bytes in the string. + * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. + * @param result Output string, can point to the same address as ::text. + */ +SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); +/** @copydoc sz_look_up_transform */ +SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} +#pragma endregion // Core API -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} +#pragma region Helper API /** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 + * @brief Equivalent to `for (char & c : text) c = tolower(c)`. + * + * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. + * So there are 26 english letters, shifted by 32 values, meaning that a conversion + * can be done by flipping the 5th bit each inappropriate character byte. This, however, + * breaks for extended ASCII, so a different solution is needed. + * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * + * @param text String to be normalized. + * @param length Number of bytes in the string. + * @param result Output string, can point to the same address as ::text. */ -#define SZ_U32_MAX_PRIME (2147483647u) +SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); /** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 + * @brief Equivalent to `for (char & c : text) c = toupper(c)`. * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 + * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. + * So there are 26 english letters, shifted by 32 values, meaning that a conversion + * can be done by flipping the 5th bit each inappropriate character byte. This, however, + * breaks for extended ASCII, so a different solution is needed. + * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * + * @param text String to be normalized. + * @param length Number of bytes in the string. + * @param result Output string, can point to the same address as ::text. */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) +SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); +/** + * @brief Equivalent to `for (char & c : text) c = toascii(c)`. * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ + * @param text String to be normalized. + * @param length Number of bytes in the string. + * @param result Output string, can point to the same address as ::text. */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; +SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); +/** + * @brief Checks if all characters in the range are valid ASCII characters. + * + * @param text String to be analyzed. + * @param length Number of bytes in the string. + * @return Whether all characters are valid ASCII characters. + */ +SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} +#pragma endregion // Helper API -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod +#pragma region Serial Implementation /** * @brief Uses a small lookup-table to convert a lowercase character to uppercase. @@ -3128,52 +212,6 @@ SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { return upped[c]; } -/** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. - */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; sz_u8_t const *unsigned_text = (sz_u8_t const *)text; @@ -3216,280 +254,24 @@ SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; -#endif - - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); - - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); - - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; - } - } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} - -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); -} - -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; -} - -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} - -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); + if (*h & 0x80ull) return sz_false_k; +#endif - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} + // Validate eight bytes at once using SWAR. + sz_u64_vec_t text_vec; + for (; h + 8 <= h_end; h += 8) { + text_vec.u64 = *(sz_u64_t const *)h; + if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; + } -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); + // Handle the misaligned tail. + for (; h < h_end; ++h) + if (*h & 0x80ull) return sz_false_k; + return sz_true_k; } -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). +// When overriding libc, disable optimizations for this function because MSVC will optimize the loops into a `memset`. +// Which then causes a stack overflow due to infinite recursion (`memset` -> `sz_fill_serial` -> `memset`). #if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC #pragma optimize("", off) #endif @@ -3548,338 +330,17 @@ SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt #pragma endregion -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; -} - -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} - -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; - } -} - -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} - -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; - } -} - -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} - -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } - - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } - - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; - } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); -} - -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; - - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } - - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } -} - -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if SZ_DETECT_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; - } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif -} - -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} - -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. +/* AVX2 implementation of the string search algorithms for Haswell processors and newer. + * Very minimalistic (compared to AVX-512), but still faster than the serial implementation. */ -#pragma region AVX2 Implementation +#pragma region Haswell Implementation -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL #pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} +#pragma GCC target("haswell") +#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { +SZ_PUBLIC void sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value) { char value_char = *(char *)&value; __m256i value_vec = _mm256_set1_epi8(value_char); // The naive implementation of this function is very simple. @@ -3935,7 +396,7 @@ SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) // For now, let's avoid the cases beyond the L2 size. int is_huge = length > 1ull * 1024ull * 1024ull; if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, + // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_haswell` function, // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, // we can use aligned loads and stores, and the performance will be great. else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { @@ -4002,86 +463,6 @@ SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) } } -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} - SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying @@ -4218,503 +599,24 @@ SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_ if (length) sz_look_up_transform_serial(source, length, lut, target); } -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } - } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - #pragma clang attribute pop #pragma GCC pop_options -#endif -#pragma endregion - -/* - * @brief AVX-512 implementation of the string search algorithms. - * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT - */ -#pragma region AVX512 Implementation - -#if SZ_USE_X86_AVX512 -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include +#endif // SZ_USE_HASWELL +#pragma endregion // Haswell Implementation -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} - -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } - - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } - } - - return sz_equal_k; -} - -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } +/* AVX512 implementation of the string search algorithms for Skylake and newer CPUs. + * Includes extensions: F, CD, ER, PF, VL, DQ, BW. + * + * This is the "starting level" for the advanced algorithms using K-mask registers on x86. + */ +#pragma region Skylake Implementation - return sz_true_k; -} +#if SZ_USE_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { +SZ_PUBLIC void sz_fill_skylake(sz_ptr_t target, sz_size_t length, sz_u8_t value) { __m512i value_vec = _mm512_set1_epi8(value); // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "stores". @@ -4763,7 +665,7 @@ SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt __mmask64 mask = _sz_u64_mask_until(length); _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, + // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_skylake` function, // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, // we can use aligned loads and stores, and the performance will be great. else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { @@ -4886,931 +788,66 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt // - if we are shifting data right, that we are traversing to the left. int const left_to_right_traversal = source > target; - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - } - } -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. - } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); - } - } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. - else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) - -/** - * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - - sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - - // The next few buffers will be swapped around. - sz_u512_vec_t previous_vec, current_vec, next_vec; - sz_u512_vec_t gaps_vec, substitutions_vec; - - // Load the strings into ZMM registers - just once. - sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; - longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); - rotate_left_vec.zmm = _mm512_set_epi8( // - 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // - 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); - rotate_right_vec.zmm = _mm512_set_epi8( // - 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // - 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); - ones_vec.zmm = _mm512_set1_epi8(1); - bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); - - // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Initialize the first two diagonals: - // - // previous_vec.u8s[0] = 0; - // current_vec.u8s[0] = current_vec.u8s[1] = 1; - // - // We can do a similar thing with vector ops: - previous_vec.zmm = _mm512_setzero_si512(); - current_vec.zmm = _mm512_set1_epi8(1); - - // We skip diagonals 0 and 1, as they are trivial. - // We will start with diagonal 2, which has length 3, with the first and last elements being preset, - // so we are effectively computing just one value, as will be marked by a single set bit in - // the `next_diagonal_mask` on the very first iteration. - sz_size_t next_diagonal_index = 2; - __mmask64 next_diagonal_mask = 0; - - // Progress through the upper triangle of the Levenshtein matrix. - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // The mask also adds one set bit. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = current_vec.zmm; - current_vec.zmm = next_vec.zmm; - - // Shift the shorter string - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. - for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { - // After this iteration, the value `shorted_dim - 1` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // Make sure we update the first entry. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the bottom right triangle. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - // In every following iterations we take use a shorter prefix of each register, - // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. - next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); - } - return current_vec.u8s[0]; -} - -/** - * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. - * - * This may be one of the most freuqently called kernels for: - * - source code analysis, assuming most lines are either under 80 or under 120 characters long. - * - DNA sequence alignment, as most short reads are 50-300 characters long. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, - * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. - * - * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - sz_unused(shorter && longer && bound && alloc); - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize! - sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && bound && max_length); - -#if 0 - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. - // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; - sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - // The next few pointers will be swapped around. - sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + longer_dim; - sz_u16_t *next_distances = current_distances + longer_dim; - sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); - - // Export the reversed string into the buffer. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Using ZMM registers, we can process 32x 16-bit values at once, - // storing 16 bytes of each string in YMM registers. - sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; - sz_u512_vec_t ones_u16_vec; - ones_u16_vec.zmm = _mm512_set1_epi16(1); - - // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. - sz_u512_vec_t shorter_vec, longer_vec; - sz_u512_vec_t ones_u8_vec; - ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); - // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` - // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the - // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - - // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to - // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // - remaining_length_mask, - shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow - // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, - _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); - deletions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); - offset_within_diagonal += register_length; - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); - // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. - // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would - // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. - // Which simplified would be equal to just `[i]`. Beautiful! - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. The compare it against the accumulated - // substitution costs. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; - } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; -} - -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; - } - - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} - -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); + // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. + // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. + // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. + // + // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. + // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. + // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. + // + // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! + // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! + // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. + // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. + // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html + // + // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. + // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or + // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, + // and is available with VBMI. That solution is still noticeably slower than AVX2. + // + // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. + // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html + if (left_to_right_traversal) { + // Head, body, and tail. + _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); + for (target += head_length, source += head_length; body_length >= 64; + target += 64, source += 64, body_length -= 64) + _mm512_store_si512(target, _mm512_loadu_si512(source)); + _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); + } + else { + // Tail, body, and head. + _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, + _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); + for (; body_length >= 64; body_length -= 64) + _mm512_store_si512(target + head_length + body_length - 64, + _mm512_loadu_si512(source + head_length + body_length - 64)); + _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); } } } #pragma clang attribute pop #pragma GCC pop_options +#endif // SZ_USE_SKYLAKE +#pragma endregion // Skylake Implementation +/* AVX512 implementation of the string search algorithms for Ice Lake and newer CPUs. + * Includes extensions: + * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, + * - 2018 CannonLake: IFMA, VBMI, + * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. + */ +#pragma region Ice Lake Implementation +#if SZ_USE_ICE #pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ apply_to = function) -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_PUBLIC void sz_look_up_transform_ice(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. @@ -5920,396 +957,20 @@ SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, s } } -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { - - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; -} - -/** - * Computes the Needleman Wunsch alignment score between two strings. - * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. - * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used - * on strings not exceeding 2^24 length or 16.7 million characters. - * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store - * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of - * a 256 x 256 matrix, but from a single row! - */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t const max_length = 256ull * 256ull * 256ull; - sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && max_length); - - sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; - sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); - sz_i32_t *previous_distances = distances; - sz_i32_t *current_distances = previous_distances + n; - - // Intialize the first row of the Levenshtein matrix with `iota`. - for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) - previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; - - /// Contains up to 16 consecutive characters from the longer string. - sz_u512_vec_t longer_vec; - sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; - sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; - sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; - - // Prepare constants and masks. - sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; - { - char is_third_or_fourth_check, is_second_or_fourth_check; - *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; - is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); - is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); - gap_vec.zmm = _mm512_set1_epi32(gap); - } - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; - - // Load one row of the substitution matrix into four ZMM registers. - sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; - row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); - row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); - row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); - row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); - - // In the serial version we have one forward pass, that computes the deletion, - // insertion, and substitution costs at once. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; - // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); - // } - // - // Given the complexity of handling the data-dependency between consecutive insertion cost computations - // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation - // separately. - // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. - // 2. Compute the pairwise minimum with deletion costs. - // 3. Inclusive prefix minimum computation to combine with addition costs. - // Proceeding with substitutions: - for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { - sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); - __mmask64 mask = _sz_u64_mask_until(register_length); - longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); - - // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source - // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. - // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. - shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); - shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); - shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); - shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); - - // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using - // the AND logical operation, checking the top two bits of every byte. - // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. - __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); - __mmask64 is_second_or_fourth = - _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); - lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( - is_third_or_fourth, - // Choose between the first and the second. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), - // Choose between the third and the fourth. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); - - // First, sign-extend lower and upper 16 bytes to 16-bit integers. - __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); - __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); - - // Now extend those 16-bit integers to 32-bit. - // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. - // To minimize the number of loads and stores, we can combine our substitution costs with the previous - // distances, containing the deletion costs. - { - cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); - cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Inclusive prefix minimum computation to combine with insertion costs. - // Simply disabling this operation results in 5x performance improvement, meaning - // that this operation is responsible for 80% of the total runtime. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // current_distances[idx_longer + 1] = - // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); - // } - // - // To perform the same operation in vectorized form, we need to perform a tree-like reduction, - // that will involve multiple steps. It's quite expensive and should be first tested in the - // "experimental" section. - // - // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); - // ... yet this approach is also quite expensive. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 16 to 31. - if (register_length > 16) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 32) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 48) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); - } - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[longer_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); - else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); -} - enum sz_encoding_t { sz_encoding_unknown_k = 0, sz_encoding_ascii_k = 1, sz_encoding_utf8_k = 2, sz_encoding_utf16_k = 3, sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, + sz_encoding_jwt_k = 5, + sz_encoding_base64_k = 6, // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, + sz_encoding_utf8bom_k = 7, + sz_encoding_utf16le_k = 8, + sz_encoding_utf16be_k = 9, + sz_encoding_utf32le_k = 10, + sz_encoding_utf32be_k = 11, }; // Character Set Detection is one of the most commonly performed operations in data processing with @@ -6354,78 +1015,18 @@ SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { #pragma clang attribute pop #pragma GCC pop_options -#endif - -#pragma endregion +#endif // SZ_USE_ICE +#pragma endregion // Ice Lake Implementation -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. +/* Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit + * Arm processors. Covers billions of mobile CPUs worldwide, including Apple's A-series, and Qualcomm's Snapdragon. */ -#pragma region ARM NEON - -#if SZ_USE_ARM_NEON +#pragma region NEON Implementation +#if SZ_USE_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; -} - -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // In most cases the `source` and the `target` are not aligned, but we should // at least make sure that writes don't touch many cache lines. @@ -6524,8 +1125,9 @@ SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_ lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); + blended_0_to_255_vec.u8x16 = vorrq_u8( // + vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), + vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); } @@ -6533,232 +1135,16 @@ SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_ for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; } -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - - h += 16, h_length -= 16; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { - - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} - -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); - - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - - return sz_find_charset_serial(h, h_length, set); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - } - - return sz_rfind_charset_serial(h, h_length, set); -} - #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm Neon - -#pragma endregion +#endif // SZ_USE_NEON +#pragma endregion // NEON Implementation -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. +/* Implementation of the memory operations using the Arm SVE variable-length registers, + * available in Arm v9 processors, like in Apple M4+ and Graviton 3+ CPUs. */ -#pragma region ARM SVE - -#if SZ_USE_ARM_SVE +#pragma region SVE Implementation +#if SZ_USE_SVE #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) @@ -6867,82 +1253,23 @@ SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm SVE - -#pragma endregion +#endif // SZ_USE_SVE +#pragma endregion // SVE Implementation -/* - * @brief Pick the right implementation for the string search algorithms. +/* Pick the right implementation for the string search algorithms. + * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. */ #pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { - - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; - - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ - - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); -} - #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 - return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif -} - -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 - return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); -#endif -} - -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif -} +#pragma region Core Funcitonality SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON sz_copy_neon(target, source, length); #else sz_copy_serial(target, source, length); @@ -6950,11 +1277,11 @@ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { } SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON sz_move_neon(target, source, length); #else sz_move_serial(target, source, length); @@ -6962,11 +1289,11 @@ SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { } SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 - sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 - sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON +#if SZ_USE_ICE + sz_fill_skylake(target, length, value); +#elif SZ_USE_HASWELL + sz_fill_haswell(target, length, value); +#elif SZ_USE_NEON sz_fill_neon(target, length, value); #else sz_fill_serial(target, length, value); @@ -6974,183 +1301,21 @@ SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { } SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON +#if SZ_USE_ICE + sz_look_up_transform_ice(source, length, lut, target); +#elif SZ_USE_HASWELL + sz_look_up_transform_haswell(source, length, lut, target); +#elif SZ_USE_NEON sz_look_up_transform_neon(source, length, lut, target); #else sz_look_up_transform_serial(source, length, lut, target); #endif } -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_find_byte_neon(haystack, h_length, needle); -#else - return sz_find_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); -#else - return sz_rfind_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_find_neon(haystack, h_length, needle, n_length); -#else - return sz_find_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); -#else - return sz_rfind_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); -#else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); -#endif -} - -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} - -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); -#else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); -#endif -} - -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); -#else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); -} - -#endif -#pragma endregion +#endif // !SZ_DYNAMIC_DISPATCH +#pragma endregion // Compile Time Dispatching #ifdef __cplusplus -#pragma GCC diagnostic pop } #endif // __cplusplus - -#endif // STRINGZILLA_H_ +#endif // STRINGZILLA_MEMORY_H_ From 8b401bd41e4bd9c29c8fad9a5b83d8232efa50c7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 15:36:38 +0000 Subject: [PATCH 033/751] Fix: Filter `similarity.h` file --- include/stringzilla/similarity.h | 6890 +++--------------------------- 1 file changed, 554 insertions(+), 6336 deletions(-) diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index de7fbcac..e811fefe 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -1,5140 +1,607 @@ /** - * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. - * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs - * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. - * - * Consider overriding the following macros to customize the library: - * - * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. - * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. - * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. - * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. - * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html - * - * @file stringzilla.h + * @brief Hardware-accelerated string similarity utilities. + * @file similarity.h * @author Ash Vardanian - */ -#ifndef STRINGZILLA_H_ -#define STRINGZILLA_H_ - -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 0 - -/** - * @brief When set to 1, the library will include the following LibC headers: and . - * In debug builds (SZ_DEBUG=1), the library will also include and . * - * You may want to disable this compiling for use in the kernel, or in embedded systems. - * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. - * https://artificial-mind.net/projects/compile-health/ - */ -#ifndef SZ_AVOID_LIBC -#define SZ_AVOID_LIBC (0) // true or false -#endif - -/** - * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address - * that is not divisible by eight. On x86 enabled by default. On ARM it's not. - * - * Most platforms support it, but there is no industry standard way to check for those. - * This value will mostly affect the performance of the serial (SWAR) backend. - */ -#ifndef SZ_USE_MISALIGNED_LOADS -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) -#define SZ_USE_MISALIGNED_LOADS (1) // true or false -#else -#define SZ_USE_MISALIGNED_LOADS (0) // true or false -#endif -#endif - -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif - -/** - * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. - * 64-bit on most platforms where pointers are 64-bit. - * 32-bit on platforms where pointers are 32-bit. - */ -#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. -#else -#define SZ_DETECT_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. -#endif - -/** - * @brief On Big-Endian machines StringZilla will work in compatibility mode. - * This disables SWAR hacks to minimize code duplication, assuming practically - * all modern popular platforms are Little-Endian. + * Includes core APIs: * - * This variable is hard to infer from macros reliably. It's best to set it manually. - * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). - * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. - * https://stackoverflow.com/a/27054190 - */ -#ifndef SZ_DETECT_BIG_ENDIAN -#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ - defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. - */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) -#else -#define SZ_DEBUG (0) -#endif -#endif - -/** - * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. - * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. - * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. - */ -#ifndef SZ_SWAR_THRESHOLD -#if SZ_DEBUG -#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds -#else -#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds -#endif -#endif - -/* Annotation for the public API symbols: + * - `sz_edit_distance` & `sz_edit_distance_utf8` for Levenshtein edit-distance computation. + * - `sz_alignment_score` for weighted Needleman-Wunsch global alignment. + * - `sz_hamming_distance` & `sz_hamming_distance_utf8` for Hamming distance computation. * - * - `SZ_PUBLIC` is used for functions that are part of the public API. - * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. - * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. + * The Hamming distance is rarely used in string processing, so only minimal compatibility is provided. + * The Levenshtein distance, however, is much more popular and computationally intensive. + * So a huge part of this file is focused on optimizing it for different input alphabet sizes and input lengths. */ -#ifndef SZ_DYNAMIC -#if SZ_DYNAMIC_DISPATCH -#if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_DYNAMIC __declspec(dllexport) -#define SZ_EXTERNAL __declspec(dllimport) -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#else -#define SZ_DYNAMIC __attribute__((visibility("default"))) -#define SZ_EXTERNAL extern -#define SZ_PUBLIC __attribute__((unused)) inline static -#define SZ_INTERNAL __attribute__((always_inline)) inline static -#endif // _WIN32 || __CYGWIN__ -#else -#define SZ_DYNAMIC inline static -#define SZ_EXTERNAL extern -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#endif // SZ_DYNAMIC_DISPATCH -#endif // SZ_DYNAMIC +#ifndef STRINGZILLA_SIMILARITY_H_ +#define STRINGZILLA_SIMILARITY_H_ -/** - * @brief Alignment macro for 64-byte alignment. - */ -#if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) -#elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) -#else -#define SZ_ALIGN64 -#endif +#include "types.h" #ifdef __cplusplus extern "C" { #endif -/* - * Let's infer the integer types or pull them from LibC, - * if that is allowed by the user. - */ -#if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` -typedef int8_t sz_i8_t; // Always 8 bits -typedef uint8_t sz_u8_t; // Always 8 bits -typedef uint16_t sz_u16_t; // Always 16 bits -typedef int32_t sz_i32_t; // Always 32 bits -typedef uint32_t sz_u32_t; // Always 32 bits -typedef uint64_t sz_u64_t; // Always 64 bits -typedef int64_t sz_i64_t; // Always 64 bits -typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits -typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits - -#else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. -typedef signed char sz_i8_t; // Always 8 bits -typedef unsigned char sz_u8_t; // Always 8 bits -typedef unsigned short sz_u16_t; // Always 16 bits -typedef int sz_i32_t; // Always 32 bits -typedef unsigned int sz_u32_t; // Always 32 bits -typedef long long sz_i64_t; // Always 64 bits -typedef unsigned long long sz_u64_t; // Always 64 bits - -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT -typedef unsigned long long sz_size_t; // 64-bit. -typedef long long sz_ssize_t; // 64-bit. -#else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT - -#endif // SZ_AVOID_LIBC - -/** - * @brief Compile-time assert macro similar to `static_assert` in C++. - */ -#define sz_static_assert(condition, name) \ - typedef struct { \ - int static_assert_##name : (condition) ? 1 : -1; \ - } sz_static_assert_##name##_t - -sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); -sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); - -#pragma region Public API - -typedef char *sz_ptr_t; // A type alias for `char *` -typedef char const *sz_cptr_t; // A type alias for `char const *` -typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions - -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings - -typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit -typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> - -/** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. - */ -typedef struct sz_string_view_t { - sz_cptr_t start; - sz_size_t length; -} sz_string_view_t; - -/** - * @brief Enumeration of SIMD capabilities of the target architecture. - * Used to introspect the supported functionality of the dynamic library. - */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability - - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability - -} sz_capability_t; - -/** - * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. - * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. - */ -SZ_DYNAMIC sz_capability_t sz_capabilities(void); - -/** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert - */ -typedef union sz_charset_t { - sz_u64_t _u64s[4]; - sz_u32_t _u32s[8]; - sz_u16_t _u16s[16]; - sz_u8_t _u8s[32]; -} sz_charset_t; - -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } - -/** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } - -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast - -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { - // Checking the bit can be done in different ways: - // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 - // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 - // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 - // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 - return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); -} - -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast -} - -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { - s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // - s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; -} - -typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); -typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); - -/** - * @brief Some complex pattern matching algorithms may require memory allocations. - * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed - */ -typedef struct sz_memory_allocator_t { - sz_memory_allocate_t allocate; - sz_memory_free_t free; - void *handle; -} sz_memory_allocator_t; - -/** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. - */ -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); +#pragma region Core API /** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. + * @brief Computes the Hamming distance between two strings - number of not matching characters. + * Difference in length is is counted as a mismatch. * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. - */ -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); - -/** - * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. - * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. - */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE -#endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length - -/** - * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). - * Differs in layout from Folly, Clang, GCC, and probably most other implementations. - * It's designed to avoid any branches on read-only operations, and can store up - * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. + * @param a First string to compare. + * @param a_length Number of bytes in the first string. + * @param b Second string to compare. + * @param b_length Number of bytes in the second string. * - * @section Changing Length + * @param bound Exclusive upper bound on the distance, that allows us to exit early. + * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. + * Pass zero to check if the strings are equal. + * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. + * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. * - * One nice thing about this design, is that you can, in many cases, change the length of the string - * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, - * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, - * only changing the last byte containing the length. + * @see sz_hamming_distance_utf8 + * @see https://en.wikipedia.org/wiki/Hamming_distance */ -typedef union sz_string_t { - -#if !SZ_DETECT_BIG_ENDIAN - - struct external { - sz_ptr_t start; - sz_size_t length; - sz_size_t space; - sz_size_t padding; - } external; - - struct internal { - sz_ptr_t start; - sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; - } internal; - -#else - - struct external { - sz_ptr_t start; - sz_size_t space; - sz_size_t padding; - sz_size_t length; - } external; - - struct internal { - sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; - sz_u8_t length; - } internal; - -#endif - - sz_size_t words[4]; - -} sz_string_t; - -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); -typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); +SZ_DYNAMIC sz_size_t sz_hamming_distance( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); /** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. + * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. + * Difference in length is is counted as a mismatch. * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. + * @param a First string to compare. + * @param a_length Number of bytes in the first string. + * @param b Second string to compare. + * @param b_length Number of bytes in the second string. * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. + * @param bound Exclusive upper bound on the distance, that allows us to exit early. + * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. + * Pass zero to check if the strings are equal. + * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. + * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection + * @see sz_hamming_distance + * @see https://en.wikipedia.org/wiki/Hamming_distance */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); +SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); /** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap + * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. + * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. * * @param a First string to compare. + * @param a_length Number of bytes in the first string. * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. + * @param b_length Number of bytes in the second string. + * + * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, + * so the memory usage is linear in relation to ::a_length and ::b_length. + * If SZ_NULL is passed, will initialize to the systems default `malloc`. + * + * @param bound Exclusive upper bound on the distance, that allows us to exit early. + * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. + * Pass zero to check if the strings are equal. + * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. + * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. + * Returns `SZ_SIZE_MAX` if the memory allocation failed. + * + * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default + * @see https://en.wikipedia.org/wiki/Levenshtein_distance */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +SZ_DYNAMIC sz_size_t sz_edit_distance( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); /** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. + * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. + * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. * * @param a First string to compare. * @param a_length Number of bytes in the first string. * @param b Second string to compare. * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. + * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, + * so the memory usage is linear in relation to ::a_length and ::b_length. + * If SZ_NULL is passed, will initialize to the systems default `malloc`. * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * @param bound Exclusive upper bound on the distance, that allows us to exit early. + * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. + * Pass zero to check if the strings are equal. + * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. + * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. + * Returns `SZ_SIZE_MAX` if the memory allocation failed. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. + * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance + * @see https://en.wikipedia.org/wiki/Levenshtein_distance */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); +SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); /** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. + * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. + * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may + * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. + * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Checks if all characters in the range are valid ASCII characters. + * @param a First string to compare. + * @param a_length Number of bytes in the first string. + * @param b Second string to compare. + * @param b_length Number of bytes in the second string. + * @param gap Penalty cost for gaps - insertions and removals. + * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. - */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - -/** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. + * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, + * so the memory usage is linear in relation to ::a_length and ::b_length. + * If SZ_NULL is passed, will initialize to the systems default `malloc`. * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + * @return Signed similarity score. Can be negative, depending on the substitution costs. + * Returns `SZ_SSIZE_MAX` if the memory allocation failed. * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. + * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default + * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); +SZ_DYNAMIC sz_ssize_t sz_alignment_score( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc); -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); +/** @copydoc sz_hamming_distance */ +SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); -/** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_hamming_distance_utf8 */ +SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_edit_distance */ +SZ_PUBLIC sz_size_t sz_edit_distance_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); -/** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_edit_distance_utf8 */ +SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_alignment_score */ +SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc); -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); +#pragma endregion // Core API -/** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. - */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); +#pragma region Serial Implementation -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } -/** - * @brief Initializes a string class instance to an empty value. - */ -SZ_PUBLIC void sz_string_init(sz_string_t *string); + // TODO: Generalize to remove the following asserts! + sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); + sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); + sz_unused(longer_length && bound); -/** - * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, - * alternative being - allocated in a remote region of the heap. - */ -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); + // We are going to store 3 diagonals of the matrix. + // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. + sz_size_t n = shorter_length + 1; + sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; + sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); + if (!distances) return SZ_SIZE_MAX; -/** - * @brief Unpacks the opaque instance of a string class into its components. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. - * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. - */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); + sz_size_t *previous_distances = distances; + sz_size_t *current_distances = previous_distances + n; + sz_size_t *next_distances = previous_distances + n * 2; -/** - * @brief Unpacks only the start and length of the string. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - */ -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); + // Initialize the first two diagonals: + previous_distances[0] = 0; + current_distances[0] = current_distances[1] = 1; -/** - * @brief Constructs a string of a given ::length with noisy contents. - * Use the returned character pointer to populate the string. - * - * @param string String to initialize. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); - -/** - * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. - * This is beneficial, if several insertions are expected, and we want to minimize allocations. - * - * @param string String to grow. - * @param new_capacity The number of characters to reserve space for, including existing ones. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); - -/** - * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. - * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. - * Similar to `sz_string_reserve`, but changes the length of the ::string. - * - * @param string String to grow. - * @param offset Offset of the first byte to reserve space for. - * If provided offset is larger than the length, it will be capped. - * @param added_length The number of new characters to reserve space for. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); - -/** - * @brief Removes a range from a string. Changes the length, but not the capacity. - * Performs no allocations or deallocations and can't fail. - * - * @param string String to clean. - * @param offset Offset of the first byte to remove. - * @param length Number of bytes to remove. Out-of-bound ranges will be capped. - * @return Number of bytes removed. - */ -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); - -/** - * @brief Shrinks the string to fit the current length, if it's allocated on the heap. - * It's the reverse operation of ::sz_string_reserve. - * - * @param string String to shrink. - * @param allocator Memory allocator to use for the allocation. - * @return Whether the operation was successful. The only failures can come from the allocator. - * On failure, the string will remain unchanged. - */ -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); - -/** - * @brief Frees the string, if it's allocated on the heap. - * If the string is on the stack, the function clears/resets the state. - */ -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); - -#pragma endregion - -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); - -/** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Locates the last matching substring. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -#pragma endregion - -#pragma region String Similarity Measures API - -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); - -/** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); - -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); - -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); - -#pragma endregion - -#pragma region Convenience API - -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -#pragma endregion - -#pragma region String Sequences API - -struct sz_sequence_t; - -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); - -typedef struct sz_sequence_t { - sz_sorted_idx_t *order; - sz_size_t count; - sz_sequence_member_start_t get_start; - sz_sequence_member_length_t get_length; - void const *handle; -} sz_sequence_t; - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); - -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. - * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. - */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); - -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); - -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); - -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. - */ -#ifndef SZ_USE_X86_AVX512 -#ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 -#else -#define SZ_USE_X86_AVX512 0 -#endif -#endif - -#ifndef SZ_USE_X86_AVX2 -#ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 -#else -#define SZ_USE_X86_AVX2 0 -#endif -#endif - -#ifndef SZ_USE_ARM_NEON -#ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 -#else -#define SZ_USE_ARM_NEON 0 -#endif -#endif - -#ifndef SZ_USE_ARM_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 -#else -#define SZ_USE_ARM_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 -#include -#endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_ARM_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_X86_AVX512 - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_X86_AVX2 -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_ARM_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_ARM_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes - -/** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. - * @note If you want to catch it, put a breakpoint at @b `__GI_exit` - */ -#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` -#include // `EXIT_FAILURE` -SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { - fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); - exit(EXIT_FAILURE); -} -#define sz_assert(condition) \ - do { \ - if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ - } while (0) -#else -#define sz_assert(condition) ((void)(condition)) -#endif - -/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. - * The following section of compiler intrinsics comes in 2 flavors. - */ -#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL -#include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. -#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { - x = x - ((x >> 1) & 0x5555555555555555ull); - x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); - return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; -} -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x80000000u) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { - x = x - ((x >> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >> 2) & 0x33333333); - return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; -} -#else -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } -#endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given -#pragma intrinsic(_byteswap_uint64) -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } -#pragma intrinsic(_byteswap_ulong) -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } -#else -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } -#endif - -SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } - -/** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. - * - * Similar to `_mm_blend_epi16` intrinsic on x86. - * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. - * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching - */ -SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } - -/* - * Efficiently computing the minimum and maximum of two or three values can be tricky. - * The simple branching baseline would be: - * - * x < y ? x : y // can replace with 1 conditional move - * - * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. - * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function - * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax - * Using only bit-shifts for singed integers it would be: - * - * y + ((x - y) & (x - y) >> 31) // 4 unique operations - * - * Alternatively, for any integers using multiplication: - * - * (x > y) * y + (x <= y) * x // 5 operations - * - * Alternatively, to avoid multiplication: - * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if SZ_DETECT_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if SZ_DETECT_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. - * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; - - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - -/** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - */ -#define SZ_U32_MAX_PRIME (2147483647u) - -/** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 - */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) - -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} - -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { - static sz_u8_t const lowered[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return lowered[c]; -} - -/** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { - static sz_u8_t const upped[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return upped[c]; -} - -/** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. - */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); -} - -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; -} - -/** - * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { - - if (!length) return sz_true_k; - sz_u8_t const *h = (sz_u8_t const *)text; - sz_u8_t const *const h_end = h + length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; -#endif - - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); - - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); - - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; - } - } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} - -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); -} - -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; -} - -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} - -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); - - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} - -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); -} - -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", off) -#endif -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - sz_ptr_t end = target + length; - // Dealing with short strings, a single sequential pass would be faster. - // If the size is larger than 2 words, then at least 1 of them will be aligned. - // But just one aligned word may not be worth SWAR. - if (length < SZ_SWAR_THRESHOLD) - while (target != end) *(target++) = value; - - // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. - else { - sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; - while ((sz_size_t)target & 7ull) *(target++) = value; - while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; - while (target != end) *(target++) = value; - } -} -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", on) -#endif - -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); -} - -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. - // Existing implementations often have two passes, in normal and reversed order, - // depending on the relation of `target` and `source` addresses. - // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html - // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 - // - // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. - // Or if we know that they don't intersect! In that case the traversal order is irrelevant, - // but older CPUs may predict and fetch forward-passes better. - if (target < source || target >= source + length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - target += length, source += length; -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; -#endif - while (length--) *(--target) = *(--source); - } -} - -#pragma endregion - -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; -} - -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} - -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; - } -} - -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} - -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; - } -} - -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} - -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } - - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } - - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; - } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); -} - -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; - - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } - - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } -} - -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if SZ_DETECT_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; - } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif -} - -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} - -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. - */ -#pragma region AVX2 Implementation - -#if SZ_USE_X86_AVX2 -#pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - char value_char = *(char *)&value; - __m256i value_vec = _mm256_set1_epi8(value_char); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); - // sz_fill_serial(target, length, value); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 32) sz_fill_serial(target, length, value); - // When the buffer is aligned, we can avoid any split-stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u16_t value16 = (sz_u16_t)value * 0x0101u; - sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; - sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; - } -} - -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 32; target += 32, source += 32, length -= 32) - // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - // sz_copy_serial(target, source, length); - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); - if (length) sz_copy_serial(target, source, length); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; target += 32, source += 32, body_length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { - _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); - _mm256_store_si256((__m256i *)(target + body_length - 32), - _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); - } - if (body_length) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); + // Progress through the upper triangle of the Levenshtein matrix. + sz_size_t next_diagonal_index = 2; + for (; next_diagonal_index != n; ++next_diagonal_index) { + sz_size_t const next_diagonal_length = next_diagonal_index + 1; + for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { + sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; + sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; + sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; + next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); } - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; - } -} - -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - for (target += length, source += length; length >= 32; length -= 32) - _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); - while (length--) *(--target) = *(--source); + // Don't forget to populate the first row and the first column of the Levenshtein matrix. + next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; + // Perform a circular rotation of those buffers, to reuse the memory. + sz_size_t *temporary = previous_distances; + previous_distances = current_distances; + current_distances = next_distances; + next_distances = temporary; } -} -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); + // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a + // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal + // index on either side, we will be cropping those values out. + sz_size_t diagonals_count = n + n - 1; + for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { + sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; + for (sz_size_t i = 0; i != next_diagonal_length; ++i) { + sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; + sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; + sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; + next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; + // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, + // dropping the first element in the current array. + sz_size_t *temporary = previous_distances; + previous_distances = current_distances + 1; + current_distances = next_distances; + next_distances = temporary; } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } + // Cache scalar before `free` call. + sz_size_t result = current_distances[0]; + alloc->free(distances, buffer_length, alloc->handle); + return result; } -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +/** + * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. + * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, + * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. + * + * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: + * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. + * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. + * = 2400 bytes of memory or @b 12x memory amplification! + */ +SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; } - // We need to pull the lookup table into 8x YMM registers. - // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, - // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, - // so that we can at least compensate high latency with twice larger window and one more level of lookup. - sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // - lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // - lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // - lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; - - lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); - lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); - lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); - lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); - lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); - lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); - lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); - lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); - lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); - lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); - lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); - lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); - lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); - lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); - lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); - lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); - - // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. - sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; - - /// Top and bottom nibbles of the source are used separately. - sz_u256_vec_t source_vec, source_bot_vec; - sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, - blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; - - // Handling the head. - while (length >= 32) { - // Load and separate the nibbles of each byte in the source. - source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); - source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); - - // In the first round, we select using the 4th bit. - not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - - // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. - // The first round selects using the 3rd bit. - not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_32_to_63_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_third_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - blended_96_to_127_vec.ymm, // - blended_64_to_95_vec.ymm, // - not_third_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_160_to_191_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_third_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - blended_224_to_255_vec.ymm, // - blended_192_to_223_vec.ymm, // - not_third_bit_vec.ymm); - - // The second round selects using the 2nd bit. - not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_64_to_95_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_second_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_192_to_223_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_second_bit_vec.ymm); - - // The third round selects using the 1st bit. - not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_128_to_159_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_first_bit_vec.ymm); - - // And dump the result into the target. - _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); - source += 32, target += 32, length -= 32; - } + // A good idea may be to dispatch different kernels for different string lengths. + // Like using `uint8_t` counters for strings under 255 characters long. + // Good in theory, this results in frequent upcasts and downcasts in serial code. + // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. + // So one must be very cautious with such optimizations. + typedef sz_size_t _distance_t; - // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); -} + // Compute the number of columns in our Levenshtein matrix. + sz_size_t const n = shorter_length + 1; -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); + // If a buffering memory-allocator is provided, this operation is practically free, + // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. + sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; + // If the strings contain Unicode characters, let's estimate the max character width, + // and use it to allocate a larger buffer to decode UTF8. + if ((can_be_unicode == sz_true_k) && + (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { + buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); } + else { can_be_unicode = sz_false_k; } - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); + // If the allocation fails, return the maximum distance. + sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); + if (!buffer) return SZ_SIZE_MAX; - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; + // Let's export the UTF8 sequence into the newly allocated buffer at the end. + if (can_be_unicode == sz_true_k) { + sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); + sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; + // Export the UTF8 sequences into the newly allocated buffer. + longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); + shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); + longer = (sz_cptr_t)longer_utf32; + shorter = (sz_cptr_t)shorter_utf32; } - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } + // Let's parameterize the core logic for different character types and distance types. +#define _wagner_fisher_unbounded(_distance_t, _char_t) \ + /* Now let's cast our pointer to avoid it in subsequent sections. */ \ + _char_t const *const longer_chars = (_char_t const *)longer; \ + _char_t const *const shorter_chars = (_char_t const *)shorter; \ + _distance_t *previous_distances = (_distance_t *)buffer; \ + _distance_t *current_distances = previous_distances + n; \ + /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ + for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ + /* The main loop of the algorithm with quadratic complexity. */ \ + for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ + _char_t const longer_char = longer_chars[idx_longer]; \ + /* Using pure pointer arithmetic is faster than iterating with an index. */ \ + _char_t const *shorter_ptr = shorter_chars; \ + _distance_t const *previous_ptr = previous_distances; \ + _distance_t *current_ptr = current_distances; \ + _distance_t *const current_end = current_ptr + shorter_length; \ + current_ptr[0] = idx_longer + 1; \ + for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ + _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ + /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ + /* saving one increment operation. */ \ + _distance_t cost_deletion = previous_ptr[1]; \ + _distance_t cost_insertion = current_ptr[0]; \ + /* ? It might be a good idea to enforce branchless execution here. */ \ + /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ + current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ + } \ + /* Swap `previous_distances` and `current_distances` pointers. */ \ + _distance_t *temporary = previous_distances; \ + previous_distances = current_distances; \ + current_distances = temporary; \ + } \ + /* Cache scalar before `free` call. */ \ + sz_size_t result = previous_distances[shorter_length]; \ + alloc->free(buffer, buffer_length, alloc->handle); \ + return result; - return sz_find_serial(h, h_length, n, n_length); -} + // Let's define a separate variant for bounded distance computation. + // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. +#define _wagner_fisher_bounded(_distance_t, _char_t) \ + _char_t const *const longer_chars = (_char_t const *)longer; \ + _char_t const *const shorter_chars = (_char_t const *)shorter; \ + _distance_t *previous_distances = (_distance_t *)buffer; \ + _distance_t *current_distances = previous_distances + n; \ + for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ + for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ + _char_t const longer_char = longer_chars[idx_longer]; \ + _char_t const *shorter_ptr = shorter_chars; \ + _distance_t const *previous_ptr = previous_distances; \ + _distance_t *current_ptr = current_distances; \ + _distance_t *const current_end = current_ptr + shorter_length; \ + current_ptr[0] = idx_longer + 1; \ + /* Initialize min_distance with a value greater than bound */ \ + _distance_t min_distance = bound - 1; \ + for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ + _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ + _distance_t cost_deletion = previous_ptr[1]; \ + _distance_t cost_insertion = current_ptr[0]; \ + current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ + /* Keep track of the minimum distance seen so far in this row */ \ + min_distance = sz_min_of_two(current_ptr[1], min_distance); \ + } \ + /* If the minimum distance in this row exceeded the bound, return early */ \ + if (min_distance >= bound) { \ + alloc->free(buffer, buffer_length, alloc->handle); \ + return bound; \ + } \ + _distance_t *temporary = previous_distances; \ + previous_distances = current_distances; \ + current_distances = temporary; \ + } \ + sz_size_t result = previous_distances[shorter_length]; \ + alloc->free(buffer, buffer_length, alloc->handle); \ + return sz_min_of_two(result, bound); -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } + // Dispatch the actual computation. + if (!bound) { + if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } + else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } + else { + if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } + else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; } -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } +SZ_PUBLIC sz_size_t sz_edit_distance_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } + // Let's make sure that we use the amount proportional to the + // number of elements in the shorter string, not the larger. + if (shorter_length > longer_length) { + sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); + sz_pointer_swap((void **)&longer, (void **)&shorter); } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif -#pragma endregion - -/* - * @brief AVX-512 implementation of the string search algorithms. - * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT - */ -#pragma region AVX512 Implementation - -#if SZ_USE_X86_AVX512 -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } + // Skip the matching prefixes and suffixes, they won't affect the distance. + for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; + longer != a_end && shorter != b_end && *longer == *shorter; + ++longer, ++shorter, --longer_length, --shorter_length); + for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; + --longer_length, --shorter_length); - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } + // Bounded computations may exit early. + int const is_bounded = bound < longer_length; + if (is_bounded) { + // If one of the strings is empty - the edit distance is equal to the length of the other one. + if (longer_length == 0) return sz_min_of_two(shorter_length, bound); + if (shorter_length == 0) return sz_min_of_two(longer_length, bound); + // If the difference in length is beyond the `bound`, there is no need to check at all. + if (longer_length - shorter_length > bound) return bound; } - return sz_equal_k; + if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. + if (shorter_length == longer_length && !is_bounded) + return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); + return _sz_edit_distance_wagner_fisher_serial( // + longer, longer_length, shorter, shorter_length, bound, sz_false_k, alloc); } -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } +SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc) { - return sz_true_k; -} + // If one of the strings is empty - the edit distance is equal to the length of the other one + if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; + if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - __m512i value_vec = _mm512_set1_epi8(value); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); - // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, value_vec); - } - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, value_vec); - for (target += head_length; body_length >= 64; target += 64, body_length -= 64) - _mm512_store_si512(target, value_vec); - _mm512_mask_storeu_epi8(target, tail_mask, value_vec); + // Let's make sure that we use the amount proportional to the + // number of elements in the shorter string, not the larger. + if (shorter_length > longer_length) { + sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); + sz_pointer_swap((void **)&longer, (void **)&shorter); } -} -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 64; target += 64, source += 64, length -= 64) - // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); - // __mmask64 mask = _sz_u64_mask_until(length); - // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { - for (; length >= 64; target += 64, source += 64, length -= 64) - _mm512_store_si512(target, _mm512_load_si512(source)); - // At this point the length is guaranteed to be under 64. - __mmask64 mask = _sz_u64_mask_until(length); - // Aligned load and stores would work too, but it's not defined. - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. - // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(target + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source)); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (target += head_length, source += head_length; // - body_length >= 128; // - target += 64, source += 64, body_length -= 128) { - _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); - _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); - } - if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; } -} -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target == source) return; // Don't be silly, don't move the data if it's already there. + sz_size_t n = shorter_length + 1; + sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; + sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); + sz_ssize_t *previous_distances = distances; + sz_ssize_t *current_distances = previous_distances + n; - // On very short buffers, that are one cache line in width or less, we don't need any loops. - // We can also avoid any data-dependencies between iterations, assuming we have 32 registers - // to pre-load the data, before writing it back. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - else if (length <= 128) { - sz_size_t last_length = length - 64; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); - _mm512_storeu_epi8(target, source0); - _mm512_mask_storeu_epi8(target + 64, mask, source1); - } - else if (length <= 192) { - sz_size_t last_length = length - 128; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_mask_storeu_epi8(target + 128, mask, source2); - } - else if (length <= 256) { - sz_size_t last_length = length - 192; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_loadu_epi8(source + 128); - __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_storeu_epi8(target + 128, source2); - _mm512_mask_storeu_epi8(target + 192, mask, source3); - } + for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) + previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } + sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; + sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; + for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { + current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // The absolute most common case of using "moves" is shifting the data within a continuous buffer - // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, - // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. - // - // Remember: - // - if we are shifting data left, that we are traversing to the right. - // - if we are shifting data right, that we are traversing to the left. - int const left_to_right_traversal = source > target; - - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); + // Initialize min_distance with a value greater than bound + sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; + for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { + sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; + sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; + sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; + current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); } + + // Swap previous_distances and current_distances pointers + sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); } + + // Cache scalar before `free` call. + sz_ssize_t result = previous_distances[shorter_length]; + alloc->free(distances, buffer_length, alloc->handle); + return result; } -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); +SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } + sz_size_t const min_length = sz_min_of_two(a_length, b_length); + sz_size_t const max_length = sz_max_of_two(a_length, b_length); + sz_cptr_t const a_end = a + min_length; + bound = bound == 0 ? max_length : bound; - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); + // Walk through both strings using SWAR and counting the number of differing characters. + sz_size_t distance = max_length - min_length; +#if SZ_USE_MISALIGNED_LOADS && !_SZ_IS_BIG_ENDIAN + if (min_length >= SZ_SWAR_THRESHOLD) { + sz_u64_vec_t a_vec, b_vec, match_vec; + for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { + a_vec.u64 = sz_u64_load(a).u64; + b_vec.u64 = sz_u64_load(b).u64; + match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); + distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); + } } +#endif - return SZ_NULL_CHAR; + for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } + return sz_min_of_two(distance, bound); } -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } +SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound) { + + sz_cptr_t const a_end = a + a_length; + sz_cptr_t const b_end = b + b_length; + sz_size_t distance = 0; - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. + sz_rune_t a_rune, b_rune; + sz_rune_length_t a_rune_length, b_rune_length; + + if (bound) { + for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { + _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); + distance += (a_rune != b_rune); } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); + // If one string has more runes, we need to go through the tail. + if (distance < bound) { + for (; a < a_end && distance < bound; a += a_rune_length, ++distance) + _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + + for (; b < b_end && distance < bound; b += b_rune_length, ++distance) + _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); } } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; + for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { + _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); + distance += (a_rune != b_rune); } + // If one string has more runes, we need to go through the tail. + for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); + for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); } - return SZ_NULL_CHAR; + return distance; } -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} +#pragma endregion // Serial Implementation -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } +/* AVX2 implementation of the string similarity algorithms for Haswell processors and newer. + * Very minimalistic (compared to AVX-512), but still faster than the serial implementation. + */ +#pragma region Haswell Implementation +#if SZ_USE_HASWELL +#pragma GCC push_options +#pragma GCC target("haswell") +#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SZ_USE_HASWELL +#pragma endregion // Haswell Implementation - return SZ_NULL_CHAR; -} +/* AVX512 implementation of the string similarity algorithms for Skylake and newer CPUs. + * Includes extensions: F, CD, ER, PF, VL, DQ, BW. + * + * This is the "starting level" for the advanced algorithms using K-mask registers on x86. + */ +#pragma region Skylake Implementation +#if SZ_USE_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) #pragma clang attribute pop #pragma GCC pop_options +#endif // SZ_USE_SKYLAKE +#pragma endregion // Skylake Implementation +/* AVX512 implementation of the string similarity algorithms for Ice Lake and newer CPUs. + * Includes extensions: + * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, + * - 2018 CannonLake: IFMA, VBMI, + * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. + */ +#pragma region Ice Lake Implementation +#if SZ_USE_ICE #pragma GCC push_options #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ @@ -5317,7 +784,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // * Uses a lot more CPU registers space, than the `upto63` variant. * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. * - * This may be one of the most freuqently called kernels for: + * This may be one of the most frequently called kernels for: * - source code analysis, assuming most lines are either under 80 or under 120 characters long. * - DNA sequence alignment, as most short reads are 50-300 characters long. */ @@ -5378,7 +845,6 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. * * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * */ SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // sz_cptr_t shorter, sz_size_t shorter_length, // @@ -5439,7 +905,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // ones_u16_vec.zmm = _mm512_set1_epi16(1); // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. + // Even there, in case `SZ_USE_HASWELL=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. sz_u512_vec_t shorter_vec, longer_vec; sz_u512_vec_t ones_u8_vec; ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); @@ -5527,539 +993,60 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // // First get the minimum of insertions and deletions. next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; - } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; -} - -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; - } - - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} - -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ - apply_to = function) - -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // We need to pull the lookup table into 4x ZMM registers. - // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` - // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to - // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. - // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. - // - // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 6 cycles latency, ports: 1*FP12 - // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p05 - // - On Genoa: 1 cycle latency, ports: 1*FP0123 - // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 4 cycles latency, ports: 1*FP01 - // - sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); - lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); - lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); - lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); - - sz_u512_vec_t first_bit_vec, second_bit_vec; - first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); - second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); - - __mmask64 first_bit_mask, second_bit_mask; - sz_u512_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; - - // Handling the head. - if (head_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); - source += head_length, target += head_length, length -= head_length; - } + _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); + i += register_length; + } - // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. - while (length >= 64) { - source_vec.zmm = _mm512_loadu_si512(source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! - source += 64, target += 64, length -= 64; + // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, + // dropping the first element in the current array. + sz_u16_t *temporary = previous_distances; + previous_distances = current_distances + 1; + current_distances = next_distances; + next_distances = temporary; } - // Handling the tail. - if (tail_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); - source += tail_length, target += tail_length, length -= tail_length; - } + // Cache scalar before `free` call. + sz_size_t result = current_distances[0]; + alloc->free(distances, buffer_length, alloc->handle); + return result; +#endif + return 0; } -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } + // Bounded computations may exit early. + int const is_bounded = bound < longer_length; + if (is_bounded) { + // If one of the strings is empty - the edit distance is equal to the length of the other one. + if (longer_length == 0) return sz_min_of_two(shorter_length, bound); + if (shorter_length == 0) return sz_min_of_two(longer_length, bound); + // If the difference in length is beyond the `bound`, there is no need to check at all. + if (longer_length - shorter_length > bound) return bound; } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { + // Make sure the shorter string is actually shorter. + if (shorter_length > longer_length) { + sz_cptr_t temporary = shorter; + shorter = longer; + longer = temporary; + sz_size_t temporary_length = shorter_length; + shorter_length = longer_length; + longer_length = temporary_length; + } - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; + // Dispatch the right implementation based on the length of the strings. + if (longer_length < 64u) + return _sz_edit_distance_skewed_diagonals_upto63_avx512( // + shorter, shorter_length, longer, longer_length, bound); + // else if (longer_length < 256u * 256u) + // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // + // shorter, shorter_length, longer, longer_length, bound, alloc); + else + return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); } /** @@ -6075,9 +1062,9 @@ SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of * a 256 x 256 matrix, but from a single row! */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { // If one of the strings is empty - the edit distance is equal to the length of the other one @@ -6284,779 +1271,57 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // return result; } -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_ssize_t sz_alignment_score_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); + return _sz_alignment_score_wagner_fisher_upto17m_ice(shorter, shorter_length, longer, longer_length, subs, gap, + alloc); else return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); } -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, - // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} - #pragma clang attribute pop #pragma GCC pop_options -#endif +#endif // SZ_USE_ICE +#pragma endregion // Ice Lake Implementation -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. +/* Implementation of the similarity algorithms using the Arm NEON instruction set, available on 64-bit + * Arm processors. Covers billions of mobile CPUs worldwide, including Apple's A-series, and Qualcomm's Snapdragon. */ -#pragma region ARM NEON - -#if SZ_USE_ARM_NEON +#pragma region NEON Implementation +#if SZ_USE_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; -} - -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // In most cases the `source` and the `target` are not aligned, but we should - // at least make sure that writes don't touch many cache lines. - // NEON has an instruction to load and write 64 bytes at once. - // - // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; - // length -= head_length; - // for (; length >= 64; target += 64, source += 64, length -= 64) - // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); - // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; - // - // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: - for (; length >= 16; target += 16, source += 16, length -= 16) - vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); - if (length) sz_copy_serial(target, source, length); -} - -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // When moving small buffers, using a small buffer on stack as a temporary storage is faster. - - if (target < source || target >= source + length) { - // Non-overlapping, proceed forward - sz_copy_neon(target, source, length); - } - else { - // Overlapping, proceed backward - target += length; - source += length; - - sz_u128_vec_t src_vec; - while (length >= 16) { - target -= 16, source -= 16, length -= 16; - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - } - while (length) { - target -= 1, source -= 1, length -= 1; - *target = *source; - } - } -} - -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register - - while (length >= 16) { - vst1q_u8((sz_u8_t *)target, fill_vec); - target += 16; - length -= 16; - } - - // Handle remaining bytes - if (length) sz_fill_serial(target, length, value); -} - -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. - - // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. - // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. - uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); - lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); - lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); - lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); - - sz_u128_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u128_vec_t blended_0_to_255_vec; - - // Process the head with serial code - for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; - - // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction - // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position - // within each 64-byte range of the table. - // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ - length -= head_length; - length -= tail_length; - for (; length >= 16; source += 16, target += 16, length -= 16) { - source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); - lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); - lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); - lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); - vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); - } - - // Process the tail with serial code - for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - - h += 16, h_length -= 16; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { - - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} - -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); - - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - - return sz_find_charset_serial(h, h_length, set); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - } - - return sz_rfind_charset_serial(h, h_length, set); -} - #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm Neon +#endif // SZ_USE_NEON +#pragma endregion // NEON Implementation -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. +/* Implementation of the string search algorithms using the Arm SVE variable-length registers, + * available in Arm v9 processors, like in Apple M4+ and Graviton 3+ CPUs. */ -#pragma region ARM SVE - -#if SZ_USE_ARM_SVE +#pragma region SVE Implementation +#if SZ_USE_SVE #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - svuint8_t value_vec = svdup_u8(value); - sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) - - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svst1_u8(mask, (unsigned char *)target, value_vec); - } - else { - // Calculate head, body, and tail sizes - sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); - sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned head - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svst1_u8(head_mask, (unsigned char *)target, value_vec); - target += head_length; - - // Aligned body loop - for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { - svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); - } - - // Handle unaligned tail - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svst1_u8(tail_mask, (unsigned char *)target, value_vec); - } -} - -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_size_t vec_len = svcntb(); // Vector length in bytes - - // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, - // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - // - // int is_huge = length >= 4ull * 1024ull * 1024ull; - // - // When the buffer is small, there isn't much to innovate. - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations - // and handle the head, body, and tail separately. We can also traverse the buffer in both directions - // as Arm generally supports more simultaneous stores than x86 CPUs. - // - // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. - // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) - // we will pay a huge penalty on loads, fetching the same content many times. - // It may be better to allow caching (and subsequent eviction), in favor of using four-element - // tuples, wich will be guaranteed to be a multiple of a cache line. - // - // Another approach is to use the `LD4B` instructions, which will populate four registers at once. - // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. - else { - // Calculating head, body, and tail sizes depends on the `vec_len`, - // but it's runtime constant, and the modulo operation is expensive! - // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. - sz_size_t head_length = 16 - ((sz_size_t)target % 16); - sz_size_t tail_length = (sz_size_t)(target + length) % 16; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned parts - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); - svst1_u8(head_mask, (unsigned char *)target, head_data); - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); - svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); - target += head_length; - source += head_length; - - // Aligned body loop, walking in two directions - for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { - svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); - svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); - svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); - svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); - } - // Up to (vec_len * 2 - 1) bytes of data may be left in the body, - // so we can unroll the last two optional loop iterations. - if (body_length > vec_len) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - body_length -= vec_len; - source += body_length; - target += body_length; - } - if (body_length) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - } -} - #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm SVE +#endif // SZ_USE_SVE +#pragma endregion // SVE Implementation -#pragma endregion - -/* - * @brief Pick the right implementation for the string search algorithms. +/* Pick the right implementation for the string search algorithms. + * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. */ #pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { - - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; - - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ - - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); -} - #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 - return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif -} - -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 - return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); -#endif -} - -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif -} - -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_copy_neon(target, source, length); -#else - sz_copy_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_move_neon(target, source, length); -#else - sz_move_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 - sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 - sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON - sz_fill_neon(target, length, value); -#else - sz_fill_serial(target, length, value); -#endif -} - -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON - sz_look_up_transform_neon(source, length, lut, target); -#else - sz_look_up_transform_serial(source, length, lut, target); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_find_byte_neon(haystack, h_length, needle); -#else - return sz_find_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); -#else - return sz_rfind_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_find_neon(haystack, h_length, needle, n_length); -#else - return sz_find_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); -#else - return sz_rfind_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif -} - SZ_DYNAMIC sz_size_t sz_hamming_distance( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // @@ -7075,7 +1340,7 @@ SZ_DYNAMIC sz_size_t sz_edit_distance( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); #else return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); @@ -7089,68 +1354,21 @@ SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); } -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); +SZ_DYNAMIC sz_ssize_t sz_alignment_score( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { +#if SZ_USE_ICE + return sz_alignment_score_ice(a, a_length, b, b_length, subs, gap, alloc); #else return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); #endif } -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); -#else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); -} - -#endif -#pragma endregion +#endif // !SZ_DYNAMIC_DISPATCH +#pragma endregion // Compile Time Dispatching #ifdef __cplusplus -#pragma GCC diagnostic pop } #endif // __cplusplus - -#endif // STRINGZILLA_H_ +#endif // STRINGZILLA_SIMISLARITY_H_ From be4c63d926c8628451726863e4d14dbd1ea374dd Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 15:37:01 +0000 Subject: [PATCH 034/751] Fix: Filter `hash.h` file --- include/stringzilla/hash.h | 7483 +++--------------------------------- 1 file changed, 621 insertions(+), 6862 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index de7fbcac..bf24a5e6 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -1,422 +1,30 @@ /** - * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. - * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs - * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. - * - * Consider overriding the following macros to customize the library: - * - * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. - * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. - * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. - * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. - * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html - * - * @file stringzilla.h + * @brief Hardware-accelerated string hashing and checksums. + * @file hash.h * @author Ash Vardanian - */ -#ifndef STRINGZILLA_H_ -#define STRINGZILLA_H_ - -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 0 - -/** - * @brief When set to 1, the library will include the following LibC headers: and . - * In debug builds (SZ_DEBUG=1), the library will also include and . * - * You may want to disable this compiling for use in the kernel, or in embedded systems. - * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. - * https://artificial-mind.net/projects/compile-health/ - */ -#ifndef SZ_AVOID_LIBC -#define SZ_AVOID_LIBC (0) // true or false -#endif - -/** - * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address - * that is not divisible by eight. On x86 enabled by default. On ARM it's not. + * Includes core APIs: * - * Most platforms support it, but there is no industry standard way to check for those. - * This value will mostly affect the performance of the serial (SWAR) backend. - */ -#ifndef SZ_USE_MISALIGNED_LOADS -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) -#define SZ_USE_MISALIGNED_LOADS (1) // true or false -#else -#define SZ_USE_MISALIGNED_LOADS (0) // true or false -#endif -#endif - -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif - -/** - * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. - * 64-bit on most platforms where pointers are 64-bit. - * 32-bit on platforms where pointers are 32-bit. - */ -#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. -#else -#define SZ_DETECT_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. -#endif - -/** - * @brief On Big-Endian machines StringZilla will work in compatibility mode. - * This disables SWAR hacks to minimize code duplication, assuming practically - * all modern popular platforms are Little-Endian. + * - `sz_checksum` - for byte-level checksums. + * - `sz_hash` - for 64-bit single-shot hashing. + * - `sz_hashes` - producing the rolling hashes of a string. + * - `sz_generate` - populating buffers with random data. * - * This variable is hard to infer from macros reliably. It's best to set it manually. - * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). - * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. - * https://stackoverflow.com/a/27054190 - */ -#ifndef SZ_DETECT_BIG_ENDIAN -#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ - defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. - */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) -#else -#define SZ_DEBUG (0) -#endif -#endif - -/** - * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. - * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. - * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. - */ -#ifndef SZ_SWAR_THRESHOLD -#if SZ_DEBUG -#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds -#else -#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds -#endif -#endif - -/* Annotation for the public API symbols: + * Convenience functions for character-set matching: * - * - `SZ_PUBLIC` is used for functions that are part of the public API. - * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. - * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. + * - `sz_hashes_fingerprint` + * - `sz_hashes_intersection` */ -#ifndef SZ_DYNAMIC -#if SZ_DYNAMIC_DISPATCH -#if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_DYNAMIC __declspec(dllexport) -#define SZ_EXTERNAL __declspec(dllimport) -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#else -#define SZ_DYNAMIC __attribute__((visibility("default"))) -#define SZ_EXTERNAL extern -#define SZ_PUBLIC __attribute__((unused)) inline static -#define SZ_INTERNAL __attribute__((always_inline)) inline static -#endif // _WIN32 || __CYGWIN__ -#else -#define SZ_DYNAMIC inline static -#define SZ_EXTERNAL extern -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#endif // SZ_DYNAMIC_DISPATCH -#endif // SZ_DYNAMIC +#ifndef STRINGZILLA_HASH_H_ +#define STRINGZILLA_HASH_H_ -/** - * @brief Alignment macro for 64-byte alignment. - */ -#if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) -#elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) -#else -#define SZ_ALIGN64 -#endif +#include "types.h" #ifdef __cplusplus extern "C" { #endif -/* - * Let's infer the integer types or pull them from LibC, - * if that is allowed by the user. - */ -#if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` -typedef int8_t sz_i8_t; // Always 8 bits -typedef uint8_t sz_u8_t; // Always 8 bits -typedef uint16_t sz_u16_t; // Always 16 bits -typedef int32_t sz_i32_t; // Always 32 bits -typedef uint32_t sz_u32_t; // Always 32 bits -typedef uint64_t sz_u64_t; // Always 64 bits -typedef int64_t sz_i64_t; // Always 64 bits -typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits -typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits - -#else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. -typedef signed char sz_i8_t; // Always 8 bits -typedef unsigned char sz_u8_t; // Always 8 bits -typedef unsigned short sz_u16_t; // Always 16 bits -typedef int sz_i32_t; // Always 32 bits -typedef unsigned int sz_u32_t; // Always 32 bits -typedef long long sz_i64_t; // Always 64 bits -typedef unsigned long long sz_u64_t; // Always 64 bits - -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT -typedef unsigned long long sz_size_t; // 64-bit. -typedef long long sz_ssize_t; // 64-bit. -#else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT - -#endif // SZ_AVOID_LIBC - -/** - * @brief Compile-time assert macro similar to `static_assert` in C++. - */ -#define sz_static_assert(condition, name) \ - typedef struct { \ - int static_assert_##name : (condition) ? 1 : -1; \ - } sz_static_assert_##name##_t - -sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); -sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); - -#pragma region Public API - -typedef char *sz_ptr_t; // A type alias for `char *` -typedef char const *sz_cptr_t; // A type alias for `char const *` -typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions - -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings - -typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit -typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> - -/** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. - */ -typedef struct sz_string_view_t { - sz_cptr_t start; - sz_size_t length; -} sz_string_view_t; - -/** - * @brief Enumeration of SIMD capabilities of the target architecture. - * Used to introspect the supported functionality of the dynamic library. - */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability - - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability - -} sz_capability_t; - -/** - * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. - * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. - */ -SZ_DYNAMIC sz_capability_t sz_capabilities(void); - -/** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert - */ -typedef union sz_charset_t { - sz_u64_t _u64s[4]; - sz_u32_t _u32s[8]; - sz_u16_t _u16s[16]; - sz_u8_t _u8s[32]; -} sz_charset_t; - -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } - -/** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } - -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast - -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { - // Checking the bit can be done in different ways: - // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 - // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 - // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 - // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 - return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); -} - -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast -} - -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { - s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // - s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; -} - -typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); -typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); - -/** - * @brief Some complex pattern matching algorithms may require memory allocations. - * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed - */ -typedef struct sz_memory_allocator_t { - sz_memory_allocate_t allocate; - sz_memory_free_t free; - void *handle; -} sz_memory_allocator_t; - -/** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. - */ -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); - -/** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. - * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. - */ -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); - -/** - * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. - * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. - */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE -#endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length - -/** - * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). - * Differs in layout from Folly, Clang, GCC, and probably most other implementations. - * It's designed to avoid any branches on read-only operations, and can store up - * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. - * - * @section Changing Length - * - * One nice thing about this design, is that you can, in many cases, change the length of the string - * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, - * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, - * only changing the last byte containing the length. - */ -typedef union sz_string_t { - -#if !SZ_DETECT_BIG_ENDIAN - - struct external { - sz_ptr_t start; - sz_size_t length; - sz_size_t space; - sz_size_t padding; - } external; - - struct internal { - sz_ptr_t start; - sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; - } internal; - -#else - - struct external { - sz_ptr_t start; - sz_size_t space; - sz_size_t padding; - sz_size_t length; - } external; - - struct internal { - sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; - sz_u8_t length; - } internal; - -#endif - - sz_size_t words[4]; - -} sz_string_t; - -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); -typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); +#pragma region Core API /** * @brief Computes the 64-bit check-sum of bytes in a string. @@ -428,9 +36,6 @@ typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); */ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - /** * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, * simple implementation, and supports rolling computation, reused in other APIs. @@ -444,108 +49,74 @@ SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); */ SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); - /** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap + * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. + * Can be used for similarity scores, search, ranking, etc. * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. - */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. + * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend + * on the choice of bases and the prime number. That's why, often two hashes from the same + * family are used with different bases. * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. + * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. + * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. + * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are + * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, + * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. + * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. + * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. + * @param text String to hash. + * @param length Number of bytes in the string. + * @param window_length Length of the rolling window in bytes. + * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. + * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. + * @param callback_handle Optional user-provided pointer to be passed to the `callback`. + * @see sz_hashes_fingerprint, sz_hashes_intersection */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); +SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // + sz_hash_callback_t callback, void *callback_handle); /** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. + * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. + * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times + * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. + * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length + * and calling the same function multiple times for the same input ::text. * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, + * avoiding cache-coherency penalties of remote on-heap buffers. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. + * @param text String to hash. + * @param length Number of bytes in the string. + * @param fingerprint Output fingerprint buffer. + * @param fingerprint_bytes Number of bytes in the fingerprint buffer. + * @param window_length Length of the rolling window in bytes. + * @see sz_hashes, sz_hashes_intersection */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); +SZ_PUBLIC void sz_hashes_fingerprint( // + sz_cptr_t text, sz_size_t length, sz_size_t window_length, // + sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); /** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. + * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes + * of the incoming document. Can be used for document scoring and search. * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Checks if all characters in the range are valid ASCII characters. + * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, + * avoiding cache-coherency penalties of remote on-heap buffers. * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. + * @param text Input document. + * @param length Number of bytes in the input document. + * @param fingerprint Reference document fingerprint. + * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. + * @param window_length Length of the rolling window in bytes. + * @see sz_hashes, sz_hashes_fingerprint */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); +SZ_PUBLIC sz_size_t sz_hashes_intersection( // + sz_cptr_t text, sz_size_t length, sz_size_t window_length, // + sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); /** * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. @@ -567,5118 +138,312 @@ SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, void *generator); +/** @copydoc sz_checksum */ +SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); +/** @copydoc sz_hash */ +SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); /** @copydoc sz_generate */ SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, void *generator); +/** @copydoc sz_hashes */ +SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // + sz_hash_callback_t callback, void *callback_handle); -/** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); - -/** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. - */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); +#pragma endregion // Core API -/** - * @brief Initializes a string class instance to an empty value. - */ -SZ_PUBLIC void sz_string_init(sz_string_t *string); +#pragma region Serial Implementation -/** - * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, - * alternative being - allocated in a remote region of the heap. - */ -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); +SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { + sz_u64_t checksum = 0; + sz_u8_t const *text_u8 = (sz_u8_t const *)text; + sz_u8_t const *text_end = text_u8 + length; + for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; + return checksum; +} -/** - * @brief Unpacks the opaque instance of a string class into its components. - * Recommended to use only in read-only operations. +/* + * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. + * Using a Boost-like mixer works very poorly in such case: * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. - * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. - */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); - -/** - * @brief Unpacks only the start and length of the string. - * Recommended to use only in read-only operations. + * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. + * Let's stick to the Fibonacci hash trick using the golden ratio. + * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ */ -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); +#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) +#define _sz_shift_low(x) (x) +#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) +#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) -/** - * @brief Constructs a string of a given ::length with noisy contents. - * Use the returned character pointer to populate the string. - * - * @param string String to initialize. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); +SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { -/** - * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. - * This is beneficial, if several insertions are expected, and we want to minimize allocations. - * - * @param string String to grow. - * @param new_capacity The number of characters to reserve space for, including existing ones. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); + sz_u64_t hash_low = 0; + sz_u64_t hash_high = 0; + sz_u8_t const *text = (sz_u8_t const *)start; + sz_u8_t const *text_end = text + length; -/** - * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. - * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. - * Similar to `sz_string_reserve`, but changes the length of the ::string. - * - * @param string String to grow. - * @param offset Offset of the first byte to reserve space for. - * If provided offset is larger than the length, it will be capped. - * @param added_length The number of new characters to reserve space for. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); - -/** - * @brief Removes a range from a string. Changes the length, but not the capacity. - * Performs no allocations or deallocations and can't fail. - * - * @param string String to clean. - * @param offset Offset of the first byte to remove. - * @param length Number of bytes to remove. Out-of-bound ranges will be capped. - * @return Number of bytes removed. - */ -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); - -/** - * @brief Shrinks the string to fit the current length, if it's allocated on the heap. - * It's the reverse operation of ::sz_string_reserve. - * - * @param string String to shrink. - * @param allocator Memory allocator to use for the allocation. - * @return Whether the operation was successful. The only failures can come from the allocator. - * On failure, the string will remain unchanged. - */ -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); - -/** - * @brief Frees the string, if it's allocated on the heap. - * If the string is on the stack, the function clears/resets the state. - */ -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); - -#pragma endregion - -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); - -/** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Locates the last matching substring. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -#pragma endregion - -#pragma region String Similarity Measures API - -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); - -/** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); - -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); - -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); - -#pragma endregion - -#pragma region Convenience API - -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -#pragma endregion - -#pragma region String Sequences API - -struct sz_sequence_t; - -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); - -typedef struct sz_sequence_t { - sz_sorted_idx_t *order; - sz_size_t count; - sz_sequence_member_start_t get_start; - sz_sequence_member_length_t get_length; - void const *handle; -} sz_sequence_t; - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); - -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. - * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. - */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); - -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); - -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); - -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. - */ -#ifndef SZ_USE_X86_AVX512 -#ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 -#else -#define SZ_USE_X86_AVX512 0 -#endif -#endif - -#ifndef SZ_USE_X86_AVX2 -#ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 -#else -#define SZ_USE_X86_AVX2 0 -#endif -#endif - -#ifndef SZ_USE_ARM_NEON -#ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 -#else -#define SZ_USE_ARM_NEON 0 -#endif -#endif - -#ifndef SZ_USE_ARM_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 -#else -#define SZ_USE_ARM_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 -#include -#endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_ARM_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_X86_AVX512 - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_X86_AVX2 -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_ARM_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_ARM_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes - -/** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. - * @note If you want to catch it, put a breakpoint at @b `__GI_exit` - */ -#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` -#include // `EXIT_FAILURE` -SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { - fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); - exit(EXIT_FAILURE); -} -#define sz_assert(condition) \ - do { \ - if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ - } while (0) -#else -#define sz_assert(condition) ((void)(condition)) -#endif - -/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. - * The following section of compiler intrinsics comes in 2 flavors. - */ -#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL -#include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. -#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { - x = x - ((x >> 1) & 0x5555555555555555ull); - x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); - return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; -} -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x80000000u) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { - x = x - ((x >> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >> 2) & 0x33333333); - return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; -} -#else -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } -#endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given -#pragma intrinsic(_byteswap_uint64) -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } -#pragma intrinsic(_byteswap_ulong) -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } -#else -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } -#endif - -SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } - -/** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. - * - * Similar to `_mm_blend_epi16` intrinsic on x86. - * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. - * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching - */ -SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } - -/* - * Efficiently computing the minimum and maximum of two or three values can be tricky. - * The simple branching baseline would be: - * - * x < y ? x : y // can replace with 1 conditional move - * - * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. - * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function - * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax - * Using only bit-shifts for singed integers it would be: - * - * y + ((x - y) & (x - y) >> 31) // 4 unique operations - * - * Alternatively, for any integers using multiplication: - * - * (x > y) * y + (x <= y) * x // 5 operations - * - * Alternatively, to avoid multiplication: - * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if SZ_DETECT_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if SZ_DETECT_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. - * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; - - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - -/** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - */ -#define SZ_U32_MAX_PRIME (2147483647u) - -/** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 - */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) - -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} - -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { - static sz_u8_t const lowered[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return lowered[c]; -} - -/** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { - static sz_u8_t const upped[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return upped[c]; -} - -/** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. - */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); -} - -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; -} - -/** - * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { - - if (!length) return sz_true_k; - sz_u8_t const *h = (sz_u8_t const *)text; - sz_u8_t const *const h_end = h + length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; -#endif - - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); - - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); - - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; - } - } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} - -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); -} - -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; -} - -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} - -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); - - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} - -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); -} - -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", off) -#endif -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - sz_ptr_t end = target + length; - // Dealing with short strings, a single sequential pass would be faster. - // If the size is larger than 2 words, then at least 1 of them will be aligned. - // But just one aligned word may not be worth SWAR. - if (length < SZ_SWAR_THRESHOLD) - while (target != end) *(target++) = value; - - // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. - else { - sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; - while ((sz_size_t)target & 7ull) *(target++) = value; - while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; - while (target != end) *(target++) = value; - } -} -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", on) -#endif - -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); -} - -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. - // Existing implementations often have two passes, in normal and reversed order, - // depending on the relation of `target` and `source` addresses. - // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html - // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 - // - // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. - // Or if we know that they don't intersect! In that case the traversal order is irrelevant, - // but older CPUs may predict and fetch forward-passes better. - if (target < source || target >= source + length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - target += length, source += length; -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; -#endif - while (length--) *(--target) = *(--source); - } -} - -#pragma endregion - -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; -} - -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} - -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; - } -} - -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} - -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; - } -} - -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} - -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } - - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } - - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; - } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); -} - -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; - - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } - - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } -} - -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if SZ_DETECT_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; - } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif -} - -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} - -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. - */ -#pragma region AVX2 Implementation - -#if SZ_USE_X86_AVX2 -#pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - char value_char = *(char *)&value; - __m256i value_vec = _mm256_set1_epi8(value_char); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); - // sz_fill_serial(target, length, value); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 32) sz_fill_serial(target, length, value); - // When the buffer is aligned, we can avoid any split-stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u16_t value16 = (sz_u16_t)value * 0x0101u; - sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; - sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; - } -} - -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 32; target += 32, source += 32, length -= 32) - // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - // sz_copy_serial(target, source, length); - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); - if (length) sz_copy_serial(target, source, length); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; target += 32, source += 32, body_length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { - _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); - _mm256_store_si256((__m256i *)(target + body_length - 32), - _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); - } - if (body_length) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; - } -} - -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - for (target += length, source += length; length >= 32; length -= 32) - _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); - while (length--) *(--target) = *(--source); - } -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} - -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // We need to pull the lookup table into 8x YMM registers. - // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, - // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, - // so that we can at least compensate high latency with twice larger window and one more level of lookup. - sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // - lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // - lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // - lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; - - lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); - lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); - lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); - lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); - lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); - lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); - lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); - lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); - lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); - lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); - lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); - lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); - lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); - lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); - lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); - lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); - - // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. - sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; - - /// Top and bottom nibbles of the source are used separately. - sz_u256_vec_t source_vec, source_bot_vec; - sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, - blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; - - // Handling the head. - while (length >= 32) { - // Load and separate the nibbles of each byte in the source. - source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); - source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); - - // In the first round, we select using the 4th bit. - not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - - // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. - // The first round selects using the 3rd bit. - not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_32_to_63_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_third_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - blended_96_to_127_vec.ymm, // - blended_64_to_95_vec.ymm, // - not_third_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_160_to_191_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_third_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - blended_224_to_255_vec.ymm, // - blended_192_to_223_vec.ymm, // - not_third_bit_vec.ymm); - - // The second round selects using the 2nd bit. - not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_64_to_95_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_second_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_192_to_223_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_second_bit_vec.ymm); - - // The third round selects using the 1st bit. - not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_128_to_159_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_first_bit_vec.ymm); - - // And dump the result into the target. - _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); - source += 32, target += 32, length -= 32; - } - - // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } - } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif -#pragma endregion - -/* - * @brief AVX-512 implementation of the string search algorithms. - * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT - */ -#pragma region AVX512 Implementation - -#if SZ_USE_X86_AVX512 -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} - -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } - - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } - } - - return sz_equal_k; -} - -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } - - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - __m512i value_vec = _mm512_set1_epi8(value); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); - // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, value_vec); - } - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, value_vec); - for (target += head_length; body_length >= 64; target += 64, body_length -= 64) - _mm512_store_si512(target, value_vec); - _mm512_mask_storeu_epi8(target, tail_mask, value_vec); - } -} - -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 64; target += 64, source += 64, length -= 64) - // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); - // __mmask64 mask = _sz_u64_mask_until(length); - // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { - for (; length >= 64; target += 64, source += 64, length -= 64) - _mm512_store_si512(target, _mm512_load_si512(source)); - // At this point the length is guaranteed to be under 64. - __mmask64 mask = _sz_u64_mask_until(length); - // Aligned load and stores would work too, but it's not defined. - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. - // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(target + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source)); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (target += head_length, source += head_length; // - body_length >= 128; // - target += 64, source += 64, body_length -= 128) { - _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); - _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); - } - if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); - } -} - -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target == source) return; // Don't be silly, don't move the data if it's already there. - - // On very short buffers, that are one cache line in width or less, we don't need any loops. - // We can also avoid any data-dependencies between iterations, assuming we have 32 registers - // to pre-load the data, before writing it back. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - else if (length <= 128) { - sz_size_t last_length = length - 64; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); - _mm512_storeu_epi8(target, source0); - _mm512_mask_storeu_epi8(target + 64, mask, source1); - } - else if (length <= 192) { - sz_size_t last_length = length - 128; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_mask_storeu_epi8(target + 128, mask, source2); - } - else if (length <= 256) { - sz_size_t last_length = length - 192; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_loadu_epi8(source + 128); - __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_storeu_epi8(target + 128, source2); - _mm512_mask_storeu_epi8(target + 192, mask, source3); - } - - // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // The absolute most common case of using "moves" is shifting the data within a continuous buffer - // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, - // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. - // - // Remember: - // - if we are shifting data left, that we are traversing to the right. - // - if we are shifting data right, that we are traversing to the left. - int const left_to_right_traversal = source > target; - - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - } - } -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. - } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); - } - } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. - else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) - -/** - * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - - sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - - // The next few buffers will be swapped around. - sz_u512_vec_t previous_vec, current_vec, next_vec; - sz_u512_vec_t gaps_vec, substitutions_vec; - - // Load the strings into ZMM registers - just once. - sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; - longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); - rotate_left_vec.zmm = _mm512_set_epi8( // - 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // - 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); - rotate_right_vec.zmm = _mm512_set_epi8( // - 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // - 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); - ones_vec.zmm = _mm512_set1_epi8(1); - bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); - - // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Initialize the first two diagonals: - // - // previous_vec.u8s[0] = 0; - // current_vec.u8s[0] = current_vec.u8s[1] = 1; - // - // We can do a similar thing with vector ops: - previous_vec.zmm = _mm512_setzero_si512(); - current_vec.zmm = _mm512_set1_epi8(1); - - // We skip diagonals 0 and 1, as they are trivial. - // We will start with diagonal 2, which has length 3, with the first and last elements being preset, - // so we are effectively computing just one value, as will be marked by a single set bit in - // the `next_diagonal_mask` on the very first iteration. - sz_size_t next_diagonal_index = 2; - __mmask64 next_diagonal_mask = 0; - - // Progress through the upper triangle of the Levenshtein matrix. - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // The mask also adds one set bit. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = current_vec.zmm; - current_vec.zmm = next_vec.zmm; - - // Shift the shorter string - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. - for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { - // After this iteration, the value `shorted_dim - 1` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // Make sure we update the first entry. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the bottom right triangle. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - // In every following iterations we take use a shorter prefix of each register, - // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. - next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); - } - return current_vec.u8s[0]; -} - -/** - * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. - * - * This may be one of the most freuqently called kernels for: - * - source code analysis, assuming most lines are either under 80 or under 120 characters long. - * - DNA sequence alignment, as most short reads are 50-300 characters long. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, - * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. - * - * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { + switch (length) { + case 0: return 0; - sz_unused(shorter && longer && bound && alloc); + // Texts under 7 bytes long are definitely below the largest prime. + case 1: + hash_low = _sz_shift_low(text[0]); + hash_high = _sz_shift_high(text[0]); + break; + case 2: + hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); + hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); + break; + case 3: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull + // + _sz_shift_low(text[2]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull + // + _sz_shift_high(text[2]); + break; + case 4: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull + // + _sz_shift_low(text[3]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull + // + _sz_shift_high(text[3]); + break; + case 5: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull * 31ull + // + _sz_shift_low(text[3]) * 31ull + // + _sz_shift_low(text[4]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull * 257ull + // + _sz_shift_high(text[3]) * 257ull + // + _sz_shift_high(text[4]); + break; + case 6: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[3]) * 31ull * 31ull + // + _sz_shift_low(text[4]) * 31ull + // + _sz_shift_low(text[5]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[3]) * 257ull * 257ull + // + _sz_shift_high(text[4]) * 257ull + // + _sz_shift_high(text[5]); + break; + case 7: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[4]) * 31ull * 31ull + // + _sz_shift_low(text[5]) * 31ull + // + _sz_shift_low(text[6]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[4]) * 257ull * 257ull + // + _sz_shift_high(text[5]) * 257ull + // + _sz_shift_high(text[6]); + break; + default: + // Unroll the first seven cycles: + hash_low = hash_low * 31ull + _sz_shift_low(text[0]); + hash_high = hash_high * 257ull + _sz_shift_high(text[0]); + hash_low = hash_low * 31ull + _sz_shift_low(text[1]); + hash_high = hash_high * 257ull + _sz_shift_high(text[1]); + hash_low = hash_low * 31ull + _sz_shift_low(text[2]); + hash_high = hash_high * 257ull + _sz_shift_high(text[2]); + hash_low = hash_low * 31ull + _sz_shift_low(text[3]); + hash_high = hash_high * 257ull + _sz_shift_high(text[3]); + hash_low = hash_low * 31ull + _sz_shift_low(text[4]); + hash_high = hash_high * 257ull + _sz_shift_high(text[4]); + hash_low = hash_low * 31ull + _sz_shift_low(text[5]); + hash_high = hash_high * 257ull + _sz_shift_high(text[5]); + hash_low = hash_low * 31ull + _sz_shift_low(text[6]); + hash_high = hash_high * 257ull + _sz_shift_high(text[6]); + text += 7; - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; + // Iterate throw the rest with the modulus: + for (; text != text_end; ++text) { + hash_low = hash_low * 31ull + _sz_shift_low(text[0]); + hash_high = hash_high * 257ull + _sz_shift_high(text[0]); + // Wrap the hashes around: + hash_low = _sz_prime_mod(hash_low); + hash_high = _sz_prime_mod(hash_high); + } + break; } - // TODO: Generalize! - sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && bound && max_length); - -#if 0 - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. - // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; - sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - // The next few pointers will be swapped around. - sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + longer_dim; - sz_u16_t *next_distances = current_distances + longer_dim; - sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); - - // Export the reversed string into the buffer. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; + return _sz_hash_mix(hash_low, hash_high); +} - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; +SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { - // Using ZMM registers, we can process 32x 16-bit values at once, - // storing 16 bytes of each string in YMM registers. - sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; - sz_u512_vec_t ones_u16_vec; - ones_u16_vec.zmm = _mm512_set1_epi16(1); + if (length < window_length || !window_length) return; + sz_u8_t const *text = (sz_u8_t const *)start; + sz_u8_t const *text_end = text + length; - // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. - sz_u512_vec_t shorter_vec, longer_vec; - sz_u512_vec_t ones_u8_vec; - ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); + // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. + sz_u64_t prime_power_low = 1, prime_power_high = 1; + for (sz_size_t i = 0; i + 1 < window_length; ++i) + prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, + prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; + // Compute the initial hash value for the first window. + sz_u64_t hash_low = 0, hash_high = 0, hash_mix; + for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) + hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, + hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); - // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` - // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the - // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - - // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to - // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // - remaining_length_mask, - shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow - // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, - _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); - deletions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); - offset_within_diagonal += register_length; - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } + // In most cases the fingerprint length will be a power of two. + hash_mix = _sz_hash_mix(hash_low, hash_high); + callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); - // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. - // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would - // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. - // Which simplified would be equal to just `[i]`. Beautiful! - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. The compare it against the accumulated - // substitution costs. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; + // Compute the hash value for every window, exporting into the fingerprint, + // using the expensive modulo operation. + sz_size_t cycles = 1; + sz_size_t const step_mask = step - 1; + for (; text < text_end; ++text, ++cycles) { + // Discard one character: + hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; + hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; + // And add a new one: + hash_low = 31ull * hash_low + _sz_shift_low(*text); + hash_high = 257ull * hash_high + _sz_shift_high(*text); + // Wrap the hashes around: + hash_low = _sz_prime_mod(hash_low); + hash_high = _sz_prime_mod(hash_high); + // Mix only if we've skipped enough hashes. + if ((cycles & step_mask) == 0) { + hash_mix = _sz_hash_mix(hash_low, hash_high); + callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; } -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } +#undef _sz_shift_low +#undef _sz_shift_high +#undef _sz_hash_mix +#undef _sz_prime_mod - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; - } +#pragma endregion // Serial Implementation - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); -} +/* AVX2 implementation of the string search algorithms for Haswell processors and newer. + * Very minimalistic (compared to AVX-512), but still faster than the serial implementation. + */ +#pragma region Haswell Implementation +#if SZ_USE_HASWELL +#pragma GCC push_options +#pragma GCC target("haswell") +#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; + // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, + // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. + // For now, let's avoid the cases beyond the L2 size. + int is_huge = length > 1ull * 1024ull * 1024ull; // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); + if (length <= 32) { return sz_checksum_serial(text, length); } + else if (!is_huge) { + sz_u256_vec_t text_vec, sums_vec; + sums_vec.ymm = _mm256_setzero_si256(); + for (; length >= 32; text += 32, length -= 32) { + text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); + sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); + } // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); + __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); + __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); + sz_u64_t result = low + high; + if (length) result += sz_checksum_serial(text, length); + return result; } // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. + // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); + sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. + sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. + sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. + sz_u64_t result = 0; - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); + // Handle the head + while (head_length--) result += *text++; - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); + sz_u256_vec_t text_vec, sums_vec; + sums_vec.ymm = _mm256_setzero_si256(); + // Fill the aligned body of the buffer. + if (!is_huge) { + for (; body_length >= 32; text += 32, body_length -= 32) { + text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); + sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); + } } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + // When the biffer is huge, we can traverse it in 2 directions. + else { + sz_u256_vec_t text_reversed_vec, sums_reversed_vec; + sums_reversed_vec.ymm = _mm256_setzero_si256(); + for (; body_length >= 64; text += 64, body_length -= 64) { + text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); + sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); + text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); + sums_reversed_vec.ymm = _mm256_add_epi64( + sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); + } + if (body_length >= 32) { + text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); + sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); + } + sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); } - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); + // Handle the tail + while (tail_length--) result += *text++; + + // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. + __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); + __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); + __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); + sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); + sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); + result += low + high; + return result; } } -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { +/** + * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. + * This implementation is coming from Agner Fog's Vector Class Library. + */ +SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { + __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); + __m256i prodlh = _mm256_mullo_epi32(a, bswap); + __m256i zero = _mm256_setzero_si256(); + __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); + __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); + __m256i prodll = _mm256_mul_epu32(a, b); + __m256i prod = _mm256_add_epi64(prodll, prodlh3); + return prod; +} + +SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { if (length < window_length || !window_length) return; if (length < 4 * window_length) { @@ -5696,57 +461,58 @@ SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t win sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; sz_u8_t const *text_end = text_first + length; - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. sz_u64_t prime_power_low = 1, prime_power_high = 1; for (sz_size_t i = 0; i + 1 < window_length; ++i) prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); + // Broadcast the constants into the registers. + sz_u256_vec_t prime_vec, golden_ratio_vec; + sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; + base_low_vec.ymm = _mm256_set1_epi64x(31ull); + base_high_vec.ymm = _mm256_set1_epi64x(257ull); + shift_high_vec.ymm = _mm256_set1_epi64x(77ull); + prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); + golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); + prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); + prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); + sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; + hash_low_vec.ymm = _mm256_setzero_si256(); + hash_high_vec.ymm = _mm256_setzero_si256(); for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; ++text_first, ++text_second, ++text_third, ++text_fourth) { // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); + // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. + chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); + hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); + hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); // 4. Compute the modulo. Assuming there are only 59 values between our prime // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); + hash_low_vec.ymm = _mm256_blendv_epi8( // + hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); + hash_high_vec.ymm = _mm256_blendv_epi8( // + hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); } // 5. Compute the hash mix, that will be used to index into the fingerprint. // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); + hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); @@ -5754,45 +520,45 @@ SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t win // Now repeat that operation for the remaining characters, discarding older characters. sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; + sz_size_t const step_mask = step - 1; for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); + chars_low_vec.ymm = _mm256_set_epi64x( // + text_fourth[-window_length], text_third[-window_length], text_second[-window_length], + text_first[-window_length]); + chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); + hash_low_vec.ymm = + _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); + hash_high_vec.ymm = + _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); + chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); + hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); + hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); // 4. Compute the modulo. Assuming there are only 59 values between our prime // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); + hash_low_vec.ymm = _mm256_blendv_epi8( // + hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); + hash_high_vec.ymm = _mm256_blendv_epi8( // + hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); // 5. Compute the hash mix, that will be used to index into the fingerprint. // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); + hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); if ((cycle & step_mask) == 0) { callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); @@ -5804,1353 +570,346 @@ SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t win #pragma clang attribute pop #pragma GCC pop_options +#endif // SZ_USE_HASWELL +#pragma endregion // Haswell Implementation -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ - apply_to = function) - -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // We need to pull the lookup table into 4x ZMM registers. - // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` - // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to - // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. - // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. - // - // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 6 cycles latency, ports: 1*FP12 - // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p05 - // - On Genoa: 1 cycle latency, ports: 1*FP0123 - // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 4 cycles latency, ports: 1*FP01 - // - sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); - lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); - lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); - lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); - - sz_u512_vec_t first_bit_vec, second_bit_vec; - first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); - second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); - - __mmask64 first_bit_mask, second_bit_mask; - sz_u512_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; - - // Handling the head. - if (head_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); - source += head_length, target += head_length, length -= head_length; - } - - // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. - while (length >= 64) { - source_vec.zmm = _mm512_loadu_si512(source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! - source += 64, target += 64, length -= 64; - } - - // Handling the tail. - if (tail_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); - source += tail_length, target += tail_length, length -= tail_length; - } -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { - - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; -} - -/** - * Computes the Needleman Wunsch alignment score between two strings. - * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. - * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used - * on strings not exceeding 2^24 length or 16.7 million characters. +/* AVX512 implementation of the string hashing algorithms for Skylake and newer CPUs. + * Includes extensions: F, CD, ER, PF, VL, DQ, BW. * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store - * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of - * a 256 x 256 matrix, but from a single row! + * This is the "starting level" for the advanced algorithms using K-mask registers on x86. */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t const max_length = 256ull * 256ull * 256ull; - sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && max_length); - - sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; - sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); - sz_i32_t *previous_distances = distances; - sz_i32_t *current_distances = previous_distances + n; - - // Intialize the first row of the Levenshtein matrix with `iota`. - for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) - previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; - - /// Contains up to 16 consecutive characters from the longer string. - sz_u512_vec_t longer_vec; - sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; - sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; - sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; - - // Prepare constants and masks. - sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; - { - char is_third_or_fourth_check, is_second_or_fourth_check; - *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; - is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); - is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); - gap_vec.zmm = _mm512_set1_epi32(gap); - } - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; - - // Load one row of the substitution matrix into four ZMM registers. - sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; - row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); - row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); - row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); - row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); - - // In the serial version we have one forward pass, that computes the deletion, - // insertion, and substitution costs at once. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; - // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); - // } - // - // Given the complexity of handling the data-dependency between consecutive insertion cost computations - // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation - // separately. - // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. - // 2. Compute the pairwise minimum with deletion costs. - // 3. Inclusive prefix minimum computation to combine with addition costs. - // Proceeding with substitutions: - for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { - sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); - __mmask64 mask = _sz_u64_mask_until(register_length); - longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); - - // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source - // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. - // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. - shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); - shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); - shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); - shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); - - // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using - // the AND logical operation, checking the top two bits of every byte. - // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. - __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); - __mmask64 is_second_or_fourth = - _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); - lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( - is_third_or_fourth, - // Choose between the first and the second. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), - // Choose between the third and the fourth. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); - - // First, sign-extend lower and upper 16 bytes to 16-bit integers. - __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); - __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); - - // Now extend those 16-bit integers to 32-bit. - // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. - // To minimize the number of loads and stores, we can combine our substitution costs with the previous - // distances, containing the deletion costs. - { - cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); - cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Inclusive prefix minimum computation to combine with insertion costs. - // Simply disabling this operation results in 5x performance improvement, meaning - // that this operation is responsible for 80% of the total runtime. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // current_distances[idx_longer + 1] = - // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); - // } - // - // To perform the same operation in vectorized form, we need to perform a tree-like reduction, - // that will involve multiple steps. It's quite expensive and should be first tested in the - // "experimental" section. - // - // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); - // ... yet this approach is also quite expensive. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 16 to 31. - if (register_length > 16) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 32) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 48) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); - } - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[longer_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); - else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); -} - -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, - // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} +#pragma region Skylake Implementation +#if SZ_USE_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) #pragma clang attribute pop #pragma GCC pop_options -#endif +#endif // SZ_USE_SKYLAKE +#pragma endregion // Skylake Implementation -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. +/* AVX512 implementation of the string search algorithms for Ice Lake and newer CPUs. + * Includes extensions: + * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, + * - 2018 CannonLake: IFMA, VBMI, + * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. */ -#pragma region ARM NEON - -#if SZ_USE_ARM_NEON +#pragma region Ice Lake Implementation +#if SZ_USE_ICE #pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+simd") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) - -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; -} +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ + apply_to = function) -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} +SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { + // The naive implementation of this function is very simple. + // It assumes the CPU is great at handling unaligned "loads". + // + // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, + // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. + // With two strings, we may consider the overall workload huge, if each exceeds 1 MB in length. + int const is_huge = length >= 1ull * 1024ull * 1024ull; + sz_u512_vec_t text_vec, sums_vec; -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match + // When the buffer is small, there isn't much to innovate. + if (length <= 16) { + __mmask16 mask = _sz_u16_mask_until(length); + text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); + sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); + sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); + sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); + return low + high; } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum + else if (length <= 32) { + __mmask32 mask = _sz_u32_mask_until(length); + text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); + sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); + // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. + __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); + __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); + __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); + sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); + sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); + return low + high; } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // In most cases the `source` and the `target` are not aligned, but we should - // at least make sure that writes don't touch many cache lines. - // NEON has an instruction to load and write 64 bytes at once. + else if (length <= 64) { + __mmask64 mask = _sz_u64_mask_until(length); + text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); + sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); + return _mm512_reduce_add_epi64(sums_vec.zmm); + } + else if (!is_huge) { + sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. + sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. + sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. + __mmask64 head_mask = _sz_u64_mask_until(head_length); + __mmask64 tail_mask = _sz_u64_mask_until(tail_length); + text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); + sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); + for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { + text_vec.zmm = _mm512_load_si512((__m512i const *)text); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + } + text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + return _mm512_reduce_add_epi64(sums_vec.zmm); + } + // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. // - // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; - // length -= head_length; - // for (; length >= 64; target += 64, source += 64, length -= 64) - // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); - // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; + // 1. Moving in both directions to maximize the throughput, when fetching from multiple + // memory pages. Also helps with cache set-associativity issues, as we won't always + // be fetching the same entries in the lookup table. + // 2. Using non-temporal stores to avoid polluting the cache. + // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless + // for predictable patterns, so disregard this advice. // - // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: - for (; length >= 16; target += 16, source += 16, length -= 16) - vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); - if (length) sz_copy_serial(target, source, length); -} - -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // When moving small buffers, using a small buffer on stack as a temporary storage is faster. - - if (target < source || target >= source + length) { - // Non-overlapping, proceed forward - sz_copy_neon(target, source, length); - } + // Bidirectional traversal generally adds about 10% to such algorithms. else { - // Overlapping, proceed backward - target += length; - source += length; + sz_u512_vec_t text_reversed_vec, sums_reversed_vec; + sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; + sz_size_t tail_length = (sz_size_t)(text + length) % 64; + sz_size_t body_length = length - head_length - tail_length; + __mmask64 head_mask = _sz_u64_mask_until(head_length); + __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - sz_u128_vec_t src_vec; - while (length >= 16) { - target -= 16, source -= 16, length -= 16; - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - } - while (length) { - target -= 1, source -= 1, length -= 1; - *target = *source; - } - } -} + text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); + sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); + text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); + sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register + // Now in the main loop, we can use non-temporal loads and stores, + // performing the operation in both directions. + for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { + text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); + sums_reversed_vec.zmm = + _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); + } + if (body_length >= 64) { + text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + } - while (length >= 16) { - vst1q_u8((sz_u8_t *)target, fill_vec); - target += 16; - length -= 16; + return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); } - - // Handle remaining bytes - if (length) sz_fill_serial(target, length, value); } -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_PUBLIC void sz_hashes_ice(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); + if (length < window_length || !window_length) return; + if (length < 4 * window_length) { + sz_hashes_serial(start, length, window_length, step, callback, callback_handle); return; } - sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. - - // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. - // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. - uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); - lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); - lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); - lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); - - sz_u128_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u128_vec_t blended_0_to_255_vec; - - // Process the head with serial code - for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; + // Using AVX2, we can perform 4 long integer multiplications and additions within one register. + // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. + sz_size_t const max_hashes = length - window_length + 1; + sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. + sz_u8_t const *text_first = (sz_u8_t const *)start; + sz_u8_t const *text_second = text_first + min_hashes_per_thread; + sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; + sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; + sz_u8_t const *text_end = text_first + length; - // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction - // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position - // within each 64-byte range of the table. - // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ - length -= head_length; - length -= tail_length; - for (; length >= 16; source += 16, target += 16, length -= 16) { - source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); - lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); - lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); - lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); - vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); - } + // Broadcast the global constants into the registers. + // Both high and low hashes will work with the same prime and golden ratio. + sz_u512_vec_t prime_vec, golden_ratio_vec; + prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); + golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - // Process the tail with serial code - for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; -} + // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. + sz_u64_t prime_power_low = 1, prime_power_high = 1; + for (sz_size_t i = 0; i + 1 < window_length; ++i) + prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, + prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); + // We will be evaluating 4 offsets at a time with 2 different hash functions. + // We can fit all those 8 state variables in each of the following ZMM registers. + sz_u512_vec_t base_vec, prime_power_vec, shift_vec; + base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); + shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); + prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, + prime_power_high, prime_power_high, prime_power_high, prime_power_high); - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; + // Compute the initial hash values for every one of the four windows. + sz_u512_vec_t hash_vec, chars_vec; + hash_vec.zmm = _mm512_setzero_si512(); + for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; + ++text_first, ++text_second, ++text_third, ++text_fourth) { - h += 16, h_length -= 16; - } + // 1. Multiply the hashes by the base. + hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - return sz_find_byte_serial(h, h_length, n); -} + // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, + // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... + chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // + text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); + // 3. Add the incoming characters. + hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; + // 4. Compute the modulo. Assuming there are only 59 values between our prime + // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. + hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, + _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); } - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { - - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} + // 5. Compute the hash mix, that will be used to index into the fingerprint. + // This includes a serial step at the end. + sz_u512_vec_t hash_mix_vec; + hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); + hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // + _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); + callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); + callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); + callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); + // Now repeat that operation for the remaining characters, discarding older characters. + sz_size_t cycle = 1; + sz_size_t step_mask = step - 1; + for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { + // 0. Load again the four characters we are dropping, shift them, and subtract. + chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], + text_second[-window_length], text_first[-window_length], // + text_fourth[-window_length], text_third[-window_length], + text_second[-window_length], text_first[-window_length]); + chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); + hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - } + // 1. Multiply the hashes by the base. + hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - return sz_find_serial(h, h_length, n, n_length); -} + // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, + // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. + chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // + text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + // ... and prefetch the next four characters into Level 2 or higher. + _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); + _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); + _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); + _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); + // 3. Add the incoming characters. + hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); + // 4. Compute the modulo. Assuming there are only 59 values between our prime + // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. + hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, + _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); + // 5. Compute the hash mix, that will be used to index into the fingerprint. + // This includes a serial step at the end. + hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); + hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // + _mm512_castsi512_si256(hash_mix_vec.zmm)); - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); + if ((cycle & step_mask) == 0) { + callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); + callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); + callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); + callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); } } - - return sz_rfind_serial(h, h_length, n, n_length); } -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SZ_USE_ICE +#pragma endregion // Ice Lake Implementation - return sz_find_charset_serial(h, h_length, set); -} +/* Implementation of the string hashing algorithms using the Arm NEON instruction set, available on 64-bit + * Arm processors. Covers billions of mobile CPUs worldwide, including Apple's A-series, and Qualcomm's Snapdragon. + */ +#pragma region NEON Implementation +#if SZ_USE_NEON +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+simd") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); +SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { + uint64x2_t sum_vec = vdupq_n_u64(0); - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; + // Process 16 bytes (128 bits) at a time + for (; length >= 16; text += 16, length -= 16) { + uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes + uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits + uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results + uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results + sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum } - return sz_rfind_charset_serial(h, h_length, set); + // Final reduction of `sum_vec` to a single scalar + sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); + if (length) sum += sz_checksum_serial(text, length); + return sum; } #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm Neon - -#pragma endregion +#endif // SZ_USE_NEON +#pragma endregion // NEON Implementation -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. +/* Implementation of the string search algorithms using the Arm SVE variable-length registers, + * available in Arm v9 processors, like in Apple M4+ and Graviton 3+ CPUs. */ -#pragma region ARM SVE - -#if SZ_USE_ARM_SVE +#pragma region SVE Implementation +#if SZ_USE_SVE #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - svuint8_t value_vec = svdup_u8(value); - sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) - - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svst1_u8(mask, (unsigned char *)target, value_vec); - } - else { - // Calculate head, body, and tail sizes - sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); - sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned head - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svst1_u8(head_mask, (unsigned char *)target, value_vec); - target += head_length; - - // Aligned body loop - for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { - svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); - } - - // Handle unaligned tail - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svst1_u8(tail_mask, (unsigned char *)target, value_vec); - } -} - -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_size_t vec_len = svcntb(); // Vector length in bytes - - // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, - // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - // - // int is_huge = length >= 4ull * 1024ull * 1024ull; - // - // When the buffer is small, there isn't much to innovate. - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations - // and handle the head, body, and tail separately. We can also traverse the buffer in both directions - // as Arm generally supports more simultaneous stores than x86 CPUs. - // - // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. - // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) - // we will pay a huge penalty on loads, fetching the same content many times. - // It may be better to allow caching (and subsequent eviction), in favor of using four-element - // tuples, wich will be guaranteed to be a multiple of a cache line. - // - // Another approach is to use the `LD4B` instructions, which will populate four registers at once. - // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. - else { - // Calculating head, body, and tail sizes depends on the `vec_len`, - // but it's runtime constant, and the modulo operation is expensive! - // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. - sz_size_t head_length = 16 - ((sz_size_t)target % 16); - sz_size_t tail_length = (sz_size_t)(target + length) % 16; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned parts - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); - svst1_u8(head_mask, (unsigned char *)target, head_data); - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); - svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); - target += head_length; - source += head_length; - - // Aligned body loop, walking in two directions - for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { - svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); - svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); - svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); - svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); - } - // Up to (vec_len * 2 - 1) bytes of data may be left in the body, - // so we can unroll the last two optional loop iterations. - if (body_length > vec_len) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - body_length -= vec_len; - source += body_length; - target += body_length; - } - if (body_length) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - } -} - #pragma clang attribute pop #pragma GCC pop_options -#endif // Arm SVE +#endif // SZ_USE_SVE +#pragma endregion // SVE Implementation -#pragma endregion - -/* - * @brief Pick the right implementation for the string search algorithms. +/* Pick the right implementation for the string search algorithms. + * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. */ #pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { - - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; - - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ - - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); -} - #if !SZ_DYNAMIC_DISPATCH SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + return sz_checksum_ice(text, length); +#elif SZ_USE_HASWELL return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_checksum_neon(text, length); #else return sz_checksum_serial(text, length); #endif } -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 - return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); -#endif -} - -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif -} - -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_copy_neon(target, source, length); -#else - sz_copy_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_move_neon(target, source, length); -#else - sz_move_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 - sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 - sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON - sz_fill_neon(target, length, value); -#else - sz_fill_serial(target, length, value); -#endif -} - -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON - sz_look_up_transform_neon(source, length, lut, target); -#else - sz_look_up_transform_serial(source, length, lut, target); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_find_byte_neon(haystack, h_length, needle); -#else - return sz_find_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); -#else - return sz_rfind_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_find_neon(haystack, h_length, needle, n_length); -#else - return sz_find_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); -#else - return sz_rfind_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); -#else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); -#endif -} - -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} - -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); -#else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); -#endif -} - SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + sz_hashes_ice(text, length, window_length, window_step, callback, callback_handle); +#elif SZ_USE_HASWELL sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); #else sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); #endif } -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); -} - SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); } -#endif -#pragma endregion +#endif // !SZ_DYNAMIC_DISPATCH +#pragma endregion // Compile Time Dispatching #ifdef __cplusplus -#pragma GCC diagnostic pop } #endif // __cplusplus - -#endif // STRINGZILLA_H_ +#endif // STRINGZILLA_HASH_H_ From 86f53d99b93b0495b3f8f5ce81e607e1dc80e765 Mon Sep 17 00:00:00 2001 From: Alex Bondarev <44079602+alexbarev@users.noreply.github.com> Date: Sat, 7 Dec 2024 21:56:31 +0400 Subject: [PATCH 035/751] Test: Add ASCII utilities tests exposing final character exclusion bug --- scripts/test.cpp | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/scripts/test.cpp b/scripts/test.cpp index eecc97f0..e8123995 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -137,6 +137,49 @@ static void test_arithmetical_utilities() { (static_cast(number) / static_cast(divisor))); } +/** + * @brief Tests various ASCII-based methods (e.g., is_alpha, is_digit) + * provided by `sz::string` and `sz::string_view`. + */ +template +static void test_ascii_utilities() { + + using str = string_type; + + assert(!str("").is_alpha()); + assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ").is_alpha()); + assert(!str("abc9").is_alpha()); + + assert(!str("").is_alnum()); + assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789").is_alnum()); + assert(!str("abc!").is_alnum()); + + assert(!str("").is_ascii()); + assert(str("\x00x7F").is_ascii()); + assert(!str("abc123🔥").is_ascii()); + + assert(!str("").is_digit()); + assert(str("0123456789").is_digit()); + assert(!str("012a").is_digit()); + + assert(!str("").is_lower()); + assert(str("abcdefghijklmnopqrstuvwxyz").is_lower()); + assert(!str("abcA").is_lower()); + assert(!str("abc\n").is_lower()); + + assert(!str("").is_space()); + assert(str(" \t\n\r\f\v").is_space()); + assert(!str(" \t\r\na").is_space()); + + assert(!str("").is_upper()); + assert(str("ABCDEFGHIJKLMNOPQRSTUVWXYZ").is_upper()); + assert(!str("ABCa").is_upper()); + + assert(!str("").is_printable()); + assert(str("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+").is_printable()); + assert(!str("012\n").is_printable()); +} + inline void expect_equality(char const *a, char const *b, std::size_t size) { if (std::memcmp(a, b, size) == 0) return; std::size_t mismatch_position = 0; @@ -1583,6 +1626,8 @@ int main(int argc, char const **argv) { // Basic utilities test_arithmetical_utilities(); + test_ascii_utilities(); + test_ascii_utilities(); test_memory_utilities(); test_replacements(); From 8b44d6a5fe4d4ee3cf38d76d2d690bcf5b1a8a2d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:26:40 +0000 Subject: [PATCH 036/751] Improve: Platform-specific equality checks --- include/stringzilla/find.h | 43 +++++++++++++------------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/include/stringzilla/find.h b/include/stringzilla/find.h index a51bd4c6..4571515d 100644 --- a/include/stringzilla/find.h +++ b/include/stringzilla/find.h @@ -305,20 +305,6 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz #pragma GCC diagnostic pop } -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - /* Find the first occurrence of a @b single-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. * Identical to `memchr(haystack, needle[0], haystack_length)`. @@ -895,7 +881,7 @@ SZ_PUBLIC sz_cptr_t sz_find_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); while (matches) { int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; + if (sz_equal_haswell(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } } @@ -933,7 +919,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); while (matches) { int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) + if (sz_equal_haswell(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; matches &= ~(1 << (31 - potential_offset)); } @@ -1074,7 +1060,7 @@ SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length) return sz_true_k; } -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { +SZ_PUBLIC sz_cptr_t sz_find_byte_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { __mmask64 mask; sz_u512_vec_t h_vec, n_vec; n_vec.zmm = _mm512_set1_epi8(n[0]); @@ -1101,7 +1087,7 @@ SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); + if (n_length == 1) return sz_find_byte_skylake(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; @@ -1198,7 +1184,7 @@ SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { +SZ_PUBLIC sz_cptr_t sz_rfind_byte_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { __mmask64 mask; sz_u512_vec_t h_vec, n_vec; n_vec.zmm = _mm512_set1_epi8(n[0]); @@ -1225,7 +1211,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); + if (n_length == 1) return sz_rfind_byte_skylake(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; @@ -1583,7 +1569,7 @@ SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, s matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); while (matches) { int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; + if (sz_equal_neon(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } } @@ -1623,7 +1609,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); while (matches) { int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) + if (sz_equal_neon(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && "The bit must be set before we squash it"); @@ -1678,6 +1664,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_ch #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SVE @@ -1692,8 +1679,8 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_ch #pragma region Core Funcitonality SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_ICE - return sz_find_byte_avx512(haystack, h_length, needle); +#if SZ_USE_SKYLAKE + return sz_find_byte_skylake(haystack, h_length, needle); #elif SZ_USE_HASWELL return sz_find_byte_haswell(haystack, h_length, needle); #elif SZ_USE_NEON @@ -1704,8 +1691,8 @@ SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cpt } SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_ICE - return sz_rfind_byte_avx512(haystack, h_length, needle); +#if SZ_USE_SKYLAKE + return sz_rfind_byte_skylake(haystack, h_length, needle); #elif SZ_USE_HASWELL return sz_rfind_byte_haswell(haystack, h_length, needle); #elif SZ_USE_NEON @@ -1716,7 +1703,7 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cp } SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_ICE +#if SZ_USE_SKYLAKE return sz_find_skylake(haystack, h_length, needle, n_length); #elif SZ_USE_HASWELL return sz_find_haswell(haystack, h_length, needle, n_length); @@ -1728,7 +1715,7 @@ SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t n } SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_ICE +#if SZ_USE_SKYLAKE return sz_rfind_skylake(haystack, h_length, needle, n_length); #elif SZ_USE_HASWELL return sz_rfind_haswell(haystack, h_length, needle, n_length); From 4a1f03c46b4f60be3b28e31f58b500734e23699e Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:32:35 +0000 Subject: [PATCH 037/751] Make: Separate builds for Skylake & Ice --- CMakeLists.txt | 24 +++++++++++++----------- README.md | 2 +- build.rs | 46 ++++++++++++++++++++++++++-------------------- c/lib.c | 41 +++++++++++++++++++++++------------------ setup.py | 27 +++++++++++++++------------ 5 files changed, 78 insertions(+), 62 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 93a9b847..c09fd6e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,7 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|AARCH64|arm64|ARM64") message(STATUS "Platform: ARM") endif() -# Determine if StringZilla is built as a subproject (using `add_subdirectory`) +# Determine if StringZilla is built as a sub-project (using `add_subdirectory`) # or if it is the main project set(STRINGZILLA_IS_MAIN_PROJECT OFF) @@ -99,7 +99,7 @@ endif() if (MSVC) # Remove /RTC* from MSVC debug flags by default (it will be added back in the set_compiler_flags function) - # Beacuse /RTC* cannot be used without the crt so it needs to be disabled for that specifc target + # Because /RTC* cannot be used without the crt so it needs to be disabled for that specific target string(REGEX REPLACE "/RTC[^ ]*" "" CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") string(REGEX REPLACE "/RTC[^ ]*" "" CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}") endif() @@ -303,18 +303,20 @@ if(${STRINGZILLA_BUILD_SHARED}) endif() target_compile_definitions(${target} PRIVATE - "SZ_USE_X86_AVX512=1" - "SZ_USE_X86_AVX2=1" - "SZ_USE_ARM_NEON=0" - "SZ_USE_ARM_SVE=0") + "SZ_USE_HASWELL=1" + "SZ_USE_SKYLAKE=1" + "SZ_USE_ICE=1" + "SZ_USE_NEON=0" + "SZ_USE_SVE=0") elseif(SZ_PLATFORM_ARM) set_compiler_flags(${target} "" "armv8-a") target_compile_definitions(${target} PRIVATE - "SZ_USE_X86_AVX512=0" - "SZ_USE_X86_AVX2=0" - "SZ_USE_ARM_NEON=1" - "SZ_USE_ARM_SVE=1") + "SZ_USE_HASWELL=0" + "SZ_USE_SKYLAKE=0" + "SZ_USE_ICE=0" + "SZ_USE_NEON=1" + "SZ_USE_SVE=1") endif() if (MSVC) @@ -337,7 +339,7 @@ if(${STRINGZILLA_BUILD_SHARED}) target_compile_definitions(stringzillite PRIVATE "SZ_AVOID_LIBC=1") target_compile_definitions(stringzillite PRIVATE "SZ_OVERRIDE_LIBC=1") - # Avoid built-ins on MSVC and other compilers, as that will cause compileration errors + # Avoid built-ins on MSVC and other compilers, as that will cause compilation errors target_compile_options(stringzillite PRIVATE "$<$:-fno-builtin;-nostdlib>" "$<$:/Oi-;/GS->") diff --git a/README.md b/README.md index d5c59ff9..c4122696 100644 --- a/README.md +++ b/README.md @@ -1172,7 +1172,7 @@ __`SZ_DEBUG`__: > If you want to enable more aggressive bounds-checking, define `SZ_DEBUG` before including the header. > If not explicitly set, it will be inferred from the build type. -__`SZ_USE_X86_AVX512`, `SZ_USE_X86_AVX2`, `SZ_USE_ARM_NEON`__: +__`SZ_USE_HASWELL`, `SZ_USE_SKYLAKE`, `SZ_USE_ICE`, `SZ_USE_NEON`, `SZ_USE_SVE`__: > One can explicitly disable certain families of SIMD instructions for compatibility purposes. > Default values are inferred at compile time. diff --git a/build.rs b/build.rs index 8f7a130d..bb5fb5cf 100644 --- a/build.rs +++ b/build.rs @@ -25,20 +25,22 @@ fn main() { // Set architecture-specific flags and macros if target_arch == "x86_64" { - build.define("SZ_USE_X86_AVX512", "1"); - build.define("SZ_USE_X86_AVX2", "1"); + build.define("SZ_USE_HASWELL", "1"); + build.define("SZ_USE_SKYLAKE", "1"); + build.define("SZ_USE_ICE", "1"); } else { - build.define("SZ_USE_X86_AVX512", "0"); - build.define("SZ_USE_X86_AVX2", "0"); + build.define("SZ_USE_HASWELL", "0"); + build.define("SZ_USE_SKYLAKE", "0"); + build.define("SZ_USE_ICE", "0"); } if target_arch == "aarch64" { build.flag_if_supported("-march=armv8-a+simd"); - build.define("SZ_USE_ARM_SVE", "1"); - build.define("SZ_USE_ARM_NEON", "1"); + build.define("SZ_USE_NEON", "1"); + build.define("SZ_USE_SVE", "1"); } else { - build.define("SZ_USE_ARM_SVE", "0"); - build.define("SZ_USE_ARM_NEON", "0"); + build.define("SZ_USE_NEON", "0"); + build.define("SZ_USE_SVE", "0"); } } else if target.contains("darwin") { build.flag_if_supported("-fcolor-diagnostics"); @@ -47,28 +49,32 @@ fn main() { if target_arch == "x86_64" { // Assuming no AVX-512 support for Darwin as per setup.py logic - build.define("SZ_USE_X86_AVX512", "0"); - build.define("SZ_USE_X86_AVX2", "1"); + build.define("SZ_USE_HASWELL", "1"); + build.define("SZ_USE_SKYLAKE", "0"); + build.define("SZ_USE_ICE", "0"); } else { - build.define("SZ_USE_X86_AVX512", "0"); - build.define("SZ_USE_X86_AVX2", "0"); + build.define("SZ_USE_HASWELL", "0"); + build.define("SZ_USE_SKYLAKE", "0"); + build.define("SZ_USE_ICE", "0"); } if target_arch == "aarch64" { - build.define("SZ_USE_ARM_SVE", "0"); // Assuming no SVE support for Darwin - build.define("SZ_USE_ARM_NEON", "1"); + build.define("SZ_USE_NEON", "1"); + build.define("SZ_USE_SVE", "0"); // Assuming no SVE support for Darwin } else { - build.define("SZ_USE_ARM_SVE", "0"); - build.define("SZ_USE_ARM_NEON", "0"); + build.define("SZ_USE_NEON", "0"); + build.define("SZ_USE_SVE", "0"); } } else if target.contains("windows") { // Set architecture-specific flags and macros if target_arch == "x86_64" { - build.define("SZ_USE_X86_AVX512", "1"); - build.define("SZ_USE_X86_AVX2", "1"); + build.define("SZ_USE_HASWELL", "1"); + build.define("SZ_USE_SKYLAKE", "1"); + build.define("SZ_USE_ICE", "1"); } else { - build.define("SZ_USE_X86_AVX512", "0"); - build.define("SZ_USE_X86_AVX2", "0"); + build.define("SZ_USE_HASWELL", "0"); + build.define("SZ_USE_SKYLAKE", "0"); + build.define("SZ_USE_ICE", "0"); } } diff --git a/c/lib.c b/c/lib.c index 2394bf59..e1d98328 100644 --- a/c/lib.c +++ b/c/lib.c @@ -77,7 +77,7 @@ SZ_INTERNAL sz_capability_t sz_capabilities_arm(void) { SZ_DYNAMIC sz_capability_t sz_capabilities(void) { -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE /// The states of 4 registers populated for a specific "cpuid" assembly call union four_registers_t { @@ -131,7 +131,7 @@ SZ_DYNAMIC sz_capability_t sz_capabilities(void) { #endif // SZ_TARGET_X86 -#if SZ_USE_ARM_NEON || SZ_USE_ARM_SVE +#if SZ_USE_NEON || SZ_USE_SVE return sz_capabilities_arm(); @@ -196,7 +196,7 @@ static void sz_dispatch_table_init(void) { impl->alignment_score = sz_alignment_score_serial; impl->hashes = sz_hashes_serial; -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL if (caps & sz_cap_x86_avx2_k) { impl->equal = sz_equal_avx2; impl->order = sz_order_avx2; @@ -216,34 +216,36 @@ static void sz_dispatch_table_init(void) { } #endif -#if SZ_USE_X86_AVX512 +#if SZ_USE_SKYLAKE if (caps & sz_cap_x86_avx512f_k) { - impl->equal = sz_equal_avx512; + impl->equal = sz_equal_skylake; impl->order = sz_order_avx512; impl->copy = sz_copy_avx512; impl->move = sz_move_avx512; impl->fill = sz_fill_avx512; - impl->find = sz_find_avx512; - impl->rfind = sz_rfind_avx512; + impl->find = sz_find_skylake; + impl->rfind = sz_rfind_skylake; impl->find_byte = sz_find_byte_avx512; impl->rfind_byte = sz_rfind_byte_avx512; impl->edit_distance = sz_edit_distance_avx512; } +#endif +#if SZ_USE_ICE if ((caps & sz_cap_x86_avx512f_k) && (caps & sz_cap_x86_avx512vl_k) && (caps & sz_cap_x86_avx512vbmi2_k) && (caps & sz_cap_x86_avx512bw_k) && (caps & sz_cap_x86_avx512vbmi_k)) { - impl->find_from_set = sz_find_charset_avx512; - impl->rfind_from_set = sz_rfind_charset_avx512; + impl->find_from_set = sz_find_charset_ice; + impl->rfind_from_set = sz_rfind_charset_ice; impl->alignment_score = sz_alignment_score_avx512; - impl->look_up_transform = sz_look_up_transform_avx512; + impl->look_up_transform = sz_look_up_transform_ice; impl->checksum = sz_checksum_avx512; } #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON if (caps & sz_cap_arm_neon_k) { impl->equal = sz_equal_neon; @@ -361,14 +363,16 @@ SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); } -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { +SZ_DYNAMIC sz_ssize_t sz_alignment_score( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { return sz_dispatch_table.alignment_score(a, a_length, b, b_length, subs, gap, alloc); } -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { +SZ_DYNAMIC void sz_hashes( // + sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { sz_dispatch_table.hashes(text, length, window_length, step, callback, callback_handle); } @@ -409,8 +413,9 @@ sz_u64_t _sz_random_generator(void *empty_state) { } #endif -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { +SZ_DYNAMIC void sz_generate( // + sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, + sz_random_generator_t generator, void *generator_user_data) { #if !SZ_AVOID_LIBC if (!generator) generator = _sz_random_generator; #endif diff --git a/setup.py b/setup.py index 25b769e8..27ef6be2 100644 --- a/setup.py +++ b/setup.py @@ -54,10 +54,11 @@ def linux_settings() -> Tuple[List[str], List[str], List[Tuple[str]]]: # GCC is our primary compiler, so when packaging the library, even if the current machine # doesn't support AVX-512 or SVE, still precompile those. macros_args = [ - ("SZ_USE_X86_AVX512", "1" if is_64bit_x86() else "0"), - ("SZ_USE_X86_AVX2", "1" if is_64bit_x86() else "0"), - ("SZ_USE_ARM_SVE", "1" if is_64bit_arm() else "0"), - ("SZ_USE_ARM_NEON", "1" if is_64bit_arm() else "0"), + ("SZ_USE_HASWELL", "1" if is_64bit_x86() else "0"), + ("SZ_USE_SKYLAKE", "1" if is_64bit_x86() else "0"), + ("SZ_USE_ICE", "1" if is_64bit_x86() else "0"), + ("SZ_USE_NEON", "1" if is_64bit_arm() else "0"), + ("SZ_USE_SVE", "1" if is_64bit_arm() else "0"), ("SZ_DETECT_BIG_ENDIAN", "1" if is_big_endian() else "0"), ] @@ -89,10 +90,11 @@ def darwin_settings() -> Tuple[List[str], List[str], List[Tuple[str]]]: # During Universal builds, however, even AVX header cause compilation errors. can_use_avx2 = is_64bit_x86() and sysconfig.get_platform().startswith("universal") macros_args = [ - ("SZ_USE_X86_AVX512", "0"), - ("SZ_USE_X86_AVX2", "1" if can_use_avx2 else "0"), - ("SZ_USE_ARM_SVE", "0"), - ("SZ_USE_ARM_NEON", "1" if is_64bit_arm() else "0"), + ("SZ_USE_HASWELL", "1" if can_use_avx2 else "0"), + ("SZ_USE_SKYLAKE", "0"), + ("SZ_USE_ICE", "0"), + ("SZ_USE_NEON", "1" if is_64bit_arm() else "0"), + ("SZ_USE_SVE", "0"), ] return compile_args, link_args, macros_args @@ -107,10 +109,11 @@ def windows_settings() -> Tuple[List[str], List[str], List[Tuple[str]]]: # When packaging the library, even if the current machine doesn't support AVX-512 or SVE, still precompile those. macros_args = [ - ("SZ_USE_X86_AVX512", "1" if is_64bit_x86() else "0"), - ("SZ_USE_X86_AVX2", "1" if is_64bit_x86() else "0"), - ("SZ_USE_ARM_SVE", "0"), - ("SZ_USE_ARM_NEON", "1" if is_64bit_arm() else "0"), + ("SZ_USE_HASWELL", "1" if is_64bit_x86() else "0"), + ("SZ_USE_SKYLAKE", "1" if is_64bit_x86() else "0"), + ("SZ_USE_ICE", "1" if is_64bit_x86() else "0"), + ("SZ_USE_NEON", "1" if is_64bit_arm() else "0"), + ("SZ_USE_SVE", "0"), ("SZ_DETECT_BIG_ENDIAN", "1" if is_big_endian() else "0"), ] From 5b55e19d1378c61da88309b30a38f9cf7c64bf79 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:37:32 +0000 Subject: [PATCH 038/751] Fix: Filter `small_string.h` file --- include/stringzilla/small_string.h | 7148 +--------------------------- 1 file changed, 216 insertions(+), 6932 deletions(-) diff --git a/include/stringzilla/small_string.h b/include/stringzilla/small_string.h index de7fbcac..17625700 100644 --- a/include/stringzilla/small_string.h +++ b/include/stringzilla/small_string.h @@ -1,365 +1,47 @@ /** - * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. - * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs - * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. - * - * Consider overriding the following macros to customize the library: - * - * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. - * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. - * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. - * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. - * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html - * - * @file stringzilla.h + * @brief Small String Optimization implemented as a C 99 structure. + * @file small_string.h * @author Ash Vardanian - */ -#ifndef STRINGZILLA_H_ -#define STRINGZILLA_H_ - -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 0 - -/** - * @brief When set to 1, the library will include the following LibC headers: and . - * In debug builds (SZ_DEBUG=1), the library will also include and . * - * You may want to disable this compiling for use in the kernel, or in embedded systems. - * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. - * https://artificial-mind.net/projects/compile-health/ - */ -#ifndef SZ_AVOID_LIBC -#define SZ_AVOID_LIBC (0) // true or false -#endif - -/** - * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address - * that is not divisible by eight. On x86 enabled by default. On ARM it's not. - * - * Most platforms support it, but there is no industry standard way to check for those. - * This value will mostly affect the performance of the serial (SWAR) backend. - */ -#ifndef SZ_USE_MISALIGNED_LOADS -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) -#define SZ_USE_MISALIGNED_LOADS (1) // true or false -#else -#define SZ_USE_MISALIGNED_LOADS (0) // true or false -#endif -#endif - -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif - -/** - * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. - * 64-bit on most platforms where pointers are 64-bit. - * 32-bit on platforms where pointers are 32-bit. - */ -#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. -#else -#define SZ_DETECT_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. -#endif - -/** - * @brief On Big-Endian machines StringZilla will work in compatibility mode. - * This disables SWAR hacks to minimize code duplication, assuming practically - * all modern popular platforms are Little-Endian. + * Includes core APIs: + * - `sz_string_init` + * - `sz_string_init_length` + * - `sz_string_free` * - * This variable is hard to infer from macros reliably. It's best to set it manually. - * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). - * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. - * https://stackoverflow.com/a/27054190 - */ -#ifndef SZ_DETECT_BIG_ENDIAN -#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ - defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. - */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) -#else -#define SZ_DEBUG (0) -#endif -#endif - -/** - * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. - * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. - * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. - */ -#ifndef SZ_SWAR_THRESHOLD -#if SZ_DEBUG -#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds -#else -#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds -#endif -#endif - -/* Annotation for the public API symbols: + * Accessing the underlying string: + * - `sz_string_is_on_stack` + * - `sz_string_unpack` + * - `sz_string_range` + * - `sz_string_equal` + * - `sz_string_order` * - * - `SZ_PUBLIC` is used for functions that are part of the public API. - * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. - * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. + * Modifying the string: + * - `sz_string_reserve` + * - `sz_string_expand` + * - `sz_string_erase` + * - `sz_string_shrink_to_fit` */ -#ifndef SZ_DYNAMIC -#if SZ_DYNAMIC_DISPATCH -#if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_DYNAMIC __declspec(dllexport) -#define SZ_EXTERNAL __declspec(dllimport) -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#else -#define SZ_DYNAMIC __attribute__((visibility("default"))) -#define SZ_EXTERNAL extern -#define SZ_PUBLIC __attribute__((unused)) inline static -#define SZ_INTERNAL __attribute__((always_inline)) inline static -#endif // _WIN32 || __CYGWIN__ -#else -#define SZ_DYNAMIC inline static -#define SZ_EXTERNAL extern -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#endif // SZ_DYNAMIC_DISPATCH -#endif // SZ_DYNAMIC +#ifndef STRINGZILLA_SMALL_STRING_H_ +#define STRINGZILLA_SMALL_STRING_H_ -/** - * @brief Alignment macro for 64-byte alignment. - */ -#if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) -#elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) -#else -#define SZ_ALIGN64 -#endif +#include "find.h" // `sz_equal` +#include "memory.h" // `sz_copy`, `sz_move`, `sz_fill` +#include "types.h" // `sz_size_t`, `sz_ptr_t`, `sz_cptr_t` #ifdef __cplusplus extern "C" { #endif -/* - * Let's infer the integer types or pull them from LibC, - * if that is allowed by the user. - */ -#if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` -typedef int8_t sz_i8_t; // Always 8 bits -typedef uint8_t sz_u8_t; // Always 8 bits -typedef uint16_t sz_u16_t; // Always 16 bits -typedef int32_t sz_i32_t; // Always 32 bits -typedef uint32_t sz_u32_t; // Always 32 bits -typedef uint64_t sz_u64_t; // Always 64 bits -typedef int64_t sz_i64_t; // Always 64 bits -typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits -typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits - -#else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. -typedef signed char sz_i8_t; // Always 8 bits -typedef unsigned char sz_u8_t; // Always 8 bits -typedef unsigned short sz_u16_t; // Always 16 bits -typedef int sz_i32_t; // Always 32 bits -typedef unsigned int sz_u32_t; // Always 32 bits -typedef long long sz_i64_t; // Always 64 bits -typedef unsigned long long sz_u64_t; // Always 64 bits - -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT -typedef unsigned long long sz_size_t; // 64-bit. -typedef long long sz_ssize_t; // 64-bit. -#else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT - -#endif // SZ_AVOID_LIBC - -/** - * @brief Compile-time assert macro similar to `static_assert` in C++. - */ -#define sz_static_assert(condition, name) \ - typedef struct { \ - int static_assert_##name : (condition) ? 1 : -1; \ - } sz_static_assert_##name##_t - -sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); -sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); - -#pragma region Public API - -typedef char *sz_ptr_t; // A type alias for `char *` -typedef char const *sz_cptr_t; // A type alias for `char const *` -typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions - -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings - -typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit -typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> - -/** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. - */ -typedef struct sz_string_view_t { - sz_cptr_t start; - sz_size_t length; -} sz_string_view_t; - -/** - * @brief Enumeration of SIMD capabilities of the target architecture. - * Used to introspect the supported functionality of the dynamic library. - */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability - - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability - -} sz_capability_t; - -/** - * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. - * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. - */ -SZ_DYNAMIC sz_capability_t sz_capabilities(void); - -/** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert - */ -typedef union sz_charset_t { - sz_u64_t _u64s[4]; - sz_u32_t _u32s[8]; - sz_u16_t _u16s[16]; - sz_u8_t _u8s[32]; -} sz_charset_t; - -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } - -/** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } - -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast - -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { - // Checking the bit can be done in different ways: - // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 - // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 - // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 - // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 - return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); -} - -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast -} - -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { - s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // - s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; -} - -typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); -typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); - -/** - * @brief Some complex pattern matching algorithms may require memory allocations. - * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed - */ -typedef struct sz_memory_allocator_t { - sz_memory_allocate_t allocate; - sz_memory_free_t free; - void *handle; -} sz_memory_allocator_t; - -/** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. - */ -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); - -/** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. - * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. - */ -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); +#pragma region Core Structure /** * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE +#ifdef _SZ_STRING_INTERNAL_SPACE +#undef _SZ_STRING_INTERNAL_SPACE #endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length +#define _SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length /** * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). @@ -376,7 +58,7 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void */ typedef union sz_string_t { -#if !SZ_DETECT_BIG_ENDIAN +#if !_SZ_IS_BIG_ENDIAN struct external { sz_ptr_t start; @@ -388,7 +70,7 @@ typedef union sz_string_t { struct internal { sz_ptr_t start; sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; + char chars[_SZ_STRING_INTERNAL_SPACE]; } internal; #else @@ -402,7 +84,7 @@ typedef union sz_string_t { struct internal { sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; + char chars[_SZ_STRING_INTERNAL_SPACE]; sz_u8_t length; } internal; @@ -412,206 +94,9 @@ typedef union sz_string_t { } sz_string_t; -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); -typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); - -/** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. - * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. - * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap - * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. - */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); +#pragma endregion // Core Structure -/** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Checks if all characters in the range are valid ASCII characters. - * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. - */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - -/** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. - * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. - */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); - -/** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. - */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); +#pragma region Core API /** * @brief Initializes a string class instance to an empty value. @@ -634,8 +119,8 @@ SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); +SZ_PUBLIC void sz_string_unpack( // + sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, sz_bool_t *is_external); /** * @brief Unpacks only the start and length of the string. @@ -681,8 +166,8 @@ SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity * @param allocator Memory allocator to use for the allocation. * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); +SZ_PUBLIC sz_ptr_t sz_string_expand( // + sz_string_t *string, sz_size_t offset, sz_size_t added_length, sz_memory_allocator_t *allocator); /** * @brief Removes a range from a string. Changes the length, but not the capacity. @@ -714,6443 +199,242 @@ SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *alloca #pragma endregion -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); - -/** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +#pragma region Serial Implementation -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { + // It doesn't matter if it's on stack or heap, the pointer location is the same. + return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); +} -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { + sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; + sz_size_t is_big_mask = is_small - 1ull; + *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. + // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. + *length = string->external.length & (0x00000000000000FFull | is_big_mask); +} -/** - * @brief Locates the last matching substring. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_PUBLIC void sz_string_unpack( // + sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, sz_bool_t *is_external) { + sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; + sz_size_t is_big_mask = is_small - 1ull; + *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. + // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. + *length = string->external.length & (0x00000000000000FFull | is_big_mask); + // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. + *space = sz_u64_blend(_SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); + *is_external = (sz_bool_t)!is_small; +} -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { + // Tempting to say that the external.length is bitwise the same even if it includes + // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. + // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this + // the hard/correct way. -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +#if SZ_USE_MISALIGNED_LOADS + // Dealing with StringZilla strings, we know that the `start` pointer always points + // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +#endif + // Alternatively, fall back to byte-by-byte comparison. + sz_ptr_t a_start, b_start; + sz_size_t a_length, b_length; + sz_string_range(a, &a_start, &a_length); + sz_string_range(b, &b_start, &b_length); + return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); +} -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { +#if SZ_USE_MISALIGNED_LOADS + // Dealing with StringZilla strings, we know that the `start` pointer always points + // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +#endif + // Alternatively, fall back to byte-by-byte comparison. + sz_ptr_t a_start, b_start; + sz_size_t a_length, b_length; + sz_string_range(a, &a_start, &a_length); + sz_string_range(b, &b_start, &b_length); + return sz_order(a_start, a_length, b_start, b_length); +} -#pragma endregion +SZ_PUBLIC void sz_string_init(sz_string_t *string) { + sz_assert(string && "String can't be SZ_NULL."); -#pragma region String Similarity Measures API + // Only 8 + 1 + 1 need to be initialized. + string->internal.start = &string->internal.chars[0]; + // But for safety let's initialize the entire structure to zeros. + // string->internal.chars[0] = 0; + // string->internal.length = 0; + string->words[1] = 0; + string->words[2] = 0; + string->words[3] = 0; +} -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); - -/** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); - -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); - -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); - -#pragma endregion - -#pragma region Convenience API - -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -#pragma endregion - -#pragma region String Sequences API - -struct sz_sequence_t; - -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); - -typedef struct sz_sequence_t { - sz_sorted_idx_t *order; - sz_size_t count; - sz_sequence_member_start_t get_start; - sz_sequence_member_length_t get_length; - void const *handle; -} sz_sequence_t; - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); - -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. - * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. - */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); - -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); - -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); - -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. - */ -#ifndef SZ_USE_X86_AVX512 -#ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 -#else -#define SZ_USE_X86_AVX512 0 -#endif -#endif - -#ifndef SZ_USE_X86_AVX2 -#ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 -#else -#define SZ_USE_X86_AVX2 0 -#endif -#endif - -#ifndef SZ_USE_ARM_NEON -#ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 -#else -#define SZ_USE_ARM_NEON 0 -#endif -#endif - -#ifndef SZ_USE_ARM_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 -#else -#define SZ_USE_ARM_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 -#include -#endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_ARM_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_X86_AVX512 - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_X86_AVX2 -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_ARM_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_ARM_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes - -/** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. - * @note If you want to catch it, put a breakpoint at @b `__GI_exit` - */ -#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` -#include // `EXIT_FAILURE` -SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { - fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); - exit(EXIT_FAILURE); -} -#define sz_assert(condition) \ - do { \ - if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ - } while (0) -#else -#define sz_assert(condition) ((void)(condition)) -#endif - -/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. - * The following section of compiler intrinsics comes in 2 flavors. - */ -#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL -#include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. -#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { - x = x - ((x >> 1) & 0x5555555555555555ull); - x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); - return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; -} -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x80000000u) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { - x = x - ((x >> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >> 2) & 0x33333333); - return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; -} -#else -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } -#endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given -#pragma intrinsic(_byteswap_uint64) -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } -#pragma intrinsic(_byteswap_ulong) -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } -#else -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } -#endif - -SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } - -/** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. - * - * Similar to `_mm_blend_epi16` intrinsic on x86. - * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. - * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching - */ -SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } - -/* - * Efficiently computing the minimum and maximum of two or three values can be tricky. - * The simple branching baseline would be: - * - * x < y ? x : y // can replace with 1 conditional move - * - * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. - * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function - * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax - * Using only bit-shifts for singed integers it would be: - * - * y + ((x - y) & (x - y) >> 31) // 4 unique operations - * - * Alternatively, for any integers using multiplication: - * - * (x > y) * y + (x <= y) * x // 5 operations - * - * Alternatively, to avoid multiplication: - * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if SZ_DETECT_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if SZ_DETECT_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. - * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; - - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - -/** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - */ -#define SZ_U32_MAX_PRIME (2147483647u) - -/** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 - */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) - -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} - -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { - static sz_u8_t const lowered[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return lowered[c]; -} - -/** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { - static sz_u8_t const upped[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return upped[c]; -} - -/** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. - */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); -} - -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; -} - -/** - * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { - - if (!length) return sz_true_k; - sz_u8_t const *h = (sz_u8_t const *)text; - sz_u8_t const *const h_end = h + length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; -#endif - - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); - - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); - - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; - } - } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} - -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); -} - -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; -} - -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} - -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); - - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} - -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); -} - -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", off) -#endif -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - sz_ptr_t end = target + length; - // Dealing with short strings, a single sequential pass would be faster. - // If the size is larger than 2 words, then at least 1 of them will be aligned. - // But just one aligned word may not be worth SWAR. - if (length < SZ_SWAR_THRESHOLD) - while (target != end) *(target++) = value; - - // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. - else { - sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; - while ((sz_size_t)target & 7ull) *(target++) = value; - while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; - while (target != end) *(target++) = value; - } -} -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", on) -#endif - -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); -} - -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. - // Existing implementations often have two passes, in normal and reversed order, - // depending on the relation of `target` and `source` addresses. - // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html - // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 - // - // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. - // Or if we know that they don't intersect! In that case the traversal order is irrelevant, - // but older CPUs may predict and fetch forward-passes better. - if (target < source || target >= source + length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - target += length, source += length; -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; -#endif - while (length--) *(--target) = *(--source); - } -} - -#pragma endregion - -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; -} - -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} - -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; - } -} - -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} - -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; - } -} - -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} - -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } - - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } - - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; - } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); -} - -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; - - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } - - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } -} - -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if SZ_DETECT_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; - } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif -} - -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} - -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. - */ -#pragma region AVX2 Implementation - -#if SZ_USE_X86_AVX2 -#pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - char value_char = *(char *)&value; - __m256i value_vec = _mm256_set1_epi8(value_char); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); - // sz_fill_serial(target, length, value); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 32) sz_fill_serial(target, length, value); - // When the buffer is aligned, we can avoid any split-stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u16_t value16 = (sz_u16_t)value * 0x0101u; - sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; - sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; - } -} - -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 32; target += 32, source += 32, length -= 32) - // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - // sz_copy_serial(target, source, length); - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); - if (length) sz_copy_serial(target, source, length); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; target += 32, source += 32, body_length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { - _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); - _mm256_store_si256((__m256i *)(target + body_length - 32), - _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); - } - if (body_length) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; - } -} - -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - for (target += length, source += length; length >= 32; length -= 32) - _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); - while (length--) *(--target) = *(--source); - } -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} - -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // We need to pull the lookup table into 8x YMM registers. - // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, - // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, - // so that we can at least compensate high latency with twice larger window and one more level of lookup. - sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // - lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // - lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // - lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; - - lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); - lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); - lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); - lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); - lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); - lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); - lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); - lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); - lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); - lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); - lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); - lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); - lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); - lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); - lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); - lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); - - // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. - sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; - - /// Top and bottom nibbles of the source are used separately. - sz_u256_vec_t source_vec, source_bot_vec; - sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, - blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; - - // Handling the head. - while (length >= 32) { - // Load and separate the nibbles of each byte in the source. - source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); - source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); - - // In the first round, we select using the 4th bit. - not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - - // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. - // The first round selects using the 3rd bit. - not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_32_to_63_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_third_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - blended_96_to_127_vec.ymm, // - blended_64_to_95_vec.ymm, // - not_third_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_160_to_191_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_third_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - blended_224_to_255_vec.ymm, // - blended_192_to_223_vec.ymm, // - not_third_bit_vec.ymm); - - // The second round selects using the 2nd bit. - not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_64_to_95_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_second_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_192_to_223_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_second_bit_vec.ymm); - - // The third round selects using the 1st bit. - not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_128_to_159_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_first_bit_vec.ymm); - - // And dump the result into the target. - _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); - source += 32, target += 32, length -= 32; - } - - // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } - } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif -#pragma endregion - -/* - * @brief AVX-512 implementation of the string search algorithms. - * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT - */ -#pragma region AVX512 Implementation - -#if SZ_USE_X86_AVX512 -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} - -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } - - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } - } - - return sz_equal_k; -} - -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } - - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - __m512i value_vec = _mm512_set1_epi8(value); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); - // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, value_vec); - } - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, value_vec); - for (target += head_length; body_length >= 64; target += 64, body_length -= 64) - _mm512_store_si512(target, value_vec); - _mm512_mask_storeu_epi8(target, tail_mask, value_vec); - } -} - -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 64; target += 64, source += 64, length -= 64) - // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); - // __mmask64 mask = _sz_u64_mask_until(length); - // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { - for (; length >= 64; target += 64, source += 64, length -= 64) - _mm512_store_si512(target, _mm512_load_si512(source)); - // At this point the length is guaranteed to be under 64. - __mmask64 mask = _sz_u64_mask_until(length); - // Aligned load and stores would work too, but it's not defined. - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. - // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(target + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source)); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (target += head_length, source += head_length; // - body_length >= 128; // - target += 64, source += 64, body_length -= 128) { - _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); - _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); - } - if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); - } -} - -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target == source) return; // Don't be silly, don't move the data if it's already there. - - // On very short buffers, that are one cache line in width or less, we don't need any loops. - // We can also avoid any data-dependencies between iterations, assuming we have 32 registers - // to pre-load the data, before writing it back. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - else if (length <= 128) { - sz_size_t last_length = length - 64; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); - _mm512_storeu_epi8(target, source0); - _mm512_mask_storeu_epi8(target + 64, mask, source1); - } - else if (length <= 192) { - sz_size_t last_length = length - 128; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_mask_storeu_epi8(target + 128, mask, source2); - } - else if (length <= 256) { - sz_size_t last_length = length - 192; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_loadu_epi8(source + 128); - __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_storeu_epi8(target + 128, source2); - _mm512_mask_storeu_epi8(target + 192, mask, source3); - } - - // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // The absolute most common case of using "moves" is shifting the data within a continuous buffer - // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, - // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. - // - // Remember: - // - if we are shifting data left, that we are traversing to the right. - // - if we are shifting data right, that we are traversing to the left. - int const left_to_right_traversal = source > target; - - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - } - } -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. - } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); - } - } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. - else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) - -/** - * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - - sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - - // The next few buffers will be swapped around. - sz_u512_vec_t previous_vec, current_vec, next_vec; - sz_u512_vec_t gaps_vec, substitutions_vec; - - // Load the strings into ZMM registers - just once. - sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; - longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); - rotate_left_vec.zmm = _mm512_set_epi8( // - 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // - 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); - rotate_right_vec.zmm = _mm512_set_epi8( // - 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // - 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); - ones_vec.zmm = _mm512_set1_epi8(1); - bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); - - // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Initialize the first two diagonals: - // - // previous_vec.u8s[0] = 0; - // current_vec.u8s[0] = current_vec.u8s[1] = 1; - // - // We can do a similar thing with vector ops: - previous_vec.zmm = _mm512_setzero_si512(); - current_vec.zmm = _mm512_set1_epi8(1); - - // We skip diagonals 0 and 1, as they are trivial. - // We will start with diagonal 2, which has length 3, with the first and last elements being preset, - // so we are effectively computing just one value, as will be marked by a single set bit in - // the `next_diagonal_mask` on the very first iteration. - sz_size_t next_diagonal_index = 2; - __mmask64 next_diagonal_mask = 0; - - // Progress through the upper triangle of the Levenshtein matrix. - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // The mask also adds one set bit. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = current_vec.zmm; - current_vec.zmm = next_vec.zmm; - - // Shift the shorter string - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. - for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { - // After this iteration, the value `shorted_dim - 1` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // Make sure we update the first entry. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the bottom right triangle. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - // In every following iterations we take use a shorter prefix of each register, - // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. - next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); - } - return current_vec.u8s[0]; -} - -/** - * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. - * - * This may be one of the most freuqently called kernels for: - * - source code analysis, assuming most lines are either under 80 or under 120 characters long. - * - DNA sequence alignment, as most short reads are 50-300 characters long. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, - * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. - * - * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - sz_unused(shorter && longer && bound && alloc); - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize! - sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && bound && max_length); - -#if 0 - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. - // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; - sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - // The next few pointers will be swapped around. - sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + longer_dim; - sz_u16_t *next_distances = current_distances + longer_dim; - sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); - - // Export the reversed string into the buffer. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Using ZMM registers, we can process 32x 16-bit values at once, - // storing 16 bytes of each string in YMM registers. - sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; - sz_u512_vec_t ones_u16_vec; - ones_u16_vec.zmm = _mm512_set1_epi16(1); - - // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. - sz_u512_vec_t shorter_vec, longer_vec; - sz_u512_vec_t ones_u8_vec; - ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); - // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` - // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the - // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - - // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to - // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // - remaining_length_mask, - shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow - // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, - _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); - deletions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); - offset_within_diagonal += register_length; - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); - // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. - // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would - // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. - // Which simplified would be equal to just `[i]`. Beautiful! - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. The compare it against the accumulated - // substitution costs. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; - } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; -} - -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; - } - - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} - -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ - apply_to = function) - -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // We need to pull the lookup table into 4x ZMM registers. - // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` - // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to - // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. - // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. - // - // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 6 cycles latency, ports: 1*FP12 - // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p05 - // - On Genoa: 1 cycle latency, ports: 1*FP0123 - // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 4 cycles latency, ports: 1*FP01 - // - sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); - lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); - lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); - lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); - - sz_u512_vec_t first_bit_vec, second_bit_vec; - first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); - second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); - - __mmask64 first_bit_mask, second_bit_mask; - sz_u512_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; - - // Handling the head. - if (head_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); - source += head_length, target += head_length, length -= head_length; - } - - // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. - while (length >= 64) { - source_vec.zmm = _mm512_loadu_si512(source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! - source += 64, target += 64, length -= 64; - } - - // Handling the tail. - if (tail_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); - source += tail_length, target += tail_length, length -= tail_length; - } -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { - - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; -} - -/** - * Computes the Needleman Wunsch alignment score between two strings. - * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. - * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used - * on strings not exceeding 2^24 length or 16.7 million characters. - * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store - * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of - * a 256 x 256 matrix, but from a single row! - */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t const max_length = 256ull * 256ull * 256ull; - sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && max_length); - - sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; - sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); - sz_i32_t *previous_distances = distances; - sz_i32_t *current_distances = previous_distances + n; - - // Intialize the first row of the Levenshtein matrix with `iota`. - for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) - previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; - - /// Contains up to 16 consecutive characters from the longer string. - sz_u512_vec_t longer_vec; - sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; - sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; - sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; - - // Prepare constants and masks. - sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; - { - char is_third_or_fourth_check, is_second_or_fourth_check; - *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; - is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); - is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); - gap_vec.zmm = _mm512_set1_epi32(gap); - } - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; - - // Load one row of the substitution matrix into four ZMM registers. - sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; - row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); - row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); - row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); - row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); - - // In the serial version we have one forward pass, that computes the deletion, - // insertion, and substitution costs at once. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; - // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); - // } - // - // Given the complexity of handling the data-dependency between consecutive insertion cost computations - // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation - // separately. - // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. - // 2. Compute the pairwise minimum with deletion costs. - // 3. Inclusive prefix minimum computation to combine with addition costs. - // Proceeding with substitutions: - for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { - sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); - __mmask64 mask = _sz_u64_mask_until(register_length); - longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); - - // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source - // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. - // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. - shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); - shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); - shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); - shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); - - // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using - // the AND logical operation, checking the top two bits of every byte. - // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. - __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); - __mmask64 is_second_or_fourth = - _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); - lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( - is_third_or_fourth, - // Choose between the first and the second. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), - // Choose between the third and the fourth. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); - - // First, sign-extend lower and upper 16 bytes to 16-bit integers. - __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); - __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); - - // Now extend those 16-bit integers to 32-bit. - // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. - // To minimize the number of loads and stores, we can combine our substitution costs with the previous - // distances, containing the deletion costs. - { - cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); - cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Inclusive prefix minimum computation to combine with insertion costs. - // Simply disabling this operation results in 5x performance improvement, meaning - // that this operation is responsible for 80% of the total runtime. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // current_distances[idx_longer + 1] = - // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); - // } - // - // To perform the same operation in vectorized form, we need to perform a tree-like reduction, - // that will involve multiple steps. It's quite expensive and should be first tested in the - // "experimental" section. - // - // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); - // ... yet this approach is also quite expensive. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 16 to 31. - if (register_length > 16) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 32) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 48) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); - } - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[longer_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); - else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); -} - -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, - // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. - */ -#pragma region ARM NEON - -#if SZ_USE_ARM_NEON -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+simd") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) - -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; -} - -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // In most cases the `source` and the `target` are not aligned, but we should - // at least make sure that writes don't touch many cache lines. - // NEON has an instruction to load and write 64 bytes at once. - // - // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; - // length -= head_length; - // for (; length >= 64; target += 64, source += 64, length -= 64) - // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); - // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; - // - // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: - for (; length >= 16; target += 16, source += 16, length -= 16) - vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); - if (length) sz_copy_serial(target, source, length); -} - -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // When moving small buffers, using a small buffer on stack as a temporary storage is faster. - - if (target < source || target >= source + length) { - // Non-overlapping, proceed forward - sz_copy_neon(target, source, length); +SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { + sz_size_t space_needed = length + 1; // space for trailing \0 + sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); + // Initialize the string to zeros for safety. + string->words[1] = 0; + string->words[2] = 0; + string->words[3] = 0; + // If we are lucky, no memory allocations will be needed. + if (space_needed <= _SZ_STRING_INTERNAL_SPACE) { + string->internal.start = &string->internal.chars[0]; + string->internal.length = (sz_u8_t)length; } else { - // Overlapping, proceed backward - target += length; - source += length; - - sz_u128_vec_t src_vec; - while (length >= 16) { - target -= 16, source -= 16, length -= 16; - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - } - while (length) { - target -= 1, source -= 1, length -= 1; - *target = *source; - } - } -} - -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register - - while (length >= 16) { - vst1q_u8((sz_u8_t *)target, fill_vec); - target += 16; - length -= 16; + // If we are not lucky, we need to allocate memory. + string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); + if (!string->external.start) return SZ_NULL_CHAR; + string->external.length = length; + string->external.space = space_needed; } - - // Handle remaining bytes - if (length) sz_fill_serial(target, length, value); + sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); + string->external.start[length] = 0; + return string->external.start; } -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } +SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. + sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. - // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. - uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); - lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); - lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); - lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); + sz_size_t new_space = new_capacity + 1; + if (new_space <= _SZ_STRING_INTERNAL_SPACE) return string->external.start; - sz_u128_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u128_vec_t blended_0_to_255_vec; + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); + sz_assert(new_space > string_space && "New space must be larger than current."); - // Process the head with serial code - for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; + sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); + if (!new_start) return SZ_NULL_CHAR; - // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction - // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position - // within each 64-byte range of the table. - // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ - length -= head_length; - length -= tail_length; - for (; length >= 16; source += 16, target += 16, length -= 16) { - source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); - lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); - lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); - lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); - vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); - } + sz_copy(new_start, string_start, string_length); + string->external.start = new_start; + string->external.space = new_space; + string->external.padding = 0; + string->external.length = string_length; - // Process the tail with serial code - for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; + // Deallocate the old string. + if (string_is_external) allocator->free(string_start, string_space, allocator->handle); + return string->external.start; } -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); +SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; + sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - h += 16, h_length -= 16; - } + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - return sz_find_byte_serial(h, h_length, n); -} + // We may already be space-optimal, and in that case we don't need to do anything. + sz_size_t new_space = string_length + 1; + if (string_space == new_space || !string_is_external) return string->external.start; -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); + sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); + if (!new_start) return SZ_NULL_CHAR; - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; - } + sz_copy(new_start, string_start, string_length); + string->external.start = new_start; + string->external.space = new_space; + string->external.padding = 0; + string->external.length = string_length; - return sz_rfind_byte_serial(h, h_length, n); + // Deallocate the old string. + if (string_is_external) allocator->free(string_start, string_space, allocator->handle); + return string->external.start; } -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { +SZ_PUBLIC sz_ptr_t sz_string_expand( // + sz_string_t *string, sz_size_t offset, sz_size_t added_length, sz_memory_allocator_t *allocator) { - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} + sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); + // The user intended to extend the string. + offset = sz_min_of_two(offset, string_length); - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } + // If we are lucky, no memory allocations will be needed. + if (string_length + added_length < string_space) { + sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); + string_start[string_length + added_length] = 0; + // Even if the string is on the stack, the `+=` won't affect the tail of the string. + string->external.length += added_length; } + // If we are not lucky, we need to allocate more memory. else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - - return sz_find_charset_serial(h, h_length, set); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); + sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); + sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); + sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); + string_start = sz_string_reserve(string, new_space - 1, allocator); + if (!string_start) return SZ_NULL_CHAR; - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; + // Copy into the new buffer. + sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); + string_start[string_length + added_length] = 0; + string->external.length = string_length + added_length; } - return sz_rfind_charset_serial(h, h_length, set); + return string_start; } -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm Neon - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. - */ -#pragma region ARM SVE - -#if SZ_USE_ARM_SVE -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+sve") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) - -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - svuint8_t value_vec = svdup_u8(value); - sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) - - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svst1_u8(mask, (unsigned char *)target, value_vec); - } - else { - // Calculate head, body, and tail sizes - sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); - sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; - sz_size_t body_length = length - head_length - tail_length; +SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - // Handle unaligned head - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svst1_u8(head_mask, (unsigned char *)target, value_vec); - target += head_length; + sz_assert(string && "String can't be SZ_NULL."); - // Aligned body loop - for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { - svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); - } + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - // Handle unaligned tail - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svst1_u8(tail_mask, (unsigned char *)target, value_vec); - } -} + // Normalize the offset, it can't be larger than the length. + offset = sz_min_of_two(offset, string_length); -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_size_t vec_len = svcntb(); // Vector length in bytes + // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, + // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain + // exactly the delta between original and final length of this `string`. + length = sz_min_of_two(length, string_length - offset); - // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, - // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - // - // int is_huge = length >= 4ull * 1024ull * 1024ull; - // - // When the buffer is small, there isn't much to innovate. - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations - // and handle the head, body, and tail separately. We can also traverse the buffer in both directions - // as Arm generally supports more simultaneous stores than x86 CPUs. - // - // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. - // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) - // we will pay a huge penalty on loads, fetching the same content many times. - // It may be better to allow caching (and subsequent eviction), in favor of using four-element - // tuples, wich will be guaranteed to be a multiple of a cache line. + // There are 2 common cases, that wouldn't even require a `memmove`: + // 1. Erasing the entire contents of the string. + // In that case `length` argument will be equal or greater than `length` member. + // 2. Removing the tail of the string with something like `string.pop_back()` in C++. // - // Another approach is to use the `LD4B` instructions, which will populate four registers at once. - // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. - else { - // Calculating head, body, and tail sizes depends on the `vec_len`, - // but it's runtime constant, and the modulo operation is expensive! - // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. - sz_size_t head_length = 16 - ((sz_size_t)target % 16); - sz_size_t tail_length = (sz_size_t)(target + length) % 16; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned parts - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); - svst1_u8(head_mask, (unsigned char *)target, head_data); - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); - svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); - target += head_length; - source += head_length; - - // Aligned body loop, walking in two directions - for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { - svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); - svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); - svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); - svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); - } - // Up to (vec_len * 2 - 1) bytes of data may be left in the body, - // so we can unroll the last two optional loop iterations. - if (body_length > vec_len) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - body_length -= vec_len; - source += body_length; - target += body_length; - } - if (body_length) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm SVE - -#pragma endregion - -/* - * @brief Pick the right implementation for the string search algorithms. - */ -#pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { - - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; - - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ - - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); -} - -#if !SZ_DYNAMIC_DISPATCH - -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 - return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif -} - -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 - return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); -#endif -} - -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif -} - -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_copy_neon(target, source, length); -#else - sz_copy_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_move_neon(target, source, length); -#else - sz_move_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 - sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 - sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON - sz_fill_neon(target, length, value); -#else - sz_fill_serial(target, length, value); -#endif -} - -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON - sz_look_up_transform_neon(source, length, lut, target); -#else - sz_look_up_transform_serial(source, length, lut, target); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_find_byte_neon(haystack, h_length, needle); -#else - return sz_find_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); -#else - return sz_rfind_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_find_neon(haystack, h_length, needle, n_length); -#else - return sz_find_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); -#else - return sz_rfind_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); -#else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); -#endif -} - -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} - -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); -#else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); -#endif -} - -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); -#else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} + // In both of those, regardless of the location of the string - stack or heap, + // the erasing is as easy as setting the length to the offset. + // In every other case, we must `memmove` the tail of the string to the left. + if (offset + length < string_length) + sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); + // The `string->external.length = offset` assignment would discard last characters + // of the on-the-stack string, but inplace subtraction would work. + string->external.length -= length; + string_start[string_length - length] = 0; + return length; } -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); +SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { + if (!sz_string_is_on_stack(string)) + allocator->free(string->external.start, string->external.space, allocator->handle); + sz_string_init(string); } -#endif -#pragma endregion +#pragma endregion // Serial Implementation #ifdef __cplusplus -#pragma GCC diagnostic pop } #endif // __cplusplus - -#endif // STRINGZILLA_H_ +#endif // STRINGZILLA_SMALL_STRING_H_ From 1ba7982559111d4fc9b58caa7bc7aa1c6e64257c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:37:47 +0000 Subject: [PATCH 039/751] Fix: Filter `sort.h` file --- include/stringzilla/sort.h | 7325 ++---------------------------------- 1 file changed, 256 insertions(+), 7069 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index de7fbcac..4fe64bee 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -1,7156 +1,343 @@ /** - * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. - * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs - * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. - * - * Consider overriding the following macros to customize the library: - * - * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. - * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. - * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. - * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. + * @brief Hardware-accelerated string sorting. + * @file sort.h + * @author Ash Vardanian * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html + * Includes core APIs: * - * @file stringzilla.h - * @author Ash Vardanian + * - `sz_partition` - to split the sequence into two parts based on a predicate. + * - `sz_merge` - to merge two consecutive sorted chunks forming the same continuous `sequence`. + * - `sz_sort` - to sort an arbitrary string sequence. + * - `sz_sort_partial` - to partially sort an arbitrary string sequence. */ -#ifndef STRINGZILLA_H_ -#define STRINGZILLA_H_ +#ifndef STRINGZILLA_SORT_H_ +#define STRINGZILLA_SORT_H_ -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 0 - -/** - * @brief When set to 1, the library will include the following LibC headers: and . - * In debug builds (SZ_DEBUG=1), the library will also include and . - * - * You may want to disable this compiling for use in the kernel, or in embedded systems. - * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. - * https://artificial-mind.net/projects/compile-health/ - */ -#ifndef SZ_AVOID_LIBC -#define SZ_AVOID_LIBC (0) // true or false -#endif +#include "types.h" -/** - * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address - * that is not divisible by eight. On x86 enabled by default. On ARM it's not. - * - * Most platforms support it, but there is no industry standard way to check for those. - * This value will mostly affect the performance of the serial (SWAR) backend. - */ -#ifndef SZ_USE_MISALIGNED_LOADS -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) -#define SZ_USE_MISALIGNED_LOADS (1) // true or false -#else -#define SZ_USE_MISALIGNED_LOADS (0) // true or false -#endif +#ifdef __cplusplus +extern "C" { #endif -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif +#pragma region Core API /** - * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. - * 64-bit on most platforms where pointers are 64-bit. - * 32-bit on platforms where pointers are 32-bit. + * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. + * The algorithm is unstable, meaning that elements may change relative order, as long + * as they are in the right partition. This is the simpler algorithm for partitioning. */ -#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. -#else -#define SZ_DETECT_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. -#endif +SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); /** - * @brief On Big-Endian machines StringZilla will work in compatibility mode. - * This disables SWAR hacks to minimize code duplication, assuming practically - * all modern popular platforms are Little-Endian. + * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. * - * This variable is hard to infer from macros reliably. It's best to set it manually. - * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). - * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. - * https://stackoverflow.com/a/27054190 - */ -#ifndef SZ_DETECT_BIG_ENDIAN -#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ - defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. + * @param partition The number of elements in the first sub-sequence in `sequence`. + * @param less Comparison function, to determine the lexicographic ordering. */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) -#else -#define SZ_DEBUG (0) -#endif -#endif +SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); /** - * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. - * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. - * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. - */ -#ifndef SZ_SWAR_THRESHOLD -#if SZ_DEBUG -#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds -#else -#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds -#endif -#endif - -/* Annotation for the public API symbols: - * - * - `SZ_PUBLIC` is used for functions that are part of the public API. - * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. - * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. + * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word + * and a follow-up by a more conventional sorting procedure on equally prefixed parts. */ -#ifndef SZ_DYNAMIC -#if SZ_DYNAMIC_DISPATCH -#if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_DYNAMIC __declspec(dllexport) -#define SZ_EXTERNAL __declspec(dllimport) -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#else -#define SZ_DYNAMIC __attribute__((visibility("default"))) -#define SZ_EXTERNAL extern -#define SZ_PUBLIC __attribute__((unused)) inline static -#define SZ_INTERNAL __attribute__((always_inline)) inline static -#endif // _WIN32 || __CYGWIN__ -#else -#define SZ_DYNAMIC inline static -#define SZ_EXTERNAL extern -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#endif // SZ_DYNAMIC_DISPATCH -#endif // SZ_DYNAMIC +SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); /** - * @brief Alignment macro for 64-byte alignment. - */ -#if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) -#elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) -#else -#define SZ_ALIGN64 -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -/* - * Let's infer the integer types or pull them from LibC, - * if that is allowed by the user. + * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word + * and a follow-up by a more conventional sorting procedure on equally prefixed parts. */ -#if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` -typedef int8_t sz_i8_t; // Always 8 bits -typedef uint8_t sz_u8_t; // Always 8 bits -typedef uint16_t sz_u16_t; // Always 16 bits -typedef int32_t sz_i32_t; // Always 32 bits -typedef uint32_t sz_u32_t; // Always 32 bits -typedef uint64_t sz_u64_t; // Always 64 bits -typedef int64_t sz_i64_t; // Always 64 bits -typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits -typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits - -#else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. -typedef signed char sz_i8_t; // Always 8 bits -typedef unsigned char sz_u8_t; // Always 8 bits -typedef unsigned short sz_u16_t; // Always 16 bits -typedef int sz_i32_t; // Always 32 bits -typedef unsigned int sz_u32_t; // Always 32 bits -typedef long long sz_i64_t; // Always 64 bits -typedef unsigned long long sz_u64_t; // Always 64 bits - -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT -typedef unsigned long long sz_size_t; // 64-bit. -typedef long long sz_ssize_t; // 64-bit. -#else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT - -#endif // SZ_AVOID_LIBC +SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); /** - * @brief Compile-time assert macro similar to `static_assert` in C++. + * @brief Intro-Sort algorithm that supports custom comparators. */ -#define sz_static_assert(condition, name) \ - typedef struct { \ - int static_assert_##name : (condition) ? 1 : -1; \ - } sz_static_assert_##name##_t - -sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); -sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); - -#pragma region Public API +SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); -typedef char *sz_ptr_t; // A type alias for `char *` -typedef char const *sz_cptr_t; // A type alias for `char const *` -typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions +#pragma endregion -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings +#pragma region Serial Implementation -typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit -typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> +SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { -/** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. - */ -typedef struct sz_string_view_t { - sz_cptr_t start; - sz_size_t length; -} sz_string_view_t; + sz_size_t matches = 0; + while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; -/** - * @brief Enumeration of SIMD capabilities of the target architecture. - * Used to introspect the supported functionality of the dynamic library. - */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability + for (sz_size_t i = matches + 1; i < sequence->count; ++i) + if (predicate(sequence, sequence->order[i])) + sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability + return matches; +} -} sz_capability_t; +SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { -/** - * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. - * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. - */ -SZ_DYNAMIC sz_capability_t sz_capabilities(void); + sz_size_t start_b = partition + 1; -/** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert - */ -typedef union sz_charset_t { - sz_u64_t _u64s[4]; - sz_u32_t _u32s[8]; - sz_u16_t _u16s[16]; - sz_u8_t _u8s[32]; -} sz_charset_t; + // If the direct merge is already sorted + if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } + sz_size_t start_a = 0; + while (start_a <= partition && start_b <= sequence->count) { -/** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } + // If element 1 is in right place + if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } + else { + sz_size_t value = sequence->order[start_b]; + sz_size_t index = start_b; -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast + // Shift all the elements between element 1 + // element 2, right by 1. + while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } + sequence->order[start_a] = value; -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { - // Checking the bit can be done in different ways: - // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 - // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 - // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 - // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 - return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); + // Update all the pointers + start_a++; + partition++; + start_b++; + } + } } -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast +SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { + sz_u64_t *keys = sequence->order; + sz_size_t keys_count = sequence->count; + for (sz_size_t i = 1; i < keys_count; i++) { + sz_u64_t i_key = keys[i]; + sz_size_t j = i; + for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; + keys[j] = i_key; + } } -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { - s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // - s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; +SZ_INTERNAL void _sz_sift_down( // + sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, sz_size_t end) { + sz_size_t root = start; + while (2 * root + 1 <= end) { + sz_size_t child = 2 * root + 1; + if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } + if (!less(sequence, order[root], order[child])) { return; } + sz_u64_swap(order + root, order + child); + root = child; + } } -typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); -typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); - -/** - * @brief Some complex pattern matching algorithms may require memory allocations. - * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed - */ -typedef struct sz_memory_allocator_t { - sz_memory_allocate_t allocate; - sz_memory_free_t free; - void *handle; -} sz_memory_allocator_t; - -/** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. - */ -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); +SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { + sz_size_t start = (count - 2) / 2; + while (1) { + _sz_sift_down(sequence, less, order, start, count - 1); + if (start == 0) return; + start--; + } +} -/** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. - * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. - */ -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); +SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { + sz_u64_t *order = sequence->order; + sz_size_t count = last - first; + _sz_heapify(sequence, less, order + first, count); + sz_size_t end = count - 1; + while (end > 0) { + sz_u64_swap(order + first, order + first + end); + end--; + _sz_sift_down(sequence, less, order + first, 0, end); + } +} -/** - * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. - * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. - */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE -#endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length +SZ_PUBLIC void sz_sort_introsort_recursion( // + sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last, sz_size_t depth) { -/** - * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). - * Differs in layout from Folly, Clang, GCC, and probably most other implementations. - * It's designed to avoid any branches on read-only operations, and can store up - * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. - * - * @section Changing Length - * - * One nice thing about this design, is that you can, in many cases, change the length of the string - * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, - * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, - * only changing the last byte containing the length. - */ -typedef union sz_string_t { + sz_size_t length = last - first; + switch (length) { + case 0: + case 1: return; + case 2: + if (less(sequence, sequence->order[first + 1], sequence->order[first])) + sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); + return; + case 3: { + sz_u64_t a = sequence->order[first]; + sz_u64_t b = sequence->order[first + 1]; + sz_u64_t c = sequence->order[first + 2]; + if (less(sequence, b, a)) sz_u64_swap(&a, &b); + if (less(sequence, c, b)) sz_u64_swap(&c, &b); + if (less(sequence, b, a)) sz_u64_swap(&a, &b); + sequence->order[first] = a; + sequence->order[first + 1] = b; + sequence->order[first + 2] = c; + return; + } + } + // Until a certain length, the quadratic-complexity insertion-sort is fine + if (length <= 16) { + sz_sequence_t sub_seq = *sequence; + sub_seq.order += first; + sub_seq.count = length; + sz_sort_insertion(&sub_seq, less); + return; + } -#if !SZ_DETECT_BIG_ENDIAN + // Fallback to N-logN-complexity heap-sort + if (depth == 0) { + _sz_heapsort(sequence, less, first, last); + return; + } - struct external { - sz_ptr_t start; - sz_size_t length; - sz_size_t space; - sz_size_t padding; - } external; + --depth; - struct internal { - sz_ptr_t start; - sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; - } internal; + // Median-of-three logic to choose pivot + sz_size_t median = first + length / 2; + if (less(sequence, sequence->order[median], sequence->order[first])) + sz_u64_swap(&sequence->order[first], &sequence->order[median]); + if (less(sequence, sequence->order[last - 1], sequence->order[first])) + sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); + if (less(sequence, sequence->order[median], sequence->order[last - 1])) + sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); -#else + // Partition using the median-of-three as the pivot + sz_u64_t pivot = sequence->order[median]; + sz_size_t left = first; + sz_size_t right = last - 1; + while (1) { + while (less(sequence, sequence->order[left], pivot)) left++; + while (less(sequence, pivot, sequence->order[right])) right--; + if (left >= right) break; + sz_u64_swap(&sequence->order[left], &sequence->order[right]); + left++; + right--; + } - struct external { - sz_ptr_t start; - sz_size_t space; - sz_size_t padding; - sz_size_t length; - } external; + // Recursively sort the partitions + sz_sort_introsort_recursion(sequence, less, first, left, depth); + sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); +} - struct internal { - sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; - sz_u8_t length; - } internal; +SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { + if (sequence->count == 0) return; + sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; + sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; + sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); +} -#endif +SZ_PUBLIC void sz_sort_recursion( // + sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, + sz_size_t partial_order_length) { - sz_size_t words[4]; + if (!sequence->count) return; -} sz_string_t; + // Array of size one doesn't need sorting - only needs the prefix to be discarded. + if (sequence->count == 1) { + sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; + order_half_words[1] = 0; + return; + } -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); -typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); + // Partition a range of integers according to a specific bit value + sz_size_t split = 0; + sz_u64_t mask = (1ull << 63) >> bit_idx; -/** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. - * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); + // The clean approach would be to perform a single pass over the sequence. + // + // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; + // for (sz_size_t i = split + 1; i < sequence->count; ++i) + // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; + // + // This, however, doesn't take into account the high relative cost of writes and swaps. + // To circumvent that, we can first count the total number entries to be mapped into either part. + // And then walk through both parts, swapping the entries that are in the wrong part. + // This would often lead to ~15% performance gain. + sz_size_t count_with_bit_set = 0; + for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; + split = sequence->count - count_with_bit_set; -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. - * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap - * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. - */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Checks if all characters in the range are valid ASCII characters. - * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. - */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - -/** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. - * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. - */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); - -/** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. - */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); - -/** - * @brief Initializes a string class instance to an empty value. - */ -SZ_PUBLIC void sz_string_init(sz_string_t *string); - -/** - * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, - * alternative being - allocated in a remote region of the heap. - */ -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); - -/** - * @brief Unpacks the opaque instance of a string class into its components. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. - * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. - */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); - -/** - * @brief Unpacks only the start and length of the string. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - */ -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); - -/** - * @brief Constructs a string of a given ::length with noisy contents. - * Use the returned character pointer to populate the string. - * - * @param string String to initialize. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); - -/** - * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. - * This is beneficial, if several insertions are expected, and we want to minimize allocations. - * - * @param string String to grow. - * @param new_capacity The number of characters to reserve space for, including existing ones. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); - -/** - * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. - * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. - * Similar to `sz_string_reserve`, but changes the length of the ::string. - * - * @param string String to grow. - * @param offset Offset of the first byte to reserve space for. - * If provided offset is larger than the length, it will be capped. - * @param added_length The number of new characters to reserve space for. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); - -/** - * @brief Removes a range from a string. Changes the length, but not the capacity. - * Performs no allocations or deallocations and can't fail. - * - * @param string String to clean. - * @param offset Offset of the first byte to remove. - * @param length Number of bytes to remove. Out-of-bound ranges will be capped. - * @return Number of bytes removed. - */ -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); - -/** - * @brief Shrinks the string to fit the current length, if it's allocated on the heap. - * It's the reverse operation of ::sz_string_reserve. - * - * @param string String to shrink. - * @param allocator Memory allocator to use for the allocation. - * @return Whether the operation was successful. The only failures can come from the allocator. - * On failure, the string will remain unchanged. - */ -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); - -/** - * @brief Frees the string, if it's allocated on the heap. - * If the string is on the stack, the function clears/resets the state. - */ -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); - -#pragma endregion - -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); - -/** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Locates the last matching substring. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -#pragma endregion - -#pragma region String Similarity Measures API - -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); - -/** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); - -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); - -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); - -#pragma endregion - -#pragma region Convenience API - -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -#pragma endregion - -#pragma region String Sequences API - -struct sz_sequence_t; - -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); - -typedef struct sz_sequence_t { - sz_sorted_idx_t *order; - sz_size_t count; - sz_sequence_member_start_t get_start; - sz_sequence_member_length_t get_length; - void const *handle; -} sz_sequence_t; - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); - -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. - * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. - */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); - -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); - -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); - -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. - */ -#ifndef SZ_USE_X86_AVX512 -#ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 -#else -#define SZ_USE_X86_AVX512 0 -#endif -#endif - -#ifndef SZ_USE_X86_AVX2 -#ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 -#else -#define SZ_USE_X86_AVX2 0 -#endif -#endif - -#ifndef SZ_USE_ARM_NEON -#ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 -#else -#define SZ_USE_ARM_NEON 0 -#endif -#endif - -#ifndef SZ_USE_ARM_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 -#else -#define SZ_USE_ARM_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 -#include -#endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_ARM_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_X86_AVX512 - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_X86_AVX2 -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_ARM_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_ARM_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes - -/** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. - * @note If you want to catch it, put a breakpoint at @b `__GI_exit` - */ -#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` -#include // `EXIT_FAILURE` -SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { - fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); - exit(EXIT_FAILURE); -} -#define sz_assert(condition) \ - do { \ - if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ - } while (0) -#else -#define sz_assert(condition) ((void)(condition)) -#endif - -/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. - * The following section of compiler intrinsics comes in 2 flavors. - */ -#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL -#include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. -#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { - x = x - ((x >> 1) & 0x5555555555555555ull); - x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); - return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; -} -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x80000000u) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { - x = x - ((x >> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >> 2) & 0x33333333); - return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; -} -#else -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } -#endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given -#pragma intrinsic(_byteswap_uint64) -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } -#pragma intrinsic(_byteswap_ulong) -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } -#else -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } -#endif - -SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } - -/** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. - * - * Similar to `_mm_blend_epi16` intrinsic on x86. - * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. - * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching - */ -SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } - -/* - * Efficiently computing the minimum and maximum of two or three values can be tricky. - * The simple branching baseline would be: - * - * x < y ? x : y // can replace with 1 conditional move - * - * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. - * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function - * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax - * Using only bit-shifts for singed integers it would be: - * - * y + ((x - y) & (x - y) >> 31) // 4 unique operations - * - * Alternatively, for any integers using multiplication: - * - * (x > y) * y + (x <= y) * x // 5 operations - * - * Alternatively, to avoid multiplication: - * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if SZ_DETECT_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if SZ_DETECT_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. - * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; - - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - -/** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - */ -#define SZ_U32_MAX_PRIME (2147483647u) - -/** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 - */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) - -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} - -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { - static sz_u8_t const lowered[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return lowered[c]; -} - -/** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { - static sz_u8_t const upped[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return upped[c]; -} - -/** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. - */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); -} - -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; -} - -/** - * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { - - if (!length) return sz_true_k; - sz_u8_t const *h = (sz_u8_t const *)text; - sz_u8_t const *const h_end = h + length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; -#endif - - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); - - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); - - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; - } - } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} - -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); -} - -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; -} - -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} - -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); - - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} - -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); -} - -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", off) -#endif -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - sz_ptr_t end = target + length; - // Dealing with short strings, a single sequential pass would be faster. - // If the size is larger than 2 words, then at least 1 of them will be aligned. - // But just one aligned word may not be worth SWAR. - if (length < SZ_SWAR_THRESHOLD) - while (target != end) *(target++) = value; - - // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. - else { - sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; - while ((sz_size_t)target & 7ull) *(target++) = value; - while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; - while (target != end) *(target++) = value; - } -} -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", on) -#endif - -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); -} - -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. - // Existing implementations often have two passes, in normal and reversed order, - // depending on the relation of `target` and `source` addresses. - // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html - // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 - // - // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. - // Or if we know that they don't intersect! In that case the traversal order is irrelevant, - // but older CPUs may predict and fetch forward-passes better. - if (target < source || target >= source + length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - target += length, source += length; -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; -#endif - while (length--) *(--target) = *(--source); - } -} - -#pragma endregion - -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; -} - -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} - -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; - } -} - -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} - -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; - } -} - -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} - -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } - - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } - - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; - } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); -} - -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; - - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } - - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } -} - -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if SZ_DETECT_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; - } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif -} - -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} - -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. - */ -#pragma region AVX2 Implementation - -#if SZ_USE_X86_AVX2 -#pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - char value_char = *(char *)&value; - __m256i value_vec = _mm256_set1_epi8(value_char); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); - // sz_fill_serial(target, length, value); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 32) sz_fill_serial(target, length, value); - // When the buffer is aligned, we can avoid any split-stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u16_t value16 = (sz_u16_t)value * 0x0101u; - sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; - sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; - } -} - -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 32; target += 32, source += 32, length -= 32) - // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - // sz_copy_serial(target, source, length); - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); - if (length) sz_copy_serial(target, source, length); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; target += 32, source += 32, body_length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { - _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); - _mm256_store_si256((__m256i *)(target + body_length - 32), - _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); - } - if (body_length) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; - } -} - -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - for (target += length, source += length; length >= 32; length -= 32) - _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); - while (length--) *(--target) = *(--source); - } -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} - -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // We need to pull the lookup table into 8x YMM registers. - // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, - // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, - // so that we can at least compensate high latency with twice larger window and one more level of lookup. - sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // - lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // - lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // - lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; - - lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); - lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); - lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); - lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); - lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); - lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); - lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); - lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); - lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); - lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); - lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); - lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); - lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); - lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); - lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); - lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); - - // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. - sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; - - /// Top and bottom nibbles of the source are used separately. - sz_u256_vec_t source_vec, source_bot_vec; - sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, - blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; - - // Handling the head. - while (length >= 32) { - // Load and separate the nibbles of each byte in the source. - source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); - source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); - - // In the first round, we select using the 4th bit. - not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - - // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. - // The first round selects using the 3rd bit. - not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_32_to_63_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_third_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - blended_96_to_127_vec.ymm, // - blended_64_to_95_vec.ymm, // - not_third_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_160_to_191_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_third_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - blended_224_to_255_vec.ymm, // - blended_192_to_223_vec.ymm, // - not_third_bit_vec.ymm); - - // The second round selects using the 2nd bit. - not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_64_to_95_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_second_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_192_to_223_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_second_bit_vec.ymm); - - // The third round selects using the 1st bit. - not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_128_to_159_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_first_bit_vec.ymm); - - // And dump the result into the target. - _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); - source += 32, target += 32, length -= 32; - } - - // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } - } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif -#pragma endregion - -/* - * @brief AVX-512 implementation of the string search algorithms. - * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT - */ -#pragma region AVX512 Implementation - -#if SZ_USE_X86_AVX512 -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} - -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } - - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } - } - - return sz_equal_k; -} - -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } - - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - __m512i value_vec = _mm512_set1_epi8(value); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); - // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, value_vec); - } - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, value_vec); - for (target += head_length; body_length >= 64; target += 64, body_length -= 64) - _mm512_store_si512(target, value_vec); - _mm512_mask_storeu_epi8(target, tail_mask, value_vec); - } -} - -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 64; target += 64, source += 64, length -= 64) - // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); - // __mmask64 mask = _sz_u64_mask_until(length); - // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { - for (; length >= 64; target += 64, source += 64, length -= 64) - _mm512_store_si512(target, _mm512_load_si512(source)); - // At this point the length is guaranteed to be under 64. - __mmask64 mask = _sz_u64_mask_until(length); - // Aligned load and stores would work too, but it's not defined. - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. - // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(target + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source)); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (target += head_length, source += head_length; // - body_length >= 128; // - target += 64, source += 64, body_length -= 128) { - _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); - _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); - } - if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); - } -} - -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target == source) return; // Don't be silly, don't move the data if it's already there. - - // On very short buffers, that are one cache line in width or less, we don't need any loops. - // We can also avoid any data-dependencies between iterations, assuming we have 32 registers - // to pre-load the data, before writing it back. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - else if (length <= 128) { - sz_size_t last_length = length - 64; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); - _mm512_storeu_epi8(target, source0); - _mm512_mask_storeu_epi8(target + 64, mask, source1); - } - else if (length <= 192) { - sz_size_t last_length = length - 128; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_mask_storeu_epi8(target + 128, mask, source2); - } - else if (length <= 256) { - sz_size_t last_length = length - 192; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_loadu_epi8(source + 128); - __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_storeu_epi8(target + 128, source2); - _mm512_mask_storeu_epi8(target + 192, mask, source3); - } - - // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // The absolute most common case of using "moves" is shifting the data within a continuous buffer - // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, - // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. - // - // Remember: - // - if we are shifting data left, that we are traversing to the right. - // - if we are shifting data right, that we are traversing to the left. - int const left_to_right_traversal = source > target; - - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - } - } -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. - } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); - } - } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. - else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) - -/** - * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - - sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - - // The next few buffers will be swapped around. - sz_u512_vec_t previous_vec, current_vec, next_vec; - sz_u512_vec_t gaps_vec, substitutions_vec; - - // Load the strings into ZMM registers - just once. - sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; - longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); - rotate_left_vec.zmm = _mm512_set_epi8( // - 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // - 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); - rotate_right_vec.zmm = _mm512_set_epi8( // - 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // - 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); - ones_vec.zmm = _mm512_set1_epi8(1); - bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); - - // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Initialize the first two diagonals: - // - // previous_vec.u8s[0] = 0; - // current_vec.u8s[0] = current_vec.u8s[1] = 1; - // - // We can do a similar thing with vector ops: - previous_vec.zmm = _mm512_setzero_si512(); - current_vec.zmm = _mm512_set1_epi8(1); - - // We skip diagonals 0 and 1, as they are trivial. - // We will start with diagonal 2, which has length 3, with the first and last elements being preset, - // so we are effectively computing just one value, as will be marked by a single set bit in - // the `next_diagonal_mask` on the very first iteration. - sz_size_t next_diagonal_index = 2; - __mmask64 next_diagonal_mask = 0; - - // Progress through the upper triangle of the Levenshtein matrix. - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // The mask also adds one set bit. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = current_vec.zmm; - current_vec.zmm = next_vec.zmm; - - // Shift the shorter string - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. - for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { - // After this iteration, the value `shorted_dim - 1` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // Make sure we update the first entry. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the bottom right triangle. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - // In every following iterations we take use a shorter prefix of each register, - // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. - next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); - } - return current_vec.u8s[0]; -} - -/** - * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. - * - * This may be one of the most freuqently called kernels for: - * - source code analysis, assuming most lines are either under 80 or under 120 characters long. - * - DNA sequence alignment, as most short reads are 50-300 characters long. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, - * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. - * - * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - sz_unused(shorter && longer && bound && alloc); - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize! - sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && bound && max_length); - -#if 0 - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. - // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; - sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - // The next few pointers will be swapped around. - sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + longer_dim; - sz_u16_t *next_distances = current_distances + longer_dim; - sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); - - // Export the reversed string into the buffer. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Using ZMM registers, we can process 32x 16-bit values at once, - // storing 16 bytes of each string in YMM registers. - sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; - sz_u512_vec_t ones_u16_vec; - ones_u16_vec.zmm = _mm512_set1_epi16(1); - - // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. - sz_u512_vec_t shorter_vec, longer_vec; - sz_u512_vec_t ones_u8_vec; - ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); - // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` - // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the - // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - - // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to - // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // - remaining_length_mask, - shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow - // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, - _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); - deletions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); - offset_within_diagonal += register_length; - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); - // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. - // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would - // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. - // Which simplified would be equal to just `[i]`. Beautiful! - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. The compare it against the accumulated - // substitution costs. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; - } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; -} - -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; - } - - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} - -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ - apply_to = function) - -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // We need to pull the lookup table into 4x ZMM registers. - // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` - // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to - // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. - // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. - // - // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 6 cycles latency, ports: 1*FP12 - // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p05 - // - On Genoa: 1 cycle latency, ports: 1*FP0123 - // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 4 cycles latency, ports: 1*FP01 - // - sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); - lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); - lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); - lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); - - sz_u512_vec_t first_bit_vec, second_bit_vec; - first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); - second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); - - __mmask64 first_bit_mask, second_bit_mask; - sz_u512_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; - - // Handling the head. - if (head_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); - source += head_length, target += head_length, length -= head_length; - } - - // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. - while (length >= 64) { - source_vec.zmm = _mm512_loadu_si512(source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! - source += 64, target += 64, length -= 64; - } - - // Handling the tail. - if (tail_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); - source += tail_length, target += tail_length, length -= tail_length; - } -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { - - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; -} - -/** - * Computes the Needleman Wunsch alignment score between two strings. - * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. - * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used - * on strings not exceeding 2^24 length or 16.7 million characters. - * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store - * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of - * a 256 x 256 matrix, but from a single row! - */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t const max_length = 256ull * 256ull * 256ull; - sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && max_length); - - sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; - sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); - sz_i32_t *previous_distances = distances; - sz_i32_t *current_distances = previous_distances + n; - - // Intialize the first row of the Levenshtein matrix with `iota`. - for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) - previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; - - /// Contains up to 16 consecutive characters from the longer string. - sz_u512_vec_t longer_vec; - sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; - sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; - sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; - - // Prepare constants and masks. - sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; - { - char is_third_or_fourth_check, is_second_or_fourth_check; - *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; - is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); - is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); - gap_vec.zmm = _mm512_set1_epi32(gap); - } - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; - - // Load one row of the substitution matrix into four ZMM registers. - sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; - row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); - row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); - row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); - row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); - - // In the serial version we have one forward pass, that computes the deletion, - // insertion, and substitution costs at once. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; - // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); - // } - // - // Given the complexity of handling the data-dependency between consecutive insertion cost computations - // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation - // separately. - // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. - // 2. Compute the pairwise minimum with deletion costs. - // 3. Inclusive prefix minimum computation to combine with addition costs. - // Proceeding with substitutions: - for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { - sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); - __mmask64 mask = _sz_u64_mask_until(register_length); - longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); - - // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source - // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. - // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. - shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); - shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); - shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); - shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); - - // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using - // the AND logical operation, checking the top two bits of every byte. - // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. - __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); - __mmask64 is_second_or_fourth = - _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); - lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( - is_third_or_fourth, - // Choose between the first and the second. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), - // Choose between the third and the fourth. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); - - // First, sign-extend lower and upper 16 bytes to 16-bit integers. - __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); - __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); - - // Now extend those 16-bit integers to 32-bit. - // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. - // To minimize the number of loads and stores, we can combine our substitution costs with the previous - // distances, containing the deletion costs. - { - cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); - cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Inclusive prefix minimum computation to combine with insertion costs. - // Simply disabling this operation results in 5x performance improvement, meaning - // that this operation is responsible for 80% of the total runtime. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // current_distances[idx_longer + 1] = - // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); - // } - // - // To perform the same operation in vectorized form, we need to perform a tree-like reduction, - // that will involve multiple steps. It's quite expensive and should be first tested in the - // "experimental" section. - // - // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); - // ... yet this approach is also quite expensive. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 16 to 31. - if (register_length > 16) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 32) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 48) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); - } - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[longer_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); - else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); -} - -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, - // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. - */ -#pragma region ARM NEON - -#if SZ_USE_ARM_NEON -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+simd") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) - -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; -} - -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // In most cases the `source` and the `target` are not aligned, but we should - // at least make sure that writes don't touch many cache lines. - // NEON has an instruction to load and write 64 bytes at once. - // - // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; - // length -= head_length; - // for (; length >= 64; target += 64, source += 64, length -= 64) - // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); - // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; - // - // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: - for (; length >= 16; target += 16, source += 16, length -= 16) - vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); - if (length) sz_copy_serial(target, source, length); -} - -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // When moving small buffers, using a small buffer on stack as a temporary storage is faster. - - if (target < source || target >= source + length) { - // Non-overlapping, proceed forward - sz_copy_neon(target, source, length); - } - else { - // Overlapping, proceed backward - target += length; - source += length; - - sz_u128_vec_t src_vec; - while (length >= 16) { - target -= 16, source -= 16, length -= 16; - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - } - while (length) { - target -= 1, source -= 1, length -= 1; - *target = *source; - } - } -} - -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register - - while (length >= 16) { - vst1q_u8((sz_u8_t *)target, fill_vec); - target += 16; - length -= 16; - } - - // Handle remaining bytes - if (length) sz_fill_serial(target, length, value); -} - -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. - - // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. - // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. - uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); - lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); - lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); - lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); - - sz_u128_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u128_vec_t blended_0_to_255_vec; - - // Process the head with serial code - for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; - - // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction - // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position - // within each 64-byte range of the table. - // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ - length -= head_length; - length -= tail_length; - for (; length >= 16; source += 16, target += 16, length -= 16) { - source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); - lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); - lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); - lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); - vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); - } - - // Process the tail with serial code - for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - - h += 16, h_length -= 16; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { - - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} - -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); - - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; + // It's possible that the sequence is already partitioned. + if (split != 0 && split != sequence->count) { + // Use two pointers to efficiently reposition elements. + // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. + sz_size_t left = 0; + sz_size_t right = sequence->count - 1; + while (1) { + // Find the next element with the bit set on the left side. + while (left < split && !(sequence->order[left] & mask)) ++left; + // Find the next element without the bit set on the right side. + while (right >= split && (sequence->order[right] & mask)) --right; + // Swap the mispositioned elements. + if (left < split && right >= split) { + sz_u64_swap(sequence->order + left, sequence->order + right); + ++left; + --right; } + else { break; } } } - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - - return sz_find_charset_serial(h, h_length, set); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - } - - return sz_rfind_charset_serial(h, h_length, set); -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm Neon - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. - */ -#pragma region ARM SVE - -#if SZ_USE_ARM_SVE -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+sve") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) - -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - svuint8_t value_vec = svdup_u8(value); - sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) - - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svst1_u8(mask, (unsigned char *)target, value_vec); - } - else { - // Calculate head, body, and tail sizes - sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); - sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned head - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svst1_u8(head_mask, (unsigned char *)target, value_vec); - target += head_length; - - // Aligned body loop - for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { - svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); - } - - // Handle unaligned tail - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svst1_u8(tail_mask, (unsigned char *)target, value_vec); - } -} - -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_size_t vec_len = svcntb(); // Vector length in bytes + // Go down recursively. + if (bit_idx < bit_max) { + sz_sequence_t a = *sequence; + a.count = split; + sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, - // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - // - // int is_huge = length >= 4ull * 1024ull * 1024ull; - // - // When the buffer is small, there isn't much to innovate. - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); + sz_sequence_t b = *sequence; + b.order += split; + b.count -= split; + sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); } - // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations - // and handle the head, body, and tail separately. We can also traverse the buffer in both directions - // as Arm generally supports more simultaneous stores than x86 CPUs. - // - // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. - // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) - // we will pay a huge penalty on loads, fetching the same content many times. - // It may be better to allow caching (and subsequent eviction), in favor of using four-element - // tuples, wich will be guaranteed to be a multiple of a cache line. - // - // Another approach is to use the `LD4B` instructions, which will populate four registers at once. - // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. + // Reached the end of recursion. else { - // Calculating head, body, and tail sizes depends on the `vec_len`, - // but it's runtime constant, and the modulo operation is expensive! - // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. - sz_size_t head_length = 16 - ((sz_size_t)target % 16); - sz_size_t tail_length = (sz_size_t)(target + length) % 16; - sz_size_t body_length = length - head_length - tail_length; + // Discard the prefixes. + sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; + for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - // Handle unaligned parts - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); - svst1_u8(head_mask, (unsigned char *)target, head_data); - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); - svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); - target += head_length; - source += head_length; + sz_sequence_t a = *sequence; + a.count = split; + sz_sort_introsort(&a, comparator); - // Aligned body loop, walking in two directions - for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { - svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); - svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); - svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); - svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); - } - // Up to (vec_len * 2 - 1) bytes of data may be left in the body, - // so we can unroll the last two optional loop iterations. - if (body_length > vec_len) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - body_length -= vec_len; - source += body_length; - target += body_length; - } - if (body_length) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } + sz_sequence_t b = *sequence; + b.order += split; + b.count -= split; + sz_sort_introsort(&b, comparator); } } -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm SVE - -#pragma endregion - -/* - * @brief Pick the right implementation for the string search algorithms. - */ -#pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { - - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; - - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ - - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); -} - -#if !SZ_DYNAMIC_DISPATCH - -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 - return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif -} - -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 - return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); -#endif -} - -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif -} - -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_copy_neon(target, source, length); -#else - sz_copy_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_move_neon(target, source, length); -#else - sz_move_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 - sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 - sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON - sz_fill_neon(target, length, value); -#else - sz_fill_serial(target, length, value); -#endif -} - -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON - sz_look_up_transform_neon(source, length, lut, target); -#else - sz_look_up_transform_serial(source, length, lut, target); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_find_byte_neon(haystack, h_length, needle); -#else - return sz_find_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); -#else - return sz_rfind_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_find_neon(haystack, h_length, needle, n_length); -#else - return sz_find_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); -#else - return sz_rfind_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); +SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { + sz_cptr_t i_str = sequence->get_start(sequence, i_key); + sz_cptr_t j_str = sequence->get_start(sequence, j_key); + sz_size_t i_len = sequence->get_length(sequence, i_key); + sz_size_t j_len = sequence->get_length(sequence, j_key); + return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); } -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); -} +SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); +#if _SZ_IS_BIG_ENDIAN + // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. + sz_unused(partial_order_length); + sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); #else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); -#endif -} -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} + // Export up to 4 bytes into the `sequence` bits themselves + for (sz_size_t i = 0; i != sequence->count; ++i) { + sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); + sz_size_t length = sequence->get_length(sequence, sequence->order[i]); + length = length > 4u ? 4u : length; + sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; + for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; + } -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); -#else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); + // Perform optionally-parallel radix sort on them + sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); #endif } -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); +SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { +#if _SZ_IS_BIG_ENDIAN + sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); #else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); + sz_sort_partial(sequence, sequence->count); #endif } -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); -} - -#endif -#pragma endregion +#pragma endregion // Serial Implementation #ifdef __cplusplus -#pragma GCC diagnostic pop } #endif // __cplusplus - -#endif // STRINGZILLA_H_ +#endif // STRINGZILLA_SORT_H_ From b835051c09a0ecfc420932de444f3c6839610764 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:50:17 +0000 Subject: [PATCH 040/751] Fix: Filter `types.h` file --- include/stringzilla/types.h | 7139 +++-------------------------------- 1 file changed, 547 insertions(+), 6592 deletions(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index de7fbcac..a39620e6 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -1,31 +1,34 @@ /** - * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. - * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs - * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. + * @brief Shared definitions for the StringZilla library. + * @file types.h + * @author Ash Vardanian * * Consider overriding the following macros to customize the library: * * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. + * - `SZ_AVOID_LIBC=0` - whether to avoid including the standard C library headers. * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. - * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html - * - * @file stringzilla.h - * @author Ash Vardanian + * - `SZ_USE_HASWELL=?` - whether to use AVX2 instructions on x86_64. + * - `SZ_USE_SKYLAKE=?` - whether to use AVX-512 instructions on x86_64. + * - `SZ_USE_ICE=?` - whether to use AVX-512 VBMI instructions on x86_64. + * - `SZ_USE_NEON=?` - whether to use NEON instructions on ARM. + * - `SZ_USE_SVE=?` - whether to use SVE and SVE2 instructions on ARM. */ -#ifndef STRINGZILLA_H_ -#define STRINGZILLA_H_ +#ifndef STRINGZILLA_TYPES_H_ +#define STRINGZILLA_TYPES_H_ -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 0 +/* + * Debugging and testing. + */ +#ifndef SZ_DEBUG +#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". +#define SZ_DEBUG (1) +#else +#define SZ_DEBUG (0) +#endif +#endif /** * @brief When set to 1, the library will include the following LibC headers: and . @@ -39,6 +42,16 @@ #define SZ_AVOID_LIBC (0) // true or false #endif +/** + * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. + * So the `sz_find` function will invoke the most advanced backend supported by the CPU, + * that runs the program, rather than the most advanced backend supported by the CPU + * used to compile the library or the downstream application. + */ +#ifndef SZ_DYNAMIC_DISPATCH +#define SZ_DYNAMIC_DISPATCH (0) // true or false +#endif + /** * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address * that is not divisible by eight. On x86 enabled by default. On ARM it's not. @@ -54,27 +67,17 @@ #endif #endif -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif - /** * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. * 64-bit on most platforms where pointers are 64-bit. * 32-bit on platforms where pointers are 32-bit. */ #if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) +#define _SZ_IS_64_BIT (1) #define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. #define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. #else -#define SZ_DETECT_64_BIT (0) +#define _SZ_IS_64_BIT (0) #define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. #define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. #endif @@ -89,23 +92,12 @@ * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. * https://stackoverflow.com/a/27054190 */ -#ifndef SZ_DETECT_BIG_ENDIAN +#ifndef _SZ_IS_BIG_ENDIAN #if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. - */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) +#define _SZ_IS_BIG_ENDIAN (1) //< It's a big-endian target architecture #else -#define SZ_DEBUG (0) +#define _SZ_IS_BIG_ENDIAN (0) //< It's a little-endian target architecture #endif #endif @@ -153,12 +145,93 @@ * @brief Alignment macro for 64-byte alignment. */ #if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) +#define _SZ_ALIGN64 __declspec(align(64)) #elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) +#define _SZ_ALIGN64 __attribute__((aligned(64))) +#else +#define _SZ_ALIGN64 +#endif + +/** + * @brief Largest prime number that fits into 31 bits. + * @see https://mersenneforum.org/showthread.php?t=3471 + */ +#define SZ_U32_MAX_PRIME (2147483647u) + +/** + * @brief Largest prime number that fits into 64 bits. + * @see https://mersenneforum.org/showthread.php?t=3471 + * + * 2^64 = 18,446,744,073,709,551,616 + * this = 18,446,744,073,709,551,557 + * diff = 59 + */ +#define SZ_U64_MAX_PRIME (18446744073709551557ull) + +#if !SZ_AVOID_LIBC +#include // `size_t` +#include // `uint8_t` +#endif + +/* Compile-time hardware features detection. + * All of those can be controlled by the user. + */ +#ifndef SZ_USE_HASWELL +#ifdef __AVX2__ +#define SZ_USE_HASWELL 1 +#else +#define SZ_USE_HASWELL 0 +#endif +#endif + +#ifndef SZ_USE_SKYLAKE +#ifdef __AVX512F__ +#define SZ_USE_SKYLAKE 1 +#else +#define SZ_USE_SKYLAKE 0 +#endif +#endif + +#ifndef SZ_USE_ICE +#ifdef __AVX512BW__ +#define SZ_USE_ICE 1 +#else +#define SZ_USE_ICE 0 +#endif +#endif + +#ifndef SZ_USE_NEON +#ifdef __ARM_NEON +#define SZ_USE_NEON 1 +#else +#define SZ_USE_NEON 0 +#endif +#endif + +#ifndef SZ_USE_SVE +#ifdef __ARM_FEATURE_SVE +#define SZ_USE_SVE 1 #else -#define SZ_ALIGN64 +#define SZ_USE_SVE 0 +#endif +#endif + +/* Hardware-specific headers for different SIMD intrinsics and register wrappers. + */ +#if SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE +#include +#endif // SZ_USE_X86... +#if SZ_USE_NEON +#if !defined(_MSC_VER) +#include +#endif +#include +#endif // SZ_USE_NEON +#if SZ_USE_SVE +#if !defined(_MSC_VER) +#include #endif +#endif // SZ_USE_SVE #ifdef __cplusplus extern "C" { @@ -169,8 +242,6 @@ extern "C" { * if that is allowed by the user. */ #if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` typedef int8_t sz_i8_t; // Always 8 bits typedef uint8_t sz_u8_t; // Always 8 bits typedef uint16_t sz_u16_t; // Always 16 bits @@ -210,13 +281,13 @@ typedef unsigned long long sz_u64_t; // Always 64 bits // > `long long` is also 64 bits // // Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT +#if _SZ_IS_64_BIT typedef unsigned long long sz_size_t; // 64-bit. typedef long long sz_ssize_t; // 64-bit. #else typedef unsigned sz_size_t; // 32-bit. typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT +#endif // _SZ_IS_64_BIT #endif // SZ_AVOID_LIBC @@ -231,8 +302,6 @@ typedef unsigned sz_ssize_t; // 32-bit. sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); -#pragma region Public API - typedef char *sz_ptr_t; // A type alias for `char *` typedef char const *sz_cptr_t; // A type alias for `char const *` typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions @@ -242,6 +311,19 @@ typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strin typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> +/** + * @brief Describes the length of a UTF8 @b rune / character / codepoint in bytes. + */ +typedef enum { + sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. + sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. + sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. + sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. + sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. +} sz_rune_length_t; + +typedef sz_u32_t sz_rune_t; + /** * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. */ @@ -250,32 +332,7 @@ typedef struct sz_string_view_t { sz_size_t length; } sz_string_view_t; -/** - * @brief Enumeration of SIMD capabilities of the target architecture. - * Used to introspect the supported functionality of the dynamic library. - */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability - - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability - -} sz_capability_t; - -/** - * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. - * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. - */ -SZ_DYNAMIC sz_capability_t sz_capabilities(void); +#pragma region Character Sets /** * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. @@ -318,6 +375,10 @@ SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; } +#pragma endregion + +#pragma region Memory Management + typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); typedef sz_u64_t (*sz_random_generator_t)(void *); @@ -352,65 +413,9 @@ SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); */ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); -/** - * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. - * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. - */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE -#endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length - -/** - * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). - * Differs in layout from Folly, Clang, GCC, and probably most other implementations. - * It's designed to avoid any branches on read-only operations, and can store up - * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. - * - * @section Changing Length - * - * One nice thing about this design, is that you can, in many cases, change the length of the string - * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, - * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, - * only changing the last byte containing the length. - */ -typedef union sz_string_t { - -#if !SZ_DETECT_BIG_ENDIAN - - struct external { - sz_ptr_t start; - sz_size_t length; - sz_size_t space; - sz_size_t padding; - } external; - - struct internal { - sz_ptr_t start; - sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; - } internal; - -#else - - struct external { - sz_ptr_t start; - sz_size_t space; - sz_size_t padding; - sz_size_t length; - } external; - - struct internal { - sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; - sz_u8_t length; - } internal; - -#endif - - sz_size_t words[4]; +#pragma endregion -} sz_string_t; +#pragma region API Signature Types typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); @@ -418,667 +423,184 @@ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); -/** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. - * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); +typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); +typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. - * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); +typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); +typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); +typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); +typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); -/** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap - * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. - */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); -/** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); +typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, + sz_error_cost_t, sz_memory_allocator_t *); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); +typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); +typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); +typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); +typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); -/** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); +#pragma endregion -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); +#pragma region Helper Structures /** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. + * @brief Helper structure to simplify work with 16-bit words. + * @see sz_u16_load */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); +typedef union sz_u16_vec_t { + sz_u16_t u16; + sz_u8_t u8s[2]; +} sz_u16_vec_t; /** - * @brief Checks if all characters in the range are valid ASCII characters. - * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. + * @brief Helper structure to simplify work with 32-bit words. + * @see sz_u32_load */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); +typedef union sz_u32_vec_t { + sz_u32_t u32; + sz_u16_t u16s[2]; + sz_u8_t u8s[4]; +} sz_u32_vec_t; /** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. - * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. + * @brief Helper structure to simplify work with 64-bit words. + * @see sz_u64_load */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); +typedef union sz_u64_vec_t { + sz_u64_t u64; + sz_u32_t u32s[2]; + sz_u16_t u16s[4]; + sz_u8_t u8s[8]; +} sz_u64_vec_t; /** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. + * @brief Helper structure to simplify work with @b 128-bit registers. + * It can help view the contents as 8-bit, 16-bit, 32-bit, or 64-bit integers, + * as well as 1x XMM register. */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +typedef union sz_u128_vec_t { +#if SZ_USE_HASWELL + __m128i xmm; +#endif +#if SZ_USE_NEON + uint8x16_t u8x16; + uint16x8_t u16x8; + uint32x4_t u32x4; + uint64x2_t u64x2; +#endif + sz_u64_t u64s[2]; + sz_u32_t u32s[4]; + sz_u16_t u16s[8]; + sz_u8_t u8s[16]; +} sz_u128_vec_t; /** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. + * @brief Helper structure to simplify work with @b 256-bit registers. + * It can help view the contents as 8-bit, 16-bit, 32-bit, or 64-bit integers, + * as well as 2x XMM registers or 1x YMM register. */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); +typedef union sz_u256_vec_t { +#if SZ_USE_HASWELL + __m256i ymm; + __m128i xmms[2]; +#endif + sz_u64_t u64s[4]; + sz_u32_t u32s[8]; + sz_u16_t u16s[16]; + sz_u8_t u8s[32]; +} sz_u256_vec_t; /** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. + * @brief Helper structure to simplify work with @b 512-bit registers. + * It can help view the contents as 8-bit, 16-bit, 32-bit, or 64-bit integers, + * as well as 4x XMM registers or 2x YMM registers or 1x ZMM register. */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); +typedef union sz_u512_vec_t { +#if SZ_USE_ICE + __m512i zmm; +#endif +#if SZ_USE_HASWELL + __m256i ymms[2]; + __m128i xmms[4]; +#endif + sz_u64_t u64s[8]; + sz_i64_t i64s[8]; + sz_u32_t u32s[16]; + sz_i32_t i32s[16]; + sz_u16_t u16s[32]; + sz_u8_t u8s[64]; +} sz_u512_vec_t; -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); +#pragma endregion -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); +#pragma region UTF8 /** - * @brief Initializes a string class instance to an empty value. + * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. */ -SZ_PUBLIC void sz_string_init(sz_string_t *string); - -/** - * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, - * alternative being - allocated in a remote region of the heap. - */ -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); - -/** - * @brief Unpacks the opaque instance of a string class into its components. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. - * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. - */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); - -/** - * @brief Unpacks only the start and length of the string. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - */ -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); - -/** - * @brief Constructs a string of a given ::length with noisy contents. - * Use the returned character pointer to populate the string. - * - * @param string String to initialize. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); - -/** - * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. - * This is beneficial, if several insertions are expected, and we want to minimize allocations. - * - * @param string String to grow. - * @param new_capacity The number of characters to reserve space for, including existing ones. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); - -/** - * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. - * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. - * Similar to `sz_string_reserve`, but changes the length of the ::string. - * - * @param string String to grow. - * @param offset Offset of the first byte to reserve space for. - * If provided offset is larger than the length, it will be capped. - * @param added_length The number of new characters to reserve space for. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); - -/** - * @brief Removes a range from a string. Changes the length, but not the capacity. - * Performs no allocations or deallocations and can't fail. - * - * @param string String to clean. - * @param offset Offset of the first byte to remove. - * @param length Number of bytes to remove. Out-of-bound ranges will be capped. - * @return Number of bytes removed. - */ -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); - -/** - * @brief Shrinks the string to fit the current length, if it's allocated on the heap. - * It's the reverse operation of ::sz_string_reserve. - * - * @param string String to shrink. - * @param allocator Memory allocator to use for the allocation. - * @return Whether the operation was successful. The only failures can come from the allocator. - * On failure, the string will remain unchanged. - */ -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); - -/** - * @brief Frees the string, if it's allocated on the heap. - * If the string is on the stack, the function clears/resets the state. - */ -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); - -#pragma endregion - -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); - -/** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Locates the last matching substring. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -#pragma endregion - -#pragma region String Similarity Measures API - -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); - -/** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); - -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); - -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); - -#pragma endregion - -#pragma region Convenience API - -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); +SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { + sz_u8_t const *current = (sz_u8_t const *)utf8; + sz_u8_t leading_byte = *current++; + sz_rune_t ch; + sz_rune_length_t ch_length; -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); + // TODO: This can be made entirely branchless using 32-bit SWAR. + if (leading_byte < 0x80) { + // Single-byte rune (0xxxxxxx) + ch = leading_byte; + ch_length = sz_utf8_rune_1byte_k; + } + else if ((leading_byte & 0xE0) == 0xC0) { + // Two-byte rune (110xxxxx 10xxxxxx) + ch = (leading_byte & 0x1F) << 6; + ch |= (*current++ & 0x3F); + ch_length = sz_utf8_rune_2bytes_k; + } + else if ((leading_byte & 0xF0) == 0xE0) { + // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) + ch = (leading_byte & 0x0F) << 12; + ch |= (*current++ & 0x3F) << 6; + ch |= (*current++ & 0x3F); + ch_length = sz_utf8_rune_3bytes_k; + } + else if ((leading_byte & 0xF8) == 0xF0) { + // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + ch = (leading_byte & 0x07) << 18; + ch |= (*current++ & 0x3F) << 12; + ch |= (*current++ & 0x3F) << 6; + ch |= (*current++ & 0x3F); + ch_length = sz_utf8_rune_4bytes_k; + } + else { + // Invalid UTF8 rune. + ch = 0; + ch_length = sz_utf8_invalid_k; + } + *code = ch; + *code_length = ch_length; +} /** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset + * @brief Exports a UTF8 string into a UTF32 buffer. + * ! The result is undefined id the UTF8 string is corrupted. + * @return The length in the number of codepoints. */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); +SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { + sz_cptr_t const end = utf8 + utf8_length; + sz_size_t count = 0; + sz_rune_length_t rune_length; + for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); + return count; +} #pragma endregion @@ -1105,273 +627,66 @@ typedef struct sz_sequence_t { * Expects ::offsets to contains `count + 1` entries, the last pointing at the end * of the last string, indicating the total length of the ::tape. */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); +SZ_PUBLIC void sz_sequence_from_u32tape( // + sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, sz_sequence_t *sequence); /** * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. * Expects ::offsets to contains `count + 1` entries, the last pointing at the end * of the last string, indicating the total length of the ::tape. */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); +SZ_PUBLIC void sz_sequence_from_u64tape( // + sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, sz_sequence_t *sequence); -/** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); +#pragma endregion -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. +#pragma region Helper Functions + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#pragma GCC visibility push(hidden) + +/* + ********************************************************************************************************************** + ********************************************************************************************************************** + ********************************************************************************************************************** + * + * This is where we the actual implementation begins. + * The rest of the file is hidden from the public API. * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. + ********************************************************************************************************************** + ********************************************************************************************************************** + ********************************************************************************************************************** */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); /** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. + * @brief Helper-macro to mark potentially unused variables. */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); +#define sz_unused(x) ((void)(x)) /** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. + * @brief Helper-macro casting a variable to another type of the same size. */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); +#define sz_bitcast(type, value) (*((type *)&(value))) /** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. + * @brief Defines `SZ_NULL`, analogous to `NULL`. + * The default often comes from locale.h, stddef.h, + * stdio.h, stdlib.h, string.h, time.h, or wchar.h. */ -#ifndef SZ_USE_X86_AVX512 -#ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 +#ifdef __GNUG__ +#define SZ_NULL __null +#define SZ_NULL_CHAR __null #else -#define SZ_USE_X86_AVX512 0 -#endif +#define SZ_NULL ((void *)0) +#define SZ_NULL_CHAR ((char *)0) #endif -#ifndef SZ_USE_X86_AVX2 -#ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 -#else -#define SZ_USE_X86_AVX2 0 -#endif -#endif - -#ifndef SZ_USE_ARM_NEON -#ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 -#else -#define SZ_USE_ARM_NEON 0 -#endif -#endif - -#ifndef SZ_USE_ARM_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 -#else -#define SZ_USE_ARM_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 -#include -#endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_ARM_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_X86_AVX512 - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_X86_AVX2 -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_ARM_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_ARM_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes +/** + * @brief Cache-line width, that will affect the execution of some algorithms, + * like equality checks and relative order computing. + */ +#define SZ_CACHE_LINE_WIDTH (64) // bytes /** * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode @@ -1467,6 +782,17 @@ SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } #endif +/* + */ +SZ_INTERNAL sz_u16_t _sz_u16_mask_until(sz_size_t n) { return (0x0001u << n) - 1u; } +SZ_INTERNAL sz_u32_t _sz_u32_mask_until(sz_size_t n) { return (0x00000001u << n) - 1u; } +SZ_INTERNAL sz_u64_t _sz_u64_mask_until(sz_size_t n) { return (0x0000000000000001ull << n) - 1ull; } +SZ_INTERNAL sz_u16_t _sz_u16_clamp_mask_until(sz_size_t n) { return n < 16 ? _sz_u16_mask_until(n) : 0xFFFFu; } +SZ_INTERNAL sz_u32_t _sz_u32_clamp_mask_until(sz_size_t n) { return n < 32 ? _sz_u32_mask_until(n) : 0xFFFFFFFFu; } +SZ_INTERNAL sz_u64_t _sz_u64_clamp_mask_until(sz_size_t n) { + return n < 64 ? _sz_u64_mask_until(n) : 0xFFFFFFFFFFFFFFFFull; +} + SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } /** @@ -1497,5655 +823,284 @@ SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { retur * * Alternatively, to avoid multiplication: * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if SZ_DETECT_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if SZ_DETECT_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. - * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; - - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - -/** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - */ -#define SZ_U32_MAX_PRIME (2147483647u) - -/** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 - */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) - -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} - -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { - static sz_u8_t const lowered[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return lowered[c]; -} - -/** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { - static sz_u8_t const upped[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return upped[c]; -} - -/** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. - */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); -} - -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; -} - -/** - * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { - - if (!length) return sz_true_k; - sz_u8_t const *h = (sz_u8_t const *)text; - sz_u8_t const *const h_end = h + length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; -#endif - - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); - - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); - - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; - } - } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} - -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); -} - -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; -} - -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} - -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); - - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} - -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); -} - -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", off) -#endif -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - sz_ptr_t end = target + length; - // Dealing with short strings, a single sequential pass would be faster. - // If the size is larger than 2 words, then at least 1 of them will be aligned. - // But just one aligned word may not be worth SWAR. - if (length < SZ_SWAR_THRESHOLD) - while (target != end) *(target++) = value; - - // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. - else { - sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; - while ((sz_size_t)target & 7ull) *(target++) = value; - while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; - while (target != end) *(target++) = value; - } -} -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", on) -#endif - -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); -} - -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. - // Existing implementations often have two passes, in normal and reversed order, - // depending on the relation of `target` and `source` addresses. - // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html - // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 - // - // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. - // Or if we know that they don't intersect! In that case the traversal order is irrelevant, - // but older CPUs may predict and fetch forward-passes better. - if (target < source || target >= source + length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - target += length, source += length; -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; -#endif - while (length--) *(--target) = *(--source); - } -} - -#pragma endregion - -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; -} - -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} - -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; - } -} - -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} - -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; - } -} - -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} - -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } - - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } - - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; - } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); -} - -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; - - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } - - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } -} - -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if SZ_DETECT_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; - } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif -} - -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} - -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. - */ -#pragma region AVX2 Implementation - -#if SZ_USE_X86_AVX2 -#pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - char value_char = *(char *)&value; - __m256i value_vec = _mm256_set1_epi8(value_char); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); - // sz_fill_serial(target, length, value); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 32) sz_fill_serial(target, length, value); - // When the buffer is aligned, we can avoid any split-stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u16_t value16 = (sz_u16_t)value * 0x0101u; - sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; - sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; - } -} - -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 32; target += 32, source += 32, length -= 32) - // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - // sz_copy_serial(target, source, length); - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); - if (length) sz_copy_serial(target, source, length); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; target += 32, source += 32, body_length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { - _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); - _mm256_store_si256((__m256i *)(target + body_length - 32), - _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); - } - if (body_length) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; - } -} - -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - for (target += length, source += length; length >= 32; length -= 32) - _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); - while (length--) *(--target) = *(--source); - } -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} - -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // We need to pull the lookup table into 8x YMM registers. - // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, - // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, - // so that we can at least compensate high latency with twice larger window and one more level of lookup. - sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // - lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // - lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // - lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; - - lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); - lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); - lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); - lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); - lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); - lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); - lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); - lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); - lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); - lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); - lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); - lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); - lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); - lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); - lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); - lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); - - // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. - sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; - - /// Top and bottom nibbles of the source are used separately. - sz_u256_vec_t source_vec, source_bot_vec; - sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, - blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; - - // Handling the head. - while (length >= 32) { - // Load and separate the nibbles of each byte in the source. - source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); - source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); - - // In the first round, we select using the 4th bit. - not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - - // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. - // The first round selects using the 3rd bit. - not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_32_to_63_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_third_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - blended_96_to_127_vec.ymm, // - blended_64_to_95_vec.ymm, // - not_third_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_160_to_191_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_third_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - blended_224_to_255_vec.ymm, // - blended_192_to_223_vec.ymm, // - not_third_bit_vec.ymm); - - // The second round selects using the 2nd bit. - not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_64_to_95_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_second_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_192_to_223_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_second_bit_vec.ymm); - - // The third round selects using the 1st bit. - not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_128_to_159_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_first_bit_vec.ymm); - - // And dump the result into the target. - _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); - source += 32, target += 32, length -= 32; - } - - // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } - } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif -#pragma endregion - -/* - * @brief AVX-512 implementation of the string search algorithms. - * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT - */ -#pragma region AVX512 Implementation - -#if SZ_USE_X86_AVX512 -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} - -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } - - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } - } - - return sz_equal_k; -} - -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } - - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - __m512i value_vec = _mm512_set1_epi8(value); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); - // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, value_vec); - } - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, value_vec); - for (target += head_length; body_length >= 64; target += 64, body_length -= 64) - _mm512_store_si512(target, value_vec); - _mm512_mask_storeu_epi8(target, tail_mask, value_vec); - } -} - -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 64; target += 64, source += 64, length -= 64) - // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); - // __mmask64 mask = _sz_u64_mask_until(length); - // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { - for (; length >= 64; target += 64, source += 64, length -= 64) - _mm512_store_si512(target, _mm512_load_si512(source)); - // At this point the length is guaranteed to be under 64. - __mmask64 mask = _sz_u64_mask_until(length); - // Aligned load and stores would work too, but it's not defined. - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. - // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(target + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source)); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (target += head_length, source += head_length; // - body_length >= 128; // - target += 64, source += 64, body_length -= 128) { - _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); - _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); - } - if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); - } -} - -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target == source) return; // Don't be silly, don't move the data if it's already there. - - // On very short buffers, that are one cache line in width or less, we don't need any loops. - // We can also avoid any data-dependencies between iterations, assuming we have 32 registers - // to pre-load the data, before writing it back. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - else if (length <= 128) { - sz_size_t last_length = length - 64; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); - _mm512_storeu_epi8(target, source0); - _mm512_mask_storeu_epi8(target + 64, mask, source1); - } - else if (length <= 192) { - sz_size_t last_length = length - 128; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_mask_storeu_epi8(target + 128, mask, source2); - } - else if (length <= 256) { - sz_size_t last_length = length - 192; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_loadu_epi8(source + 128); - __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_storeu_epi8(target + 128, source2); - _mm512_mask_storeu_epi8(target + 192, mask, source3); - } - - // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // The absolute most common case of using "moves" is shifting the data within a continuous buffer - // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, - // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. - // - // Remember: - // - if we are shifting data left, that we are traversing to the right. - // - if we are shifting data right, that we are traversing to the left. - int const left_to_right_traversal = source > target; - - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - } - } -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. - } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); - } - } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. - else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) - -/** - * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - - sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - - // The next few buffers will be swapped around. - sz_u512_vec_t previous_vec, current_vec, next_vec; - sz_u512_vec_t gaps_vec, substitutions_vec; - - // Load the strings into ZMM registers - just once. - sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; - longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); - rotate_left_vec.zmm = _mm512_set_epi8( // - 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // - 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); - rotate_right_vec.zmm = _mm512_set_epi8( // - 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // - 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); - ones_vec.zmm = _mm512_set1_epi8(1); - bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); - - // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Initialize the first two diagonals: - // - // previous_vec.u8s[0] = 0; - // current_vec.u8s[0] = current_vec.u8s[1] = 1; - // - // We can do a similar thing with vector ops: - previous_vec.zmm = _mm512_setzero_si512(); - current_vec.zmm = _mm512_set1_epi8(1); - - // We skip diagonals 0 and 1, as they are trivial. - // We will start with diagonal 2, which has length 3, with the first and last elements being preset, - // so we are effectively computing just one value, as will be marked by a single set bit in - // the `next_diagonal_mask` on the very first iteration. - sz_size_t next_diagonal_index = 2; - __mmask64 next_diagonal_mask = 0; - - // Progress through the upper triangle of the Levenshtein matrix. - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // The mask also adds one set bit. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = current_vec.zmm; - current_vec.zmm = next_vec.zmm; - - // Shift the shorter string - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. - for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { - // After this iteration, the value `shorted_dim - 1` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // Make sure we update the first entry. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the bottom right triangle. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - // In every following iterations we take use a shorter prefix of each register, - // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. - next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); - } - return current_vec.u8s[0]; -} - -/** - * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. - * - * This may be one of the most freuqently called kernels for: - * - source code analysis, assuming most lines are either under 80 or under 120 characters long. - * - DNA sequence alignment, as most short reads are 50-300 characters long. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, - * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. - * - * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - sz_unused(shorter && longer && bound && alloc); - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize! - sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && bound && max_length); - -#if 0 - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. - // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; - sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - // The next few pointers will be swapped around. - sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + longer_dim; - sz_u16_t *next_distances = current_distances + longer_dim; - sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); - - // Export the reversed string into the buffer. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Using ZMM registers, we can process 32x 16-bit values at once, - // storing 16 bytes of each string in YMM registers. - sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; - sz_u512_vec_t ones_u16_vec; - ones_u16_vec.zmm = _mm512_set1_epi16(1); - - // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. - sz_u512_vec_t shorter_vec, longer_vec; - sz_u512_vec_t ones_u8_vec; - ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); - // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` - // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the - // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - - // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to - // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // - remaining_length_mask, - shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow - // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, - _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); - deletions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); - offset_within_diagonal += register_length; - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); - // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. - // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would - // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. - // Which simplified would be equal to just `[i]`. Beautiful! - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. The compare it against the accumulated - // substitution costs. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; - } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; -} - -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; - } - - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} - -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ - apply_to = function) - -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // We need to pull the lookup table into 4x ZMM registers. - // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` - // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to - // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. - // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. - // - // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 6 cycles latency, ports: 1*FP12 - // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p05 - // - On Genoa: 1 cycle latency, ports: 1*FP0123 - // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 4 cycles latency, ports: 1*FP01 - // - sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); - lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); - lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); - lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); - - sz_u512_vec_t first_bit_vec, second_bit_vec; - first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); - second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); - - __mmask64 first_bit_mask, second_bit_mask; - sz_u512_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; - - // Handling the head. - if (head_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); - source += head_length, target += head_length, length -= head_length; - } - - // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. - while (length >= 64) { - source_vec.zmm = _mm512_loadu_si512(source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! - source += 64, target += 64, length -= 64; - } - - // Handling the tail. - if (tail_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); - source += tail_length, target += tail_length, length -= tail_length; - } -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { - - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; -} - -/** - * Computes the Needleman Wunsch alignment score between two strings. - * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. - * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used - * on strings not exceeding 2^24 length or 16.7 million characters. - * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store - * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of - * a 256 x 256 matrix, but from a single row! - */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t const max_length = 256ull * 256ull * 256ull; - sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && max_length); - - sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; - sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); - sz_i32_t *previous_distances = distances; - sz_i32_t *current_distances = previous_distances + n; - - // Intialize the first row of the Levenshtein matrix with `iota`. - for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) - previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; - - /// Contains up to 16 consecutive characters from the longer string. - sz_u512_vec_t longer_vec; - sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; - sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; - sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; - - // Prepare constants and masks. - sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; - { - char is_third_or_fourth_check, is_second_or_fourth_check; - *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; - is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); - is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); - gap_vec.zmm = _mm512_set1_epi32(gap); - } - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; - - // Load one row of the substitution matrix into four ZMM registers. - sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; - row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); - row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); - row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); - row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); - - // In the serial version we have one forward pass, that computes the deletion, - // insertion, and substitution costs at once. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; - // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); - // } - // - // Given the complexity of handling the data-dependency between consecutive insertion cost computations - // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation - // separately. - // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. - // 2. Compute the pairwise minimum with deletion costs. - // 3. Inclusive prefix minimum computation to combine with addition costs. - // Proceeding with substitutions: - for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { - sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); - __mmask64 mask = _sz_u64_mask_until(register_length); - longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); - - // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source - // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. - // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. - shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); - shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); - shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); - shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); - - // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using - // the AND logical operation, checking the top two bits of every byte. - // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. - __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); - __mmask64 is_second_or_fourth = - _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); - lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( - is_third_or_fourth, - // Choose between the first and the second. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), - // Choose between the third and the fourth. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); - - // First, sign-extend lower and upper 16 bytes to 16-bit integers. - __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); - __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); - - // Now extend those 16-bit integers to 32-bit. - // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. - // To minimize the number of loads and stores, we can combine our substitution costs with the previous - // distances, containing the deletion costs. - { - cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); - cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Inclusive prefix minimum computation to combine with insertion costs. - // Simply disabling this operation results in 5x performance improvement, meaning - // that this operation is responsible for 80% of the total runtime. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // current_distances[idx_longer + 1] = - // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); - // } - // - // To perform the same operation in vectorized form, we need to perform a tree-like reduction, - // that will involve multiple steps. It's quite expensive and should be first tested in the - // "experimental" section. - // - // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); - // ... yet this approach is also quite expensive. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 16 to 31. - if (register_length > 16) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 32) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 48) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); - } - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[longer_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); - else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); -} - -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, - // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. - */ -#pragma region ARM NEON - -#if SZ_USE_ARM_NEON -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+simd") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) - -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; -} - -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // In most cases the `source` and the `target` are not aligned, but we should - // at least make sure that writes don't touch many cache lines. - // NEON has an instruction to load and write 64 bytes at once. - // - // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; - // length -= head_length; - // for (; length >= 64; target += 64, source += 64, length -= 64) - // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); - // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; - // - // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: - for (; length >= 16; target += 16, source += 16, length -= 16) - vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); - if (length) sz_copy_serial(target, source, length); -} - -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // When moving small buffers, using a small buffer on stack as a temporary storage is faster. - - if (target < source || target >= source + length) { - // Non-overlapping, proceed forward - sz_copy_neon(target, source, length); - } - else { - // Overlapping, proceed backward - target += length; - source += length; - - sz_u128_vec_t src_vec; - while (length >= 16) { - target -= 16, source -= 16, length -= 16; - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - } - while (length) { - target -= 1, source -= 1, length -= 1; - *target = *source; - } - } -} - -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register - - while (length >= 16) { - vst1q_u8((sz_u8_t *)target, fill_vec); - target += 16; - length -= 16; - } - - // Handle remaining bytes - if (length) sz_fill_serial(target, length, value); -} - -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. - - // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. - // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. - uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); - lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); - lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); - lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); - - sz_u128_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u128_vec_t blended_0_to_255_vec; - - // Process the head with serial code - for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; - - // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction - // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position - // within each 64-byte range of the table. - // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ - length -= head_length; - length -= tail_length; - for (; length >= 16; source += 16, target += 16, length -= 16) { - source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); - lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); - lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); - lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); - vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); - } - - // Process the tail with serial code - for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - - h += 16, h_length -= 16; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { - - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} - -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); - - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - - return sz_find_charset_serial(h, h_length, set); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - } - - return sz_rfind_charset_serial(h, h_length, set); -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm Neon - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. - */ -#pragma region ARM SVE - -#if SZ_USE_ARM_SVE -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+sve") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) - -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - svuint8_t value_vec = svdup_u8(value); - sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) - - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svst1_u8(mask, (unsigned char *)target, value_vec); - } - else { - // Calculate head, body, and tail sizes - sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); - sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned head - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svst1_u8(head_mask, (unsigned char *)target, value_vec); - target += head_length; - - // Aligned body loop - for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { - svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); - } - - // Handle unaligned tail - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svst1_u8(tail_mask, (unsigned char *)target, value_vec); - } -} + * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations + */ +#define sz_min_of_two(x, y) (x < y ? x : y) +#define sz_max_of_two(x, y) (x < y ? y : x) +#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) +#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_size_t vec_len = svcntb(); // Vector length in bytes +/** + * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: + * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; + * for (; a != min_end; ++a, ++b) + * if (*a != *b) return ordering_lookup[*a < *b]; + * That, however, introduces a data-dependency. + * A cleaner option is to perform two comparisons and a subtraction. + * One instruction more, but no data-dependency. + */ +#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, - // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - // - // int is_huge = length >= 4ull * 1024ull * 1024ull; - // - // When the buffer is small, there isn't much to innovate. - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations - // and handle the head, body, and tail separately. We can also traverse the buffer in both directions - // as Arm generally supports more simultaneous stores than x86 CPUs. - // - // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. - // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) - // we will pay a huge penalty on loads, fetching the same content many times. - // It may be better to allow caching (and subsequent eviction), in favor of using four-element - // tuples, wich will be guaranteed to be a multiple of a cache line. - // - // Another approach is to use the `LD4B` instructions, which will populate four registers at once. - // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. - else { - // Calculating head, body, and tail sizes depends on the `vec_len`, - // but it's runtime constant, and the modulo operation is expensive! - // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. - sz_size_t head_length = 16 - ((sz_size_t)target % 16); - sz_size_t tail_length = (sz_size_t)(target + length) % 16; - sz_size_t body_length = length - head_length - tail_length; +/** @brief Branchless minimum function for two signed 32-bit integers. */ +SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - // Handle unaligned parts - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); - svst1_u8(head_mask, (unsigned char *)target, head_data); - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); - svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); - target += head_length; - source += head_length; +/** @brief Branchless minimum function for two signed 32-bit integers. */ +SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - // Aligned body loop, walking in two directions - for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { - svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); - svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); - svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); - svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); - } - // Up to (vec_len * 2 - 1) bytes of data may be left in the body, - // so we can unroll the last two optional loop iterations. - if (body_length > vec_len) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - body_length -= vec_len; - source += body_length; - target += body_length; - } - if (body_length) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - } +/** + * @brief Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each byte signifies a match. + */ +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each byte is set. + // For that take the bottom 7 bits of each byte, add one to them, + // and if this sets the top bit to one, then all the 7 bits are ones as well. + vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); + return vec; } -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm SVE - -#pragma endregion - -/* - * @brief Pick the right implementation for the string search algorithms. +/** + * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. */ -#pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { +SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, + sz_size_t *normalized_offset, sz_size_t *normalized_length) { + // TODO: Remove branches. + // Normalize negative indices + if (start < 0) start += length; + if (end < 0) end += length; - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; + // Clamp indices to a valid range + if (start < 0) start = 0; + if (end < 0) end = 0; + if (start > (sz_ssize_t)length) start = length; + if (end > (sz_ssize_t)length) end = length; - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ + // Ensure start <= end + if (start > end) start = end; - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); + *normalized_offset = start; + *normalized_length = end - start; } -#if !SZ_DYNAMIC_DISPATCH - -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 - return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif +/** + * @brief Compute the logarithm base 2 of a positive integer, rounding down. + */ +SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { + sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); + sz_size_t leading_zeros = sz_u64_clz(x); + return 63 - leading_zeros; } -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 - return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); +/** + * @brief Compute the smallest power of two greater than or equal to ::x. + */ +SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { + // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. + // https://stackoverflow.com/a/10143264 + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; +#if _SZ_IS_64_BIT + x |= x >> 32; #endif + x++; + return x; } -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif +/** + * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. + * + * There is a well known SWAR sequence for that known to chess programmers, + * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. + * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating + * https://lukas-prokop.at/articles/2021-07-23-transpose + */ +SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { + sz_u64_t t; + t = x ^ (x << 36); + x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); + t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); + x ^= t ^ (t >> 18); + t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); + x ^= t ^ (t >> 9); + return x; } -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_copy_neon(target, source, length); -#else - sz_copy_serial(target, source, length); -#endif +/** + * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. + */ +SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { + sz_u64_t t = *a; + *a = *b; + *b = t; } -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 - sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 - sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON - sz_move_neon(target, source, length); -#else - sz_move_serial(target, source, length); -#endif +/** + * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. + */ +SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { + void *t = *a; + *a = *b; + *b = t; } -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 - sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 - sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON - sz_fill_neon(target, length, value); +/** + * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. + */ +SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { +#if !SZ_USE_MISALIGNED_LOADS + sz_u16_vec_t result; + result.u8s[0] = ptr[0]; + result.u8s[1] = ptr[1]; + return result; +#elif defined(_MSC_VER) && !defined(__clang__) +#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. + return *((sz_u16_vec_t *)ptr); #else - sz_fill_serial(target, length, value); + return *((__unaligned sz_u16_vec_t *)ptr); #endif -} - -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON - sz_look_up_transform_neon(source, length, lut, target); #else - sz_look_up_transform_serial(source, length, lut, target); + __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; + return *result; #endif } -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_find_byte_neon(haystack, h_length, needle); +/** + * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. + */ +SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { +#if !SZ_USE_MISALIGNED_LOADS + sz_u32_vec_t result; + result.u8s[0] = ptr[0]; + result.u8s[1] = ptr[1]; + result.u8s[2] = ptr[2]; + result.u8s[3] = ptr[3]; + return result; +#elif defined(_MSC_VER) && !defined(__clang__) +#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. + return *((sz_u32_vec_t *)ptr); #else - return sz_find_byte_serial(haystack, h_length, needle); + return *((__unaligned sz_u32_vec_t *)ptr); #endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 - return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); #else - return sz_rfind_byte_serial(haystack, h_length, needle); + __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; + return *result; #endif } -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_find_neon(haystack, h_length, needle, n_length); +/** + * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. + */ +SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { +#if !SZ_USE_MISALIGNED_LOADS + sz_u64_vec_t result; + result.u8s[0] = ptr[0]; + result.u8s[1] = ptr[1]; + result.u8s[2] = ptr[2]; + result.u8s[3] = ptr[3]; + result.u8s[4] = ptr[4]; + result.u8s[5] = ptr[5]; + result.u8s[6] = ptr[6]; + result.u8s[7] = ptr[7]; + return result; +#elif defined(_MSC_VER) && !defined(__clang__) +#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. + return *((sz_u64_vec_t *)ptr); #else - return sz_find_serial(haystack, h_length, needle, n_length); + return *((__unaligned sz_u64_vec_t *)ptr); #endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); #else - return sz_rfind_serial(haystack, h_length, needle, n_length); + __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; + return *result; #endif } -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif +/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ +SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { + sz_size_t capacity; + sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); + sz_size_t consumed_capacity = sizeof(sz_size_t); + if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; + return (sz_ptr_t)handle + consumed_capacity; } -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif +/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ +SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { + sz_unused(start && length && handle); } -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); +/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ +SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { + sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; + sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; + sz_size_t fingerprint_bytes = fingerprint_buffer->length; + fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); + sz_unused(start && length); } -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); +/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ +SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, + void *handle) { + sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; + sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; + sz_size_t fingerprint_bytes = fingerprint_buffer->length; + fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); + sz_unused(start && length); } -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); -#else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); -#endif +/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ +SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, + void *scalar_handle) { + sz_unused(start && length && hash && scalar_handle); + sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; + *scalar_ptr ^= hash; } -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} +#pragma GCC visibility pop +#pragma endregion -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); -#else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); -#endif -} +#pragma region Serial Implementation -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); -#else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); -#endif -} +#if !SZ_AVOID_LIBC +#include // `fprintf` +#include // `malloc`, `EXIT_FAILURE` -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); +SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { + sz_unused(handle); + return malloc(length); } - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); +SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { + sz_unused(handle && length); + free(start); } -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} +#endif -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); +SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { +#if !SZ_AVOID_LIBC + alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; + alloc->free = (sz_memory_free_t)_sz_memory_free_default; +#else + alloc->allocate = (sz_memory_allocate_t)SZ_NULL; + alloc->free = (sz_memory_free_t)SZ_NULL; +#endif + alloc->handle = SZ_NULL; } -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); +SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { + // The logic here is simple - put the buffer length in the first slots of the buffer. + // Later use it for bounds checking. + alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; + alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; + alloc->handle = &buffer; + sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); } -#endif #pragma endregion #ifdef __cplusplus @@ -7153,4 +1108,4 @@ SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_ } #endif // __cplusplus -#endif // STRINGZILLA_H_ +#endif // STRINGZILLA_TYPES_H_ From 5f7ca590428e13f6a92e2341b64506f535266ba1 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:24:05 +0000 Subject: [PATCH 041/751] Fix: Minor macro mismatches --- .vscode/settings.json | 6 ++++ include/stringzilla/drafts.h | 4 +-- include/stringzilla/hash.h | 18 ++++++---- include/stringzilla/memory.h | 11 +----- include/stringzilla/similarity.h | 22 +++++++----- include/stringzilla/types.h | 38 ++++++++++++-------- scripts/bench_memory.cpp | 30 ++++++++-------- scripts/bench_search.cpp | 60 +++++++++++++++++--------------- scripts/bench_similarity.cpp | 2 +- scripts/bench_token.cpp | 27 +++++++------- scripts/test.cpp | 16 ++++----- 11 files changed, 125 insertions(+), 109 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 980956d1..ee1f1d3b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -27,6 +27,7 @@ "Baeza", "basicsize", "bigram", + "bigrams", "bioinformaticians", "bioinformatics", "Bitap", @@ -50,6 +51,7 @@ "getslice", "Giancarlo", "Gonnet", + "Haswell", "Heikki", "hexdigits", "Hirschberg's", @@ -102,6 +104,7 @@ "readlines", "releasebuffer", "rfind", + "rfinds", "richcompare", "Ritchie", "rmatcher", @@ -111,11 +114,13 @@ "rsplits", "rstrip", "SIMD", + "Skylake", "splitlines", "ssize", "startswith", "STL", "stringzilla", + "stringzillite", "Strs", "strzl", "substr", @@ -129,6 +134,7 @@ "unpoison", "usecases", "Vardanian", + "VBMI", "vectorcallfunc", "Wagner", "whitespaces", diff --git a/include/stringzilla/drafts.h b/include/stringzilla/drafts.h index bcba2233..1817a81e 100644 --- a/include/stringzilla/drafts.h +++ b/include/stringzilla/drafts.h @@ -476,7 +476,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz #endif // SZ_USE_AVX512 -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON SZ_PUBLIC sz_cptr_t sz_find_neon_too_smart(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { @@ -946,7 +946,7 @@ SZ_PUBLIC void sz_hashes_neon_readahead(sz_cptr_t start, sz_size_t length, sz_si } } -#endif // SZ_USE_ARM_NEON +#endif // SZ_USE_NEON #ifdef __cplusplus } // extern "C" diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index bf24a5e6..d8f4a05e 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -74,8 +74,9 @@ SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); * @param callback_handle Optional user-provided pointer to be passed to the `callback`. * @see sz_hashes_fingerprint, sz_hashes_intersection */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); +SZ_DYNAMIC void sz_hashes( // + sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // + sz_hash_callback_t callback, void *callback_handle); /** * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. @@ -140,14 +141,19 @@ SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t /** @copydoc sz_checksum */ SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); + /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); + /** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); +SZ_PUBLIC void sz_generate_serial( // + sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, + void *generator); + /** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); +SZ_PUBLIC void sz_hashes_serial( // + sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // + sz_hash_callback_t callback, void *callback_handle); #pragma endregion // Core API diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index 87957878..32106a82 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -9,7 +9,7 @@ * - `sz_move` - analog to `memmove` * - `sz_fill` - analog to `memset` * - `sz_look_up_transform` - LUT transformation of a string, similar to OpenCV LUT - * - `sz_detect_encoding` - similar to `iconv` or `chardet` + * - TODO: `sz_detect_encoding` - similar to `iconv` or `chardet` * * Convenience functions for character-set mapping: * @@ -149,15 +149,6 @@ SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); */ SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); -/** - * @brief Checks if all characters in the range are valid ASCII characters. - * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. - */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - #pragma endregion // Helper API #pragma region Serial Implementation diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index e811fefe..ef34b824 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -150,6 +150,15 @@ SZ_DYNAMIC sz_ssize_t sz_alignment_score( // sz_error_cost_t const *subs, sz_error_cost_t gap, // sz_memory_allocator_t *alloc); +/** + * @brief Checks if all characters in the range are valid ASCII characters. + * + * @param text String to be analyzed. + * @param length Number of bytes in the string. + * @return Whether all characters are valid ASCII characters. + */ +SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); + /** @copydoc sz_hamming_distance */ SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); @@ -707,9 +716,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return SZ_SIZE_MAX; } // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. @@ -740,9 +747,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return SZ_SIZE_MAX; } // Now let's handle the bottom right triangle. @@ -766,9 +771,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return SZ_SIZE_MAX; + // In every following iterations we take use a shorter prefix of each register, // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index a39620e6..be4a3e0d 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -3,18 +3,26 @@ * @file types.h * @author Ash Vardanian * - * Consider overriding the following macros to customize the library: + * Includes the following types: * - * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. - * - `SZ_AVOID_LIBC=0` - whether to avoid including the standard C library headers. - * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. - * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. - * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_HASWELL=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_SKYLAKE=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_ICE=?` - whether to use AVX-512 VBMI instructions on x86_64. - * - `SZ_USE_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_SVE=?` - whether to use SVE and SVE2 instructions on ARM. + * - `sz_u8_t`, `sz_u16_t`, `sz_u32_t`, `sz_u64_t` - unsigned integers of 8, 16, 32, and 64 bits. + * - `sz_i8_t`, `sz_i16_t`, `sz_i32_t`, `sz_i64_t` - signed integers of 8, 16, 32, and 64 bits. + * - `sz_size_t`, `sz_ssize_t` - unsigned and signed integers of the same size as a pointer. + * - `sz_ptr_t`, `sz_cptr_t` - pointer and constant pointer to a C-style string. + * - `sz_bool_t` - boolean type, `sz_true_k` and `sz_false_k` constants. + * - `sz_ordering_t` - for comparison results, `sz_less_k`, `sz_equal_k`, `sz_greater_k`. + * - @b `sz_u8_vec_t`, `sz_u16_vec_t`, `sz_u32_vec_t`, `sz_u64_vec_t` - @b SWAR vector types. + * - @b `sz_u128_vec_t`, `sz_u256_vec_t`, `sz_u512_vec_t` - @b SIMD vector types for x86 and Arm. + * - @b `sz_rune_t` - for 32-bit Unicode code points ~ @b runes. + * - `sz_rune_length_t` - to describe the number of bytes in a UTF8-encoded rune. + * - `sz_error_cost_t` - for substitution costs in string alignment and scoring algorithms. + * + * The library also defines the following higher-level structures: + * + * - `sz_string_view_t` - for a C-style `std::string_view`-like structure. + * - `sz_memory_allocator_t` - a wrapper for memory-management functions. + * - `sz_sequence_t` - a wrapper to access strings forming a sequential container. + * - `sz_charset_t` - a bitset for 256 possible byte values. */ #ifndef STRINGZILLA_TYPES_H_ #define STRINGZILLA_TYPES_H_ @@ -864,8 +872,8 @@ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) /** * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { +SZ_INTERNAL void sz_ssize_clamp_interval( // + sz_size_t length, sz_ssize_t start, sz_ssize_t end, sz_size_t *normalized_offset, sz_size_t *normalized_length) { // TODO: Remove branches. // Normalize negative indices if (start < 0) start += length; @@ -1023,7 +1031,7 @@ SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { /** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); + *(sz_ptr_t)&capacity = *(sz_cptr_t)handle; sz_size_t consumed_capacity = sizeof(sz_size_t); if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; return (sz_ptr_t)handle + consumed_capacity; @@ -1098,7 +1106,7 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); + *(sz_ptr_t)buffer = *(sz_cptr_t)&length; } #pragma endregion diff --git a/scripts/bench_memory.cpp b/scripts/bench_memory.cpp index d8131102..ee6ae03b 100644 --- a/scripts/bench_memory.cpp +++ b/scripts/bench_memory.cpp @@ -69,16 +69,16 @@ tracked_unary_functions_t copy_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o tracked_unary_functions_t result = { {"memcpy" + suffix, wrap_sz(memcpy)}, {"sz_copy_serial" + suffix, wrap_sz(sz_copy_serial)}, -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE {"sz_copy_avx512" + suffix, wrap_sz(sz_copy_avx512)}, #endif -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL {"sz_copy_avx2" + suffix, wrap_sz(sz_copy_avx2)}, #endif -#if SZ_USE_ARM_SVE +#if SZ_USE_SVE {"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve)}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_copy_neon" + suffix, wrap_sz(sz_copy_neon)}, #endif }; @@ -109,16 +109,16 @@ tracked_unary_functions_t fill_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o return slice.size(); })}, {"sz_fill_serial", wrap_sz(sz_fill_serial)}, -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE {"sz_fill_avx512", wrap_sz(sz_fill_avx512)}, #endif -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL {"sz_fill_avx2", wrap_sz(sz_fill_avx2)}, #endif -#if SZ_USE_ARM_SVE +#if SZ_USE_SVE {"sz_fill_sve", wrap_sz(sz_fill_sve)}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_fill_neon", wrap_sz(sz_fill_neon)}, #endif }; @@ -149,13 +149,13 @@ tracked_unary_functions_t move_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o tracked_unary_functions_t result = { {"memmove" + suffix, wrap_sz(memmove)}, {"sz_move_serial" + suffix, wrap_sz(sz_move_serial)}, -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE {"sz_move_avx512" + suffix, wrap_sz(sz_move_avx512)}, #endif -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL {"sz_move_avx2" + suffix, wrap_sz(sz_move_avx2)}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_move_neon" + suffix, wrap_sz(sz_move_neon)}, #endif }; @@ -192,13 +192,13 @@ tracked_unary_functions_t transform_functions() { return slice.size(); })}, {"sz_look_up_transform_serial", wrap_sz(sz_look_up_transform_serial)}, -#if SZ_USE_X86_AVX512 - {"sz_look_up_transform_avx512", wrap_sz(sz_look_up_transform_avx512)}, +#if SZ_USE_ICE + {"sz_look_up_transform_ice", wrap_sz(sz_look_up_transform_ice)}, #endif -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL {"sz_look_up_transform_avx2", wrap_sz(sz_look_up_transform_avx2)}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_look_up_transform_neon", wrap_sz(sz_look_up_transform_neon)}, #endif }; diff --git a/scripts/bench_search.cpp b/scripts/bench_search.cpp index ada4ded4..7380a697 100644 --- a/scripts/bench_search.cpp +++ b/scripts/bench_search.cpp @@ -29,13 +29,13 @@ tracked_binary_functions_t find_functions() { return (match == std::string_view::npos ? h.size() : match); }}, {"sz_find_serial", wrap_sz(sz_find_serial), true}, -#if SZ_USE_X86_AVX512 - {"sz_find_avx512", wrap_sz(sz_find_avx512), true}, +#if SZ_USE_SKYLAKE + {"sz_find_skylake", wrap_sz(sz_find_skylake), true}, #endif -#if SZ_USE_X86_AVX2 - {"sz_find_avx2", wrap_sz(sz_find_avx2), true}, +#if SZ_USE_HASWELL + {"sz_find_haswell", wrap_sz(sz_find_haswell), true}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_find_neon", wrap_sz(sz_find_neon), true}, #endif {"strstr/strchr", @@ -90,13 +90,13 @@ tracked_binary_functions_t rfind_functions() { return (match == std::string_view::npos ? 0 : match); }}, {"sz_rfind_serial", wrap_sz(sz_rfind_serial), true}, -#if SZ_USE_X86_AVX512 - {"sz_rfind_avx512", wrap_sz(sz_rfind_avx512), true}, +#if SZ_USE_SKYLAKE + {"sz_rfind_skylake", wrap_sz(sz_rfind_skylake), true}, #endif -#if SZ_USE_X86_AVX2 - {"sz_rfind_avx2", wrap_sz(sz_rfind_avx2), true}, +#if SZ_USE_HASWELL + {"sz_rfind_haswell", wrap_sz(sz_rfind_haswell), true}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_rfind_neon", wrap_sz(sz_rfind_neon), true}, #endif {"std::search", @@ -140,13 +140,13 @@ tracked_binary_functions_t find_charset_functions() { return (match == std::string_view::npos ? h.size() : match); }}, {"sz_find_charset_serial", wrap_sz(sz_find_charset_serial), true}, -#if SZ_USE_X86_AVX2 - {"sz_find_charset_avx2", wrap_sz(sz_find_charset_avx2), true}, +#if SZ_USE_HASWELL + {"sz_find_charset_haswell", wrap_sz(sz_find_charset_haswell), true}, #endif -#if SZ_USE_X86_AVX512 - {"sz_find_charset_avx512", wrap_sz(sz_find_charset_avx512), true}, +#if SZ_USE_ICE + {"sz_find_charset_ice", wrap_sz(sz_find_charset_ice), true}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_find_charset_neon", wrap_sz(sz_find_charset_neon), true}, #endif {"strcspn", [](std::string_view h, std::string_view n) { return strcspn(h.data(), n.data()); }}, @@ -171,10 +171,10 @@ tracked_binary_functions_t rfind_charset_functions() { return (match == std::string_view::npos ? 0 : match); }}, {"sz_rfind_charset_serial", wrap_sz(sz_rfind_charset_serial), true}, -#if SZ_USE_X86_AVX512 - {"sz_rfind_charset_avx512", wrap_sz(sz_rfind_charset_avx512), true}, +#if SZ_USE_ICE + {"sz_rfind_charset_ice", wrap_sz(sz_rfind_charset_ice), true}, #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_rfind_charset_neon", wrap_sz(sz_rfind_charset_neon), true}, #endif }; @@ -184,8 +184,8 @@ tracked_binary_functions_t rfind_charset_functions() { /** * @brief Evaluation for search string operations: find. */ -void bench_finds(std::string const &haystack, std::vector const &strings, - tracked_binary_functions_t &&variants) { +void bench_finds( // + std::string const &haystack, std::vector const &strings, tracked_binary_functions_t &&variants) { for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { auto &variant = variants[variant_idx]; @@ -234,8 +234,8 @@ void bench_finds(std::string const &haystack, std::vector const &st /** * @brief Evaluation for reverse order search string operations: find. */ -void bench_rfinds(std::string const &haystack, std::vector const &strings, - tracked_binary_functions_t &&variants) { +void bench_rfinds( // + std::string const &haystack, std::vector const &strings, tracked_binary_functions_t &&variants) { for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { auto &variant = variants[variant_idx]; @@ -336,15 +336,17 @@ int main(int argc, char const **argv) { bench_search(dataset.text, filter_by_length(dataset.tokens, token_length)); } - // Run bechnmarks on abstract tokens of different length + // Run benchmarks on abstract tokens of different length for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { std::printf("Benchmarking for missing tokens of length %zu:\n", token_length); - bench_search(dataset.text, std::vector { - std::string(token_length, '\1'), - std::string(token_length, '\2'), - std::string(token_length, '\3'), - std::string(token_length, '\4'), - }); + bench_search( // + dataset.text, // + std::vector { + std::string(token_length, '\1'), + std::string(token_length, '\2'), + std::string(token_length, '\3'), + std::string(token_length, '\4'), + }); } std::printf("All benchmarks passed.\n"); diff --git a/scripts/bench_similarity.cpp b/scripts/bench_similarity.cpp index b2c36a60..140433e2 100644 --- a/scripts/bench_similarity.cpp +++ b/scripts/bench_similarity.cpp @@ -54,7 +54,7 @@ tracked_binary_functions_t distance_functions() { {"naive", wrap_baseline}, {"sz_edit_distance", wrap_sz_distance(sz_edit_distance_serial), true}, {"sz_alignment_score", wrap_sz_scoring(sz_alignment_score_serial), true}, -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE {"sz_edit_distance_avx512", wrap_sz_distance(sz_edit_distance_avx512), true}, {"sz_alignment_score_avx512", wrap_sz_scoring(sz_alignment_score_avx512), true}, #endif diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index f699f459..1120ad52 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -22,13 +22,12 @@ tracked_unary_functions_t checksum_functions() { [](std::size_t sum, char c) { return sum + static_cast(c); }); }}, {"sz_checksum_serial", wrap_sz(sz_checksum_serial), true}, -#if SZ_USE_X86_AVX2 - {"sz_checksum_avx2", wrap_sz(sz_checksum_avx2), true}, +#if SZ_USE_HASWELL + {"sz_checksum_haswell", wrap_sz(sz_checksum_haswell), true}, #endif -#if SZ_USE_X86_AVX512 - {"sz_checksum_avx512", wrap_sz(sz_checksum_avx512), true}, +#if SZ_USE_ICE #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON {"sz_checksum_neon", wrap_sz(sz_checksum_neon), true}, #endif }; @@ -56,11 +55,11 @@ tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, st }; std::string suffix = std::to_string(window_width) + ":step" + std::to_string(step); tracked_unary_functions_t result = { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE {"sz_hashes_avx512:" + suffix, wrap_sz(sz_hashes_avx512)}, #endif -#if SZ_USE_X86_AVX2 - {"sz_hashes_avx2:" + suffix, wrap_sz(sz_hashes_avx2)}, +#if SZ_USE_HASWELL + {"sz_hashes_haswell:" + suffix, wrap_sz(sz_hashes_haswell)}, #endif {"sz_hashes_serial:" + suffix, wrap_sz(sz_hashes_serial)}, }; @@ -118,10 +117,10 @@ tracked_binary_functions_t equality_functions() { tracked_binary_functions_t result = { {"std::string_view.==", [](std::string_view a, std::string_view b) { return (a == b); }}, {"sz_equal_serial", wrap_sz(sz_equal_serial), true}, -#if SZ_USE_X86_AVX2 - {"sz_equal_avx2", wrap_sz(sz_equal_avx2), true}, +#if SZ_USE_HASWELL + {"sz_equal_haswell", wrap_sz(sz_equal_haswell), true}, #endif -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE {"sz_equal_avx512", wrap_sz(sz_equal_avx512), true}, #endif {"memcmp", @@ -145,10 +144,10 @@ tracked_binary_functions_t ordering_functions() { return (order == 0 ? sz_equal_k : (order < 0 ? sz_less_k : sz_greater_k)); }}, {"sz_order_serial", wrap_sz(sz_order_serial), true}, -#if SZ_USE_X86_AVX2 - {"sz_order_avx2", wrap_sz(sz_order_avx2), true}, +#if SZ_USE_HASWELL + {"sz_order_haswell", wrap_sz(sz_order_haswell), true}, #endif -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE {"sz_order_avx512", wrap_sz(sz_order_avx512), true}, #endif {"memcmp", diff --git a/scripts/test.cpp b/scripts/test.cpp index eecc97f0..db856a8e 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -11,10 +11,10 @@ // Those parameters must never be explicitly set during releases, // but they come handy during development, if you want to validate // different ISA-specific implementations. -// #define SZ_USE_X86_AVX2 0 -// #define SZ_USE_X86_AVX512 0 -// #define SZ_USE_ARM_NEON 0 -// #define SZ_USE_ARM_SVE 0 +// #define SZ_USE_HASWELL 0 +// #define SZ_USE_ICE 0 +// #define SZ_USE_NEON 0 +// #define SZ_USE_SVE 0 #define SZ_DEBUG 1 // Enforce aggressive logging for this unit. // Put this at the top to make sure it pulls all the right dependencies @@ -1576,10 +1576,10 @@ int main(int argc, char const **argv) { // Let's greet the user nicely sz_unused(argc && argv); std::printf("Hi, dear tester! You look nice today!\n"); - std::printf("- Uses AVX2: %s \n", SZ_USE_X86_AVX2 ? "yes" : "no"); - std::printf("- Uses AVX512: %s \n", SZ_USE_X86_AVX512 ? "yes" : "no"); - std::printf("- Uses NEON: %s \n", SZ_USE_ARM_NEON ? "yes" : "no"); - std::printf("- Uses SVE: %s \n", SZ_USE_ARM_SVE ? "yes" : "no"); + std::printf("- Uses AVX2: %s \n", SZ_USE_HASWELL ? "yes" : "no"); + std::printf("- Uses AVX512: %s \n", SZ_USE_ICE ? "yes" : "no"); + std::printf("- Uses NEON: %s \n", SZ_USE_NEON ? "yes" : "no"); + std::printf("- Uses SVE: %s \n", SZ_USE_SVE ? "yes" : "no"); // Basic utilities test_arithmetical_utilities(); From 41e59179629f20657c510d8cbe48b8ceaf92be39 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:24:25 +0000 Subject: [PATCH 042/751] Fix: Partially filter `stringzilla.h` file --- include/stringzilla/stringzilla.h | 909 ++++-------------------------- 1 file changed, 125 insertions(+), 784 deletions(-) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index de7fbcac..c0b1b369 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -2,23 +2,37 @@ * @brief StringZilla is a collection of advanced string algorithms, designed to be used in Big Data applications. * It is generally faster than LibC, and has a broader & cleaner interface, and targets modern x86 CPUs * with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. + * @file stringzilla.h + * @author Ash Vardanian + * + * @see StringZilla docs: https://github.com/ashvardanian/StringZilla/blob/main/README.md + * @see LibC string docs: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html + * + * @section Introduction + * + * + * @section Compilation Settings * * Consider overriding the following macros to customize the library: * * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. + * - `SZ_AVOID_LIBC=0` - whether to avoid including the standard C library headers. * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. + * + * Performance tuning: + * * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. - * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. - * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. - * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. + * - `SZ_CACHE_LINE_WIDTH=64` - cache-line width that affects the execution of some algorithms. + * - `SZ_CACHE_SIZE=1048576` - the combined size of L1d and L2 caches in bytes, affecting temporal loads. * - * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md - * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html + * Different generations of CPUs and SIMD capabilities can be enabled or disabled with the following macros: * - * @file stringzilla.h - * @author Ash Vardanian + * - `SZ_USE_HASWELL=?` - whether to use AVX2 instructions on x86_64. + * - `SZ_USE_SKYLAKE=?` - whether to use AVX-512 instructions on x86_64. + * - `SZ_USE_ICE=?` - whether to use AVX-512 VBMI instructions on x86_64. + * - `SZ_USE_NEON=?` - whether to use NEON instructions on ARM. + * - `SZ_USE_SVE=?` - whether to use SVE and SVE2 instructions on ARM. */ #ifndef STRINGZILLA_H_ #define STRINGZILLA_H_ @@ -27,229 +41,10 @@ #define STRINGZILLA_VERSION_MINOR 11 #define STRINGZILLA_VERSION_PATCH 0 -/** - * @brief When set to 1, the library will include the following LibC headers: and . - * In debug builds (SZ_DEBUG=1), the library will also include and . - * - * You may want to disable this compiling for use in the kernel, or in embedded systems. - * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. - * https://artificial-mind.net/projects/compile-health/ - */ -#ifndef SZ_AVOID_LIBC -#define SZ_AVOID_LIBC (0) // true or false -#endif - -/** - * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address - * that is not divisible by eight. On x86 enabled by default. On ARM it's not. - * - * Most platforms support it, but there is no industry standard way to check for those. - * This value will mostly affect the performance of the serial (SWAR) backend. - */ -#ifndef SZ_USE_MISALIGNED_LOADS -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) -#define SZ_USE_MISALIGNED_LOADS (1) // true or false -#else -#define SZ_USE_MISALIGNED_LOADS (0) // true or false -#endif -#endif - -/** - * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. - * So the `sz_find` function will invoke the most advanced backend supported by the CPU, - * that runs the program, rather than the most advanced backend supported by the CPU - * used to compile the library or the downstream application. - */ -#ifndef SZ_DYNAMIC_DISPATCH -#define SZ_DYNAMIC_DISPATCH (0) // true or false -#endif - -/** - * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. - * 64-bit on most platforms where pointers are 64-bit. - * 32-bit on platforms where pointers are 32-bit. - */ -#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) -#define SZ_DETECT_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. -#else -#define SZ_DETECT_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. -#endif - -/** - * @brief On Big-Endian machines StringZilla will work in compatibility mode. - * This disables SWAR hacks to minimize code duplication, assuming practically - * all modern popular platforms are Little-Endian. - * - * This variable is hard to infer from macros reliably. It's best to set it manually. - * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). - * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. - * https://stackoverflow.com/a/27054190 - */ -#ifndef SZ_DETECT_BIG_ENDIAN -#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ - defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) -#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture -#else -#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture -#endif -#endif - -/* - * Debugging and testing. - */ -#ifndef SZ_DEBUG -#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". -#define SZ_DEBUG (1) -#else -#define SZ_DEBUG (0) -#endif -#endif - -/** - * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. - * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. - * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. - */ -#ifndef SZ_SWAR_THRESHOLD -#if SZ_DEBUG -#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds -#else -#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds -#endif -#endif - -/* Annotation for the public API symbols: - * - * - `SZ_PUBLIC` is used for functions that are part of the public API. - * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. - * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. - */ -#ifndef SZ_DYNAMIC -#if SZ_DYNAMIC_DISPATCH -#if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_DYNAMIC __declspec(dllexport) -#define SZ_EXTERNAL __declspec(dllimport) -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#else -#define SZ_DYNAMIC __attribute__((visibility("default"))) -#define SZ_EXTERNAL extern -#define SZ_PUBLIC __attribute__((unused)) inline static -#define SZ_INTERNAL __attribute__((always_inline)) inline static -#endif // _WIN32 || __CYGWIN__ -#else -#define SZ_DYNAMIC inline static -#define SZ_EXTERNAL extern -#define SZ_PUBLIC inline static -#define SZ_INTERNAL inline static -#endif // SZ_DYNAMIC_DISPATCH -#endif // SZ_DYNAMIC - -/** - * @brief Alignment macro for 64-byte alignment. - */ -#if defined(_MSC_VER) -#define SZ_ALIGN64 __declspec(align(64)) -#elif defined(__GNUC__) || defined(__clang__) -#define SZ_ALIGN64 __attribute__((aligned(64))) -#else -#define SZ_ALIGN64 -#endif - #ifdef __cplusplus extern "C" { #endif -/* - * Let's infer the integer types or pull them from LibC, - * if that is allowed by the user. - */ -#if !SZ_AVOID_LIBC -#include // `size_t` -#include // `uint8_t` -typedef int8_t sz_i8_t; // Always 8 bits -typedef uint8_t sz_u8_t; // Always 8 bits -typedef uint16_t sz_u16_t; // Always 16 bits -typedef int32_t sz_i32_t; // Always 32 bits -typedef uint32_t sz_u32_t; // Always 32 bits -typedef uint64_t sz_u64_t; // Always 64 bits -typedef int64_t sz_i64_t; // Always 64 bits -typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits -typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits - -#else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. -typedef signed char sz_i8_t; // Always 8 bits -typedef unsigned char sz_u8_t; // Always 8 bits -typedef unsigned short sz_u16_t; // Always 16 bits -typedef int sz_i32_t; // Always 32 bits -typedef unsigned int sz_u32_t; // Always 32 bits -typedef long long sz_i64_t; // Always 64 bits -typedef unsigned long long sz_u64_t; // Always 64 bits - -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models -#if SZ_DETECT_64_BIT -typedef unsigned long long sz_size_t; // 64-bit. -typedef long long sz_ssize_t; // 64-bit. -#else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. -#endif // SZ_DETECT_64_BIT - -#endif // SZ_AVOID_LIBC - -/** - * @brief Compile-time assert macro similar to `static_assert` in C++. - */ -#define sz_static_assert(condition, name) \ - typedef struct { \ - int static_assert_##name : (condition) ? 1 : -1; \ - } sz_static_assert_##name##_t - -sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); -sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); - -#pragma region Public API - -typedef char *sz_ptr_t; // A type alias for `char *` -typedef char const *sz_cptr_t; // A type alias for `char const *` -typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions - -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings - -typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit -typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> - -/** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. - */ -typedef struct sz_string_view_t { - sz_cptr_t start; - sz_size_t length; -} sz_string_view_t; - /** * @brief Enumeration of SIMD capabilities of the target architecture. * Used to introspect the supported functionality of the dynamic library. @@ -277,176 +72,6 @@ typedef enum sz_capability_t { */ SZ_DYNAMIC sz_capability_t sz_capabilities(void); -/** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert - */ -typedef union sz_charset_t { - sz_u64_t _u64s[4]; - sz_u32_t _u32s[8]; - sz_u16_t _u16s[16]; - sz_u8_t _u8s[32]; -} sz_charset_t; - -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } - -/** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } - -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast - -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { - // Checking the bit can be done in different ways: - // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 - // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 - // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 - // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 - return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); -} - -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast -} - -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { - s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // - s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; -} - -typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); -typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); - -/** - * @brief Some complex pattern matching algorithms may require memory allocations. - * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed - */ -typedef struct sz_memory_allocator_t { - sz_memory_allocate_t allocate; - sz_memory_free_t free; - void *handle; -} sz_memory_allocator_t; - -/** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. - */ -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); - -/** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. - * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. - */ -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); - -/** - * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. - * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. - */ -#ifdef SZ_STRING_INTERNAL_SPACE -#undef SZ_STRING_INTERNAL_SPACE -#endif -#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length - -/** - * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). - * Differs in layout from Folly, Clang, GCC, and probably most other implementations. - * It's designed to avoid any branches on read-only operations, and can store up - * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. - * - * @section Changing Length - * - * One nice thing about this design, is that you can, in many cases, change the length of the string - * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, - * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, - * only changing the last byte containing the length. - */ -typedef union sz_string_t { - -#if !SZ_DETECT_BIG_ENDIAN - - struct external { - sz_ptr_t start; - sz_size_t length; - sz_size_t space; - sz_size_t padding; - } external; - - struct internal { - sz_ptr_t start; - sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; - } internal; - -#else - - struct external { - sz_ptr_t start; - sz_size_t space; - sz_size_t padding; - sz_size_t length; - } external; - - struct internal { - sz_ptr_t start; - char chars[SZ_STRING_INTERNAL_SPACE]; - sz_u8_t length; - } internal; - -#endif - - sz_size_t words[4]; - -} sz_string_t; - -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); -typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); - -/** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. - * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. - * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); - /** * @brief Checks if two string are equal. * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. @@ -480,139 +105,6 @@ SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, /** @copydoc sz_order */ SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Checks if all characters in the range are valid ASCII characters. - * - * @param text String to be analyzed. - * @param length Number of bytes in the string. - * @return Whether all characters are valid ASCII characters. - */ -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); - -/** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. - * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. - */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** - * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** - * @brief Similar to `memmove`, copies (moves) contents of one string into another. - * Unlike `sz_copy`, allows overlapping strings as arguments. - * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. - */ -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); - -typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); - -/** - * @brief Similar to `memset`, fills a string with a given value. - * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. - */ -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); - -typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); - /** * @brief Initializes a string class instance to an empty value. */ @@ -1154,62 +646,62 @@ SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t l * Hardware feature detection. * All of those can be controlled by the user. */ -#ifndef SZ_USE_X86_AVX512 +#ifndef SZ_USE_ICE #ifdef __AVX512BW__ -#define SZ_USE_X86_AVX512 1 +#define SZ_USE_ICE 1 #else -#define SZ_USE_X86_AVX512 0 +#define SZ_USE_ICE 0 #endif #endif -#ifndef SZ_USE_X86_AVX2 +#ifndef SZ_USE_HASWELL #ifdef __AVX2__ -#define SZ_USE_X86_AVX2 1 +#define SZ_USE_HASWELL 1 #else -#define SZ_USE_X86_AVX2 0 +#define SZ_USE_HASWELL 0 #endif #endif -#ifndef SZ_USE_ARM_NEON +#ifndef SZ_USE_NEON #ifdef __ARM_NEON -#define SZ_USE_ARM_NEON 1 +#define SZ_USE_NEON 1 #else -#define SZ_USE_ARM_NEON 0 +#define SZ_USE_NEON 0 #endif #endif -#ifndef SZ_USE_ARM_SVE +#ifndef SZ_USE_SVE #ifdef __ARM_FEATURE_SVE -#define SZ_USE_ARM_SVE 1 +#define SZ_USE_SVE 1 #else -#define SZ_USE_ARM_SVE 0 +#define SZ_USE_SVE 0 #endif #endif /* * Include hardware-specific headers. */ -#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 +#if SZ_USE_ICE || SZ_USE_HASWELL #include #endif // SZ_USE_X86... -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON #if !defined(_MSC_VER) #include #endif #include -#endif // SZ_USE_ARM_NEON -#if SZ_USE_ARM_SVE +#endif // SZ_USE_NEON +#if SZ_USE_SVE #if !defined(_MSC_VER) #include #endif -#endif // SZ_USE_ARM_SVE +#endif // SZ_USE_SVE #pragma region Hardware Specific API -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE /** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); /** @copydoc sz_copy */ @@ -1219,19 +711,19 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); /** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); +SZ_PUBLIC void sz_look_up_transform_ice(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_edit_distance */ SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc); @@ -1244,7 +736,7 @@ SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t wind sz_hash_callback_t callback, void *callback_handle); #endif -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ @@ -1270,7 +762,7 @@ SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window sz_hash_callback_t callback, void *callback_handle); #endif -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ @@ -1297,7 +789,7 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_ch SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); #endif -#if SZ_USE_ARM_SVE +#if SZ_USE_SVE /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ @@ -1554,7 +1046,7 @@ SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { x |= x >> 4; x |= x >> 8; x |= x >> 16; -#if SZ_DETECT_64_BIT +#if _SZ_IS_64_BIT x |= x >> 32; #endif x++; @@ -1740,79 +1232,6 @@ SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size *scalar_ptr ^= hash; } -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} - // Pivot the third (last) point left, until we find a different character. - for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); - --(*third)) {} - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third); - ++vibrant_second) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second); - ++vibrant_first) {} - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - #pragma GCC visibility pop #pragma endregion @@ -1853,26 +1272,6 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); } -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { for (sz_cptr_t const end = text + length; text != end; ++text) if (sz_charset_contains(set, *text)) return text; @@ -1904,7 +1303,7 @@ SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); sz_size_t min_length = a_shorter ? a_length : b_length; sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN +#if SZ_USE_MISALIGNED_LOADS && !_SZ_IS_BIG_ENDIAN for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { a_vec = sz_u64_load(a); b_vec = sz_u64_load(b); @@ -1943,7 +1342,7 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr if (!h_length) return SZ_NULL_CHAR; sz_cptr_t const h_end = h + h_length; -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. +#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) if (*h == *n) return h; @@ -1980,7 +1379,7 @@ sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { // Reposition the `h` pointer to the end, as we will be walking backwards. h = h + h_length - 1; -#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevety. +#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) if (*h == *n) return h; @@ -2364,7 +1763,7 @@ SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; -#if SZ_DETECT_BIG_ENDIAN +#if _SZ_IS_BIG_ENDIAN sz_find_t backends[] = { (sz_find_t)sz_find_byte_serial, (sz_find_t)_sz_find_horspool_upto_256bytes_serial, @@ -2823,7 +2222,7 @@ SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // // Walk through both strings using SWAR and counting the number of differing characters. sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN +#if SZ_USE_MISALIGNED_LOADS && !_SZ_IS_BIG_ENDIAN if (min_length >= SZ_SWAR_THRESHOLD) { sz_u64_vec_t a_vec, b_vec, match_vec; for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { @@ -3278,7 +2677,7 @@ SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_s // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. *length = string->external.length & (0x00000000000000FFull | is_big_mask); // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); + *space = sz_u64_blend(_SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); *is_external = (sz_bool_t)!is_small; } @@ -3336,7 +2735,7 @@ SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, string->words[2] = 0; string->words[3] = 0; // If we are lucky, no memory allocations will be needed. - if (space_needed <= SZ_STRING_INTERNAL_SPACE) { + if (space_needed <= _SZ_STRING_INTERNAL_SPACE) { string->internal.start = &string->internal.chars[0]; string->internal.length = (sz_u8_t)length; } @@ -3357,7 +2756,7 @@ SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); sz_size_t new_space = new_capacity + 1; - if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; + if (new_space <= _SZ_STRING_INTERNAL_SPACE) return string->external.start; sz_ptr_t string_start; sz_size_t string_length; @@ -3488,64 +2887,6 @@ SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *alloca sz_string_init(string); } -// When overriding libc, disable optimisations for this function beacuse MSVC will optimize the loops into a memset. -// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", off) -#endif -SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - sz_ptr_t end = target + length; - // Dealing with short strings, a single sequential pass would be faster. - // If the size is larger than 2 words, then at least 1 of them will be aligned. - // But just one aligned word may not be worth SWAR. - if (length < SZ_SWAR_THRESHOLD) - while (target != end) *(target++) = value; - - // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. - else { - sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; - while ((sz_size_t)target & 7ull) *(target++) = value; - while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; - while (target != end) *(target++) = value; - } -} -#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC -#pragma optimize("", on) -#endif - -SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); -} - -SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. - // Existing implementations often have two passes, in normal and reversed order, - // depending on the relation of `target` and `source` addresses. - // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html - // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 - // - // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. - // Or if we know that they don't intersect! In that case the traversal order is irrelevant, - // but older CPUs may predict and fetch forward-passes better. - if (target < source || target >= source + length) { -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; -#endif - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - target += length, source += length; -#if SZ_USE_MISALIGNED_LOADS - while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; -#endif - while (length--) *(--target) = *(--source); - } -} - #pragma endregion /* @@ -3803,7 +3144,7 @@ SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { -#if SZ_DETECT_BIG_ENDIAN +#if _SZ_IS_BIG_ENDIAN // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. sz_unused(partial_order_length); sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); @@ -3824,7 +3165,7 @@ SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_ } SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if SZ_DETECT_BIG_ENDIAN +#if _SZ_IS_BIG_ENDIAN sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); #else sz_sort_partial(sequence, sequence->count); @@ -3839,7 +3180,7 @@ SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { */ #pragma region AVX2 Implementation -#if SZ_USE_X86_AVX2 +#if SZ_USE_HASWELL #pragma GCC push_options #pragma GCC target("avx2") #pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) @@ -4564,7 +3905,7 @@ SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t windo */ #pragma region AVX512 Implementation -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE #pragma GCC push_options #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) @@ -4690,7 +4031,7 @@ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr return sz_equal_k; } -SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { +SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { __mmask64 mask; sz_u512_vec_t a_vec, b_vec; @@ -4950,7 +4291,7 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { +SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; @@ -4982,7 +4323,7 @@ SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); - if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; + if (sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } @@ -5040,7 +4381,7 @@ SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; + if (n_length <= 3 || sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } } @@ -5070,7 +4411,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cpt return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { +SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; @@ -5101,7 +4442,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) + if (n_length <= 3 || sz_equal_skylake(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && "The bit must be set before we squash it"); @@ -5121,7 +4462,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) + if (n_length <= 3 || sz_equal_skylake(h + 64 - potential_offset - 1, n, n_length)) return h + 64 - potential_offset - 1; sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && "The bit must be set before we squash it"); @@ -5439,7 +4780,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // ones_u16_vec.zmm = _mm512_set1_epi16(1); // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. + // Even there, in case `SZ_USE_HASWELL=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. sz_u512_vec_t shorter_vec, longer_vec; sz_u512_vec_t ones_u8_vec; ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); @@ -5810,7 +5151,7 @@ SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t win #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ apply_to = function) -SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_PUBLIC void sz_look_up_transform_ice(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. @@ -5920,7 +5261,7 @@ SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, s } } -SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. // In practice, that only hurts, even when we have matches every 5-ish bytes. @@ -6035,7 +5376,7 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_ return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { return sz_rfind_charset_serial(text, length, filter); } @@ -6046,7 +5387,7 @@ SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_avx512` and `sz_find_charset_avx512` functions. + // the `sz_find_skylake` and `sz_find_charset_ice` functions. // // Pick the offsets within needles where there is the least variance in the characters. // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": @@ -6363,7 +5704,7 @@ SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { */ #pragma region ARM NEON -#if SZ_USE_ARM_NEON +#if SZ_USE_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) @@ -6758,7 +6099,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_ch */ #pragma region ARM SVE -#if SZ_USE_ARM_SVE +#if SZ_USE_SVE #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) @@ -6902,11 +6243,11 @@ SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_ #if !SZ_DYNAMIC_DISPATCH SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_checksum_avx512(text, length); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL return sz_checksum_avx2(text, length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_checksum_neon(text, length); #else return sz_checksum_serial(text, length); @@ -6914,11 +6255,11 @@ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { } SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_X86_AVX512 - return sz_equal_avx512(a, b, length); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + return sz_equal_skylake(a, b, length); +#elif SZ_USE_HASWELL return sz_equal_avx2(a, b, length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_equal_neon(a, b, length); #else return sz_equal_serial(a, b, length); @@ -6926,11 +6267,11 @@ SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { } SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_order_neon(a, a_length, b, b_length); #else return sz_order_serial(a, a_length, b, b_length); @@ -6938,11 +6279,11 @@ SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, } SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE sz_copy_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL sz_copy_avx2(target, source, length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON sz_copy_neon(target, source, length); #else sz_copy_serial(target, source, length); @@ -6950,11 +6291,11 @@ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { } SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE sz_move_avx512(target, source, length); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL sz_move_avx2(target, source, length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON sz_move_neon(target, source, length); #else sz_move_serial(target, source, length); @@ -6962,11 +6303,11 @@ SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { } SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE sz_fill_avx512(target, length, value); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL sz_fill_avx2(target, length, value); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON sz_fill_neon(target, length, value); #else sz_fill_serial(target, length, value); @@ -6974,11 +6315,11 @@ SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { } SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_X86_AVX512 - sz_look_up_transform_avx512(source, length, lut, target); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + sz_look_up_transform_ice(source, length, lut, target); +#elif SZ_USE_HASWELL sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON sz_look_up_transform_neon(source, length, lut, target); #else sz_look_up_transform_serial(source, length, lut, target); @@ -6986,11 +6327,11 @@ SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr } SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_find_byte_neon(haystack, h_length, needle); #else return sz_find_byte_serial(haystack, h_length, needle); @@ -6998,11 +6339,11 @@ SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cpt } SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_rfind_byte_neon(haystack, h_length, needle); #else return sz_rfind_byte_serial(haystack, h_length, needle); @@ -7010,11 +6351,11 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cp } SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_find_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + return sz_find_skylake(haystack, h_length, needle, n_length); +#elif SZ_USE_HASWELL return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_find_neon(haystack, h_length, needle, n_length); #else return sz_find_serial(haystack, h_length, needle, n_length); @@ -7022,11 +6363,11 @@ SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t n } SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_X86_AVX512 - return sz_rfind_avx512(haystack, h_length, needle, n_length); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + return sz_rfind_skylake(haystack, h_length, needle, n_length); +#elif SZ_USE_HASWELL return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_rfind_neon(haystack, h_length, needle, n_length); #else return sz_rfind_serial(haystack, h_length, needle, n_length); @@ -7034,11 +6375,11 @@ SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t } SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_find_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + return sz_find_charset_ice(text, length, set); +#elif SZ_USE_HASWELL return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_find_charset_neon(text, length, set); #else return sz_find_charset_serial(text, length, set); @@ -7046,11 +6387,11 @@ SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charse } SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_X86_AVX512 - return sz_rfind_charset_avx512(text, length, set); -#elif SZ_USE_X86_AVX2 +#if SZ_USE_ICE + return sz_rfind_charset_ice(text, length, set); +#elif SZ_USE_HASWELL return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_ARM_NEON +#elif SZ_USE_NEON return sz_rfind_charset_neon(text, length, set); #else return sz_rfind_charset_serial(text, length, set); @@ -7075,7 +6416,7 @@ SZ_DYNAMIC sz_size_t sz_edit_distance( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); #else return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); @@ -7092,7 +6433,7 @@ SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); #else return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); @@ -7101,9 +6442,9 @@ SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cpt SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_X86_AVX512 +#if SZ_USE_ICE sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_X86_AVX2 +#elif SZ_USE_HASWELL sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); #else sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); From fc408fa0a0f2d947c610568bd7a5c4a60ecca443 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:24:49 +0000 Subject: [PATCH 043/751] Make: Split ./include/stringzilla/find.h to ./include/stringzilla/compare.h --- include/stringzilla/{find.h => compare.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{find.h => compare.h} (100%) diff --git a/include/stringzilla/find.h b/include/stringzilla/compare.h similarity index 100% rename from include/stringzilla/find.h rename to include/stringzilla/compare.h From 49e8d9d240993bdf68715a9c87824a032752798d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:24:50 +0000 Subject: [PATCH 044/751] Make: Split ./include/stringzilla/find.h to ./include/stringzilla/compare.h --- include/stringzilla/find.h => temp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/find.h => temp (100%) diff --git a/include/stringzilla/find.h b/temp similarity index 100% rename from include/stringzilla/find.h rename to temp From fc9e5d61e5fb1c5031f6f10920f6b50e2530de1e Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:24:50 +0000 Subject: [PATCH 045/751] Make: Split ./include/stringzilla/find.h to ./include/stringzilla/compare.h --- temp => include/stringzilla/find.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp => include/stringzilla/find.h (100%) diff --git a/temp b/include/stringzilla/find.h similarity index 100% rename from temp rename to include/stringzilla/find.h From 6512f1d129aeddc8601c9df7332c135038914b68 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:54:45 +0000 Subject: [PATCH 046/751] Fix: Filter `compare.h` file --- include/stringzilla/compare.h | 1716 +++------------------------------ include/stringzilla/find.h | 82 +- 2 files changed, 150 insertions(+), 1648 deletions(-) diff --git a/include/stringzilla/compare.h b/include/stringzilla/compare.h index 4571515d..9f2e276d 100644 --- a/include/stringzilla/compare.h +++ b/include/stringzilla/compare.h @@ -1,24 +1,17 @@ /** - * @brief Hardware-accelerated sub-string and character-set search utilities. - * @file find.h + * @brief Hardware-accelerated string comparison utilities. + * @file compare.h * @author Ash Vardanian * * Includes core APIs: * - * - `sz_equal` - * - `sz_find` and reverse-order `sz_rfind` - * - `sz_find_byte` and reverse-order `sz_rfind_byte` - * - `sz_find_charset` and reverse-order `sz_rfind_charset` - * - * Convenience functions for character-set matching: - * - * - `sz_find_char_from` - * - `sz_find_char_not_from` - * - `sz_rfind_char_from` - * - `sz_rfind_char_not_from` + * - `sz_equal` - for equality comparison of two strings. + * - `sz_order` - for the relative order of two strings, similar to `memcmp`. + * - TODO: `sz_mismatch`, `sz_rmismatch` - to supersede `sz_equal`. + * - TODO: `sz_order_utf8` - for the relative order of two UTF-8 strings. */ -#ifndef STRINGZILLA_FIND_H_ -#define STRINGZILLA_FIND_H_ +#ifndef STRINGZILLA_COMPARE_H_ +#define STRINGZILLA_COMPARE_H_ #include "types.h" @@ -29,165 +22,56 @@ extern "C" { #pragma region Core API /** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing + * @brief Checks if two string are equal. + * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -#if SZ_USE_HASWELL -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -#endif - -#if SZ_USE_SKYLAKE -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -#endif - -#if SZ_USE_NEON -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -#endif - -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. + * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. + * This function is more often used in parsing, while `sz_order` is often used in sorting. + * It works best on platforms with cheap * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. + * @param a First string to compare. + * @param b Second string to compare. + * @param length Number of bytes in both strings. + * @return 1 if strings match, 0 otherwise. */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** - * @brief Locates the last matching substring. + * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. + * Can be used on different length strings. * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. + * @param a First string to compare. + * @param a_length Number of bytes in the first string. + * @param b Second string to compare. + * @param b_length Number of bytes in the second string. + * @return Negative if (a < b), positive if (a > b), zero if they are equal. */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_equal */ +SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +/** @copydoc sz_order */ +SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); #if SZ_USE_HASWELL -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_equal */ +SZ_PUBLIC sz_bool_t sz_equal_haswell(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +/** @copydoc sz_order */ +SZ_PUBLIC sz_ordering_t sz_order_haswell(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); #endif #if SZ_USE_SKYLAKE -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -#endif - -#if SZ_USE_NEON -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -#endif - -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -#if SZ_USE_HASWELL -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -#endif - -#if SZ_USE_ICE -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_equal */ +SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +/** @copydoc sz_order */ +SZ_PUBLIC sz_ordering_t sz_order_skylake(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); #endif #if SZ_USE_NEON -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_equal */ +SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +/** @copydoc sz_order */ +SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); #endif #pragma endregion // Core API @@ -214,586 +98,23 @@ SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) return (sz_bool_t)(a_end == a); } -/** - * @brief Chooses the offsets of the most interesting characters in a search needle. - * - * Search throughput can significantly deteriorate if we are matching the wrong characters. - * Say the needle is "aXaYa", and we are comparing the first, second, and last character. - * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. - * - * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. - * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and - * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the - * bytes will carry absolutely no value and will be equal to 0x04. - */ -SZ_INTERNAL void _sz_locate_needle_anomalies( // - sz_cptr_t start, sz_size_t length, // - sz_size_t *first, sz_size_t *second, sz_size_t *third) { - - *first = 0; - *second = length / 2; - *third = length - 1; - - // - int has_duplicates = // - start[*first] == start[*second] || // - start[*first] == start[*third] || // - start[*second] == start[*third]; - - // Loop through letters to find non-colliding variants. - if (length > 3 && has_duplicates) { - // Pivot the middle point right, until we find a character different from the first one. - while (start[*second] == start[*first] && *second + 1 < *third) ++(*second); - // Pivot the third (last) point left, until we find a different character. - while ((start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1)) - --(*third); - } - - // TODO: Investigate alternative strategies for long needles. - // On very long needles we have the luxury to choose! - // Often dealing with UTF8, we will likely benefit from shifting the first and second characters - // further to the right, to achieve not only uniqueness within the needle, but also avoid common - // rune prefixes of 2-, 3-, and 4-byte codes. - if (length > 8) { - // Pivot the first and second points right, until we find a character, that: - // > is different from others. - // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. - // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. - // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. - // - // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. - // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. - sz_u8_t const *start_u8 = (sz_u8_t const *)start; - sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; - - // Let's begin with the seccond character, as the termination criteria there is more obvious - // and we may end up with more variants to check for the first candidate. - while ((start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && - (vibrant_second + 1 < vibrant_third)) - ++vibrant_second; - - // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. - if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } - else { vibrant_second = *second; } - - // Now check the first character. - while ((start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || - start_u8[vibrant_first] == start_u8[vibrant_third]) && - (vibrant_first + 1 < vibrant_second)) - ++vibrant_first; - - // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. - // We don't need to shift the third one when dealing with texts as the last byte of the text is - // also the last byte of a rune and contains the most information. - if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } - } -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/* Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; +SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); + sz_size_t min_length = a_shorter ? a_length : b_length; + sz_cptr_t min_end = a + min_length; +#if SZ_USE_MISALIGNED_LOADS && !_SZ_IS_BIG_ENDIAN + for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { + a_vec = sz_u64_load(a); + b_vec = sz_u64_load(b); + if (a_vec.u64 != b_vec.u64) + return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); } #endif + for (; a != min_end; ++a, ++b) + if (*a != *b) return _sz_order_scalars(*a, *b); - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/* Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial( // - sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal_serial((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial( // - sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal_serial((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix( // - sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal_serial(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal_serial(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial( // - sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial( // - sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if _SZ_IS_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); + // If the strings are equal up to `min_end`, then the shorter string is smaller + return _sz_order_scalars(a_length, b_length); } #pragma endregion // Serial Implementation @@ -804,8 +125,14 @@ SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n #pragma region Haswell Implementation #if SZ_USE_HASWELL #pragma GCC push_options -#pragma GCC target("haswell") -#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) +#pragma GCC target("avx2") +#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) + +SZ_PUBLIC sz_ordering_t sz_order_haswell(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: + //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations + return sz_order_serial(a, a_length, b, b_length); +} SZ_PUBLIC sz_bool_t sz_equal_haswell(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { sz_u256_vec_t a_vec, b_vec; @@ -823,203 +150,6 @@ SZ_PUBLIC sz_bool_t sz_equal_haswell(sz_cptr_t a, sz_cptr_t b, sz_size_t length) return sz_true_k; } -SZ_PUBLIC sz_cptr_t sz_find_byte_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_haswell(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = // - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal_haswell(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_haswell(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = // - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal_haswell(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMD-ized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } - } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_HASWELL @@ -1036,6 +166,69 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t text, sz_size_t length, s #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) +SZ_PUBLIC sz_ordering_t sz_order_skylake(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + sz_u512_vec_t a_vec, b_vec; + + // Pointer arithmetic is cheap, fetching memory is not! + // So we can use the masked loads to fetch at most one cache-line for each string, + // compare the prefixes, and only then move forward. + sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. + sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. + a_head_length = a_head_length < a_length ? a_head_length : a_length; + b_head_length = b_head_length < b_length ? b_head_length : b_length; + sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; + __mmask64 head_mask = _sz_u64_mask_until(head_length); + a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); + b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); + __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); + if (mask_not_equal != 0) { + sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); + char a_char = a_vec.u8s[first_diff]; + char b_char = b_vec.u8s[first_diff]; + return _sz_order_scalars(a_char, b_char); + } + else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } + else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } + + // The rare case, when both string are very long. + __mmask64 a_mask, b_mask; + while ((a_length >= 64) & (b_length >= 64)) { + a_vec.zmm = _mm512_loadu_si512(a); + b_vec.zmm = _mm512_loadu_si512(b); + mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); + if (mask_not_equal != 0) { + sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); + char a_char = a_vec.u8s[first_diff]; + char b_char = b_vec.u8s[first_diff]; + return _sz_order_scalars(a_char, b_char); + } + a += 64, b += 64, a_length -= 64, b_length -= 64; + } + + // In most common scenarios at least one of the strings is under 64 bytes. + if (a_length | b_length) { + a_mask = _sz_u64_clamp_mask_until(a_length); + b_mask = _sz_u64_clamp_mask_until(b_length); + a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); + b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); + // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. + // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have + // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. + mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); + if (mask_not_equal != 0) { + sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); + char a_char = a_vec.u8s[first_diff]; + char b_char = b_vec.u8s[first_diff]; + return _sz_order_scalars(a_char, b_char); + } + // From logic perspective, the hardest cases are "abc\0" and "abc". + // The result must be `sz_greater_k`, as the latter is shorter. + else { return _sz_order_scalars(a_length, b_length); } + } + + return sz_equal_k; +} + SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { __mmask64 mask; sz_u512_vec_t a_vec, b_vec; @@ -1060,217 +253,6 @@ SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length) return sz_true_k; } -SZ_PUBLIC sz_cptr_t sz_find_byte_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_skylake(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the algorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64( // - _kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. - } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64( // - _kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); - } - } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. - else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64( // - _kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64( // - _kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_skylake(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64( // - _kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_skylake(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64( // - _kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_skylake(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SKYLAKE @@ -1289,124 +271,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ apply_to = function) -SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} +/* Nothing here for now. */ #pragma clang attribute pop #pragma GCC pop_options @@ -1422,10 +287,10 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_ch #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; +SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: + //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations + return sz_order_serial(a, a_length, b, b_length); } SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { @@ -1442,215 +307,6 @@ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { return sz_true_k; } -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - - h += 16, h_length -= 16; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register( // - sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, uint8x16_t set_bottom_vec_u8x16) { - - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} - -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); - - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal_neon(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal_neon(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - - return sz_find_charset_serial(h, h_length, set); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - } - - return sz_rfind_charset_serial(h, h_length, set); -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_NEON @@ -1665,6 +321,8 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_ch #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) +/* Nothing here for now. */ + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SVE @@ -1676,118 +334,34 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_ch #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -#pragma region Core Funcitonality - -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_SKYLAKE - return sz_find_byte_skylake(haystack, h_length, needle); -#elif SZ_USE_HASWELL - return sz_find_byte_haswell(haystack, h_length, needle); -#elif SZ_USE_NEON - return sz_find_byte_neon(haystack, h_length, needle); -#else - return sz_find_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_SKYLAKE - return sz_rfind_byte_skylake(haystack, h_length, needle); -#elif SZ_USE_HASWELL - return sz_rfind_byte_haswell(haystack, h_length, needle); -#elif SZ_USE_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); -#else - return sz_rfind_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { +SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { #if SZ_USE_SKYLAKE - return sz_find_skylake(haystack, h_length, needle, n_length); + return sz_equal_skylake(a, b, length); #elif SZ_USE_HASWELL - return sz_find_haswell(haystack, h_length, needle, n_length); + return sz_equal_haswell(a, b, length); #elif SZ_USE_NEON - return sz_find_neon(haystack, h_length, needle, n_length); + return sz_equal_neon(a, b, length); #else - return sz_find_serial(haystack, h_length, needle, n_length); + return sz_equal_serial(a, b, length); #endif } -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { +SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { #if SZ_USE_SKYLAKE - return sz_rfind_skylake(haystack, h_length, needle, n_length); + return sz_order_skylake(a, a_length, b, b_length); #elif SZ_USE_HASWELL - return sz_rfind_haswell(haystack, h_length, needle, n_length); + return sz_order_haswell(a, a_length, b, b_length); #elif SZ_USE_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); + return sz_order_neon(a, a_length, b, b_length); #else - return sz_rfind_serial(haystack, h_length, needle, n_length); + return sz_order_serial(a, a_length, b, b_length); #endif } -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_ICE - return sz_find_charset_ice(text, length, set); -#elif SZ_USE_HASWELL - return sz_find_charset_haswell(text, length, set); -#elif SZ_USE_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_ICE - return sz_rfind_charset_ice(text, length, set); -#elif SZ_USE_HASWELL - return sz_rfind_charset_haswell(text, length, set); -#elif SZ_USE_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif -} - -#pragma endregion -#pragma region Helper Shortcuts - -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); -} - -#pragma endregion // Helper Shortcuts #endif // !SZ_DYNAMIC_DISPATCH #pragma endregion // Compile Time Dispatching #ifdef __cplusplus } #endif // __cplusplus -#endif // STRINGZILLA_FIND_H_ +#endif // STRINGZILLA_COMPARE_H_ diff --git a/include/stringzilla/find.h b/include/stringzilla/find.h index 4571515d..91892a0f 100644 --- a/include/stringzilla/find.h +++ b/include/stringzilla/find.h @@ -22,6 +22,8 @@ #include "types.h" +#include "compare.h" // `sz_equal` + #ifdef __cplusplus extern "C" { #endif @@ -194,26 +196,6 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t haystack, sz_size_t h_length #pragma region Serial Implementation -/** - * @brief Byte-level equality comparison between two strings. - * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. - */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_cptr_t const a_end = a + length; -#if SZ_USE_MISALIGNED_LOADS - if (length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec; - for (; a + 8 <= a_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) return sz_false_k; - } - } -#endif - while (a != a_end && *a == *b) a++, b++; - return (sz_bool_t)(a_end == a); -} - /** * @brief Chooses the offsets of the most interesting characters in a search needle. * @@ -804,24 +786,8 @@ SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n #pragma region Haswell Implementation #if SZ_USE_HASWELL #pragma GCC push_options -#pragma GCC target("haswell") -#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) - -SZ_PUBLIC sz_bool_t sz_equal_haswell(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} +#pragma GCC target("avx2") +#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) SZ_PUBLIC sz_cptr_t sz_find_byte_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { int mask; @@ -1036,30 +1002,6 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t text, sz_size_t length, s #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } - - return sz_true_k; -} - SZ_PUBLIC sz_cptr_t sz_find_byte_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { __mmask64 mask; sz_u512_vec_t h_vec, n_vec; @@ -1428,20 +1370,6 @@ SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; } -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { sz_u64_t matches; sz_u128_vec_t h_vec, n_vec, matches_vec; @@ -1676,7 +1604,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_ch #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -#pragma region Core Funcitonality +#pragma region Core Functionality SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { #if SZ_USE_SKYLAKE From 00f27f62c0767838f11dee34359a4aefd55977bd Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 8 Dec 2024 16:21:44 +0000 Subject: [PATCH 047/751] Fix: Haswell compilation flag --- include/stringzilla/memory.h | 6 +++--- include/stringzilla/similarity.h | 4 ++-- include/stringzilla/small_string.h | 3 ++- include/stringzilla/sort.h | 2 ++ 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index 32106a82..06a3dc60 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -328,8 +328,8 @@ SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt #if SZ_USE_HASWELL #pragma GCC push_options -#pragma GCC target("haswell") -#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) +#pragma GCC target("avx2") +#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) SZ_PUBLIC void sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value) { char value_char = *(char *)&value; @@ -1253,7 +1253,7 @@ SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -#pragma region Core Funcitonality +#pragma region Core Functionality SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { #if SZ_USE_ICE diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index ef34b824..5451c95f 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -579,8 +579,8 @@ SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // #pragma region Haswell Implementation #if SZ_USE_HASWELL #pragma GCC push_options -#pragma GCC target("haswell") -#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) +#pragma GCC target("avx2") +#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) #pragma clang attribute pop #pragma GCC pop_options diff --git a/include/stringzilla/small_string.h b/include/stringzilla/small_string.h index 17625700..ba823901 100644 --- a/include/stringzilla/small_string.h +++ b/include/stringzilla/small_string.h @@ -24,9 +24,10 @@ #ifndef STRINGZILLA_SMALL_STRING_H_ #define STRINGZILLA_SMALL_STRING_H_ +#include "types.h" + #include "find.h" // `sz_equal` #include "memory.h" // `sz_copy`, `sz_move`, `sz_fill` -#include "types.h" // `sz_size_t`, `sz_ptr_t`, `sz_cptr_t` #ifdef __cplusplus extern "C" { diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 4fe64bee..7a8de124 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -15,6 +15,8 @@ #include "types.h" +#include "compare.h" // `sz_compare` + #ifdef __cplusplus extern "C" { #endif From 406bf0f2befc379c17372e4871e62cc13d6f5ad8 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 8 Dec 2024 19:21:58 +0000 Subject: [PATCH 048/751] Fix: Symbols names & visibility --- README.md | 5 +- c/lib.c | 216 +- include/stringzilla/drafts.h | 28 +- include/stringzilla/find.h | 85 +- include/stringzilla/hash.h | 70 +- include/stringzilla/memory.h | 50 +- include/stringzilla/similarity.h | 115 +- include/stringzilla/small_string.h | 16 +- include/stringzilla/stringzilla.h | 6451 +-------------------------- include/stringzilla/stringzilla.hpp | 100 +- include/stringzilla/types.h | 49 +- scripts/bench_memory.cpp | 20 +- scripts/bench_similarity.cpp | 8 +- scripts/bench_token.cpp | 12 +- scripts/test.cpp | 29 +- 15 files changed, 458 insertions(+), 6796 deletions(-) diff --git a/README.md b/README.md index c4122696..c07050a3 100644 --- a/README.md +++ b/README.md @@ -624,7 +624,8 @@ sz_string_view_t needle = {your_subtext, your_subtext_length}; // Perform string-level operations sz_size_t substring_position = sz_find(haystack.start, haystack.length, needle.start, needle.length); -sz_size_t substring_position = sz_find_avx512(haystack.start, haystack.length, needle.start, needle.length); +sz_size_t substring_position = sz_find_skylake(haystack.start, haystack.length, needle.start, needle.length); +sz_size_t substring_position = sz_find_haswell(haystack.start, haystack.length, needle.start, needle.length); sz_size_t substring_position = sz_find_neon(haystack.start, haystack.length, needle.start, needle.length); // Hash strings @@ -747,7 +748,7 @@ typedef union sz_string_t { struct internal { sz_ptr_t start; sz_u8_t length; - char chars[SZ_STRING_INTERNAL_SPACE]; /// Ends with a null-terminator. + char chars[_SZ_STRING_INTERNAL_SPACE]; /// Ends with a null-terminator. } internal; struct external { diff --git a/c/lib.c b/c/lib.c index e1d98328..8a0a75b9 100644 --- a/c/lib.c +++ b/c/lib.c @@ -3,10 +3,20 @@ * @brief StringZilla C library with dynamic backed dispatch for the most appropriate implementation. * @author Ash Vardanian * @date January 16, 2024 - * @copyright Copyright (c) 2024 */ -#if defined(_WIN32) || defined(__CYGWIN__) -#include // `DllMain` +#if SZ_AVOID_LIBC +// If we don't have the LibC, the `malloc` definition in `stringzilla.h` will be illformed. +#ifdef _MSC_VER +typedef sz_size_t size_t; // Reuse the type definition we've inferred from `stringzilla.h` +extern __declspec(dllimport) int rand(void); +extern __declspec(dllimport) void free(void *start); +extern __declspec(dllimport) void *malloc(size_t length); +#else +typedef __SIZE_TYPE__ size_t; // For GCC/Clang +extern int rand(void); +extern void free(void *start); +extern void *malloc(size_t length); +#endif #endif // When enabled, this library will override the symbols usually provided by the C standard library. @@ -23,35 +33,32 @@ #define SZ_DYNAMIC_DISPATCH 1 #include -#if SZ_AVOID_LIBC -// If we don't have the LibC, the `malloc` definition in `stringzilla.h` will be illformed. -#ifdef _MSC_VER -typedef sz_size_t size_t; // Reuse the type definition we've inferred from `stringzilla.h` -extern __declspec(dllimport) int rand(void); -extern __declspec(dllimport) void free(void *start); -extern __declspec(dllimport) void *malloc(size_t length); -#else -typedef __SIZE_TYPE__ size_t; // For GCC/Clang -extern int rand(void); -extern void free(void *start); -extern void *malloc(size_t length); -#endif +// Inferring target OS: Windows, MacOS, or Linux +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) || defined(__CYGWIN__) +#define _SZ_IS_WINDOWS 1 +#elif defined(__APPLE__) && defined(__MACH__) +#define _SZ_IS_APPLE 1 +#elif defined(__linux__) +#define _SZ_IS_LINUX 1 #endif // On Apple Silicon, `mrs` is not allowed in user-space, so we need to use the `sysctl` API. -#if defined(__APPLE__) && defined(__MACH__) -#define SZ_APPLE 1 +#if defined(_SZ_IS_APPLE) #include #endif -#if defined(__linux__) -#define SZ_LINUX 1 +#if defined(_SZ_IS_WINDOWS) +#include // `DllMain` #endif -SZ_INTERNAL sz_capability_t sz_capabilities_arm(void) { +/** + * @brief Function to determine the SIMD capabilities of the current 64-bit Arm machine at @b runtime. + * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. + */ +SZ_INTERNAL sz_capability_t _sz_capabilities_arm(void) { // https://github.com/ashvardanian/SimSIMD/blob/28e536083602f85ad0c59456782c8864463ffb0e/include/simsimd/simsimd.h#L434 // for documentation on how we detect capabilities across different ARM platforms. -#if defined(SZ_APPLE) +#if defined(_SZ_IS_APPLE) // On Apple Silicon, `mrs` is not allowed in user-space, so we need to use the `sysctl` API. uint32_t supports_neon = 0; @@ -62,20 +69,47 @@ SZ_INTERNAL sz_capability_t sz_capabilities_arm(void) { (sz_cap_arm_neon_k * (supports_neon)) | // (sz_cap_serial_k)); -#elif defined(SZ_LINUX) - unsigned supports_neon = 1; // NEON is always supported +#elif defined(_SZ_IS_LINUX) + + // Read CPUID registers directly + unsigned long id_aa64isar0_el1 = 0, id_aa64isar1_el1 = 0, id_aa64pfr0_el1 = 0, id_aa64zfr0_el1 = 0; + + // Now let's unpack the status flags from ID_AA64ISAR0_EL1 + // https://developer.arm.com/documentation/ddi0601/2024-03/AArch64-Registers/ID-AA64ISAR0-EL1--AArch64-Instruction-Set-Attribute-Register-0?lang=en + __asm__ __volatile__("mrs %0, ID_AA64ISAR0_EL1" : "=r"(id_aa64isar0_el1)); + // Now let's unpack the status flags from ID_AA64ISAR1_EL1 + // https://developer.arm.com/documentation/ddi0601/2024-03/AArch64-Registers/ID-AA64ISAR1-EL1--AArch64-Instruction-Set-Attribute-Register-1?lang=en + __asm__ __volatile__("mrs %0, ID_AA64ISAR1_EL1" : "=r"(id_aa64isar1_el1)); + // Now let's unpack the status flags from ID_AA64PFR0_EL1 + // https://developer.arm.com/documentation/ddi0601/2024-03/AArch64-Registers/ID-AA64PFR0-EL1--AArch64-Processor-Feature-Register-0?lang=en __asm__ __volatile__("mrs %0, ID_AA64PFR0_EL1" : "=r"(id_aa64pfr0_el1)); + // SVE, bits [35:32] of ID_AA64PFR0_EL1 unsigned supports_sve = ((id_aa64pfr0_el1 >> 32) & 0xF) >= 1; - return (sz_capability_t)( // - (sz_cap_neon_k * (supports_neon)) | // - (sz_cap_sve_k * (supports_sve)) | // + // Now let's unpack the status flags from ID_AA64ZFR0_EL1 + // https://developer.arm.com/documentation/ddi0601/2024-03/AArch64-Registers/ID-AA64ZFR0-EL1--SVE-Feature-ID-Register-0?lang=en + if (supports_sve) __asm__ __volatile__("mrs %0, ID_AA64ZFR0_EL1" : "=r"(id_aa64zfr0_el1)); + // SVEver, bits [3:0] can be used to check for capability levels: + // - 0b0000: SVE is implemented + // - 0b0001: SVE2 is implemented + // - 0b0010: SVE2.1 is implemented + // This value must match the existing indicator obtained from ID_AA64PFR0_EL1: + unsigned supports_sve2 = ((id_aa64zfr0_el1) & 0xF) >= 1; + unsigned supports_sve2p1 = ((id_aa64zfr0_el1) & 0xF) >= 2; + unsigned supports_neon = 1; // NEON is always supported + + return (sz_capability_t)( // + (sz_cap_neon_k * (supports_neon)) | // + (sz_cap_sve_k * (supports_sve)) | // + (sz_cap_sve2_k * (supports_sve2)) | // + (sz_cap_sve2p1_k * (supports_sve2p1)) | // (sz_cap_serial_k)); -#else // SIMSIMD_DEFINED_LINUX + +#else // if !defined(_SZ_IS_APPLE) && !defined(_SZ_IS_LINUX) return sz_cap_serial_k; #endif } -SZ_DYNAMIC sz_capability_t sz_capabilities(void) { +SZ_INTERNAL sz_capability_t _sz_capabilities_x86(void) { #if SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE @@ -91,54 +125,50 @@ SZ_DYNAMIC sz_capability_t sz_capabilities(void) { __cpuidex(info1.array, 1, 0); __cpuidex(info7.array, 7, 0); #else - __asm__ __volatile__("cpuid" - : "=a"(info1.named.eax), "=b"(info1.named.ebx), "=c"(info1.named.ecx), "=d"(info1.named.edx) - : "a"(1), "c"(0)); - __asm__ __volatile__("cpuid" - : "=a"(info7.named.eax), "=b"(info7.named.ebx), "=c"(info7.named.ecx), "=d"(info7.named.edx) - : "a"(7), "c"(0)); + __asm__ __volatile__( // + "cpuid" + : "=a"(info1.named.eax), "=b"(info1.named.ebx), "=c"(info1.named.ecx), "=d"(info1.named.edx) + : "a"(1), "c"(0)); + __asm__ __volatile__( // + "cpuid" + : "=a"(info7.named.eax), "=b"(info7.named.ebx), "=c"(info7.named.ecx), "=d"(info7.named.edx) + : "a"(7), "c"(0)); #endif - // Check for AVX2 (Function ID 7, EBX register) + // Check for AVX2 (Function ID 7, EBX register), you can take the relevant flags from the LLVM implementation: // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L148 unsigned supports_avx2 = (info7.named.ebx & 0x00000020) != 0; - // Check for AVX512F (Function ID 7, EBX register) - // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L155 unsigned supports_avx512f = (info7.named.ebx & 0x00010000) != 0; - // Check for AVX512BW (Function ID 7, EBX register) - // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L166 unsigned supports_avx512bw = (info7.named.ebx & 0x40000000) != 0; - // Check for AVX512VL (Function ID 7, EBX register) - // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L167C25-L167C35 unsigned supports_avx512vl = (info7.named.ebx & 0x80000000) != 0; - // Check for GFNI (Function ID 1, ECX register) - // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L171C30-L171C40 unsigned supports_avx512vbmi = (info7.named.ecx & 0x00000002) != 0; unsigned supports_avx512vbmi2 = (info7.named.ecx & 0x00000040) != 0; - // Check for GFNI (Function ID 1, ECX register) - // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L177C30-L177C40 - unsigned supports_gfni = (info7.named.ecx & 0x00000100) != 0; - - return (sz_capability_t)( // - (sz_cap_x86_avx2_k * supports_avx2) | // - (sz_cap_x86_avx512f_k * supports_avx512f) | // - (sz_cap_x86_avx512vl_k * supports_avx512vl) | // - (sz_cap_x86_avx512bw_k * supports_avx512bw) | // - (sz_cap_x86_avx512vbmi_k * supports_avx512vbmi) | // - (sz_cap_x86_avx512vbmi2_k * supports_avx512vbmi2) | // - (sz_cap_x86_gfni_k * (supports_gfni)) | // - (sz_cap_serial_k)); - -#endif // SZ_TARGET_X86 - -#if SZ_USE_NEON || SZ_USE_SVE + unsigned supports_vaes = (info7.named.ecx & 0x00000200) != 0; - return sz_capabilities_arm(); - -#endif // SZ_TARGET_ARM + return (sz_capability_t)( // + (sz_cap_haswell_k * supports_avx2) | // + (sz_cap_skylake_k * (supports_avx512f && supports_avx512vl && supports_avx512bw && supports_vaes)) | // + (sz_cap_ice_k * (supports_avx512vbmi && supports_avx512vbmi2)) | // + (sz_cap_serial_k)); +#else + return sz_cap_serial_k; +#endif +} +/** + * @brief Function to determine the SIMD capabilities of the current 64-bit x86 machine at @b runtime. + * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. + */ +SZ_DYNAMIC sz_capability_t sz_capabilities(void) { +#if _SZ_IS_X86 + return _sz_capabilities_x86(); +#elif _SZ_IS_ARM + return _sz_capabilities_arm(); +#else return sz_cap_serial_k; +#endif } + typedef struct sz_implementations_t { sz_equal_t equal; sz_order_t order; @@ -197,56 +227,54 @@ static void sz_dispatch_table_init(void) { impl->hashes = sz_hashes_serial; #if SZ_USE_HASWELL - if (caps & sz_cap_x86_avx2_k) { - impl->equal = sz_equal_avx2; - impl->order = sz_order_avx2; - - impl->copy = sz_copy_avx2; - impl->move = sz_move_avx2; - impl->fill = sz_fill_avx2; - impl->look_up_transform = sz_look_up_transform_avx2; - impl->checksum = sz_checksum_avx2; - - impl->find_byte = sz_find_byte_avx2; - impl->rfind_byte = sz_rfind_byte_avx2; - impl->find = sz_find_avx2; - impl->rfind = sz_rfind_avx2; - impl->find_from_set = sz_find_charset_avx2; - impl->rfind_from_set = sz_rfind_charset_avx2; + if (caps & sz_cap_haswell_k) { + impl->equal = sz_equal_haswell; + impl->order = sz_order_haswell; + + impl->copy = sz_copy_haswell; + impl->move = sz_move_haswell; + impl->fill = sz_fill_haswell; + impl->look_up_transform = sz_look_up_transform_haswell; + impl->checksum = sz_checksum_haswell; + + impl->find_byte = sz_find_byte_haswell; + impl->rfind_byte = sz_rfind_byte_haswell; + impl->find = sz_find_haswell; + impl->rfind = sz_rfind_haswell; + impl->find_from_set = sz_find_charset_haswell; + impl->rfind_from_set = sz_rfind_charset_haswell; } #endif #if SZ_USE_SKYLAKE - if (caps & sz_cap_x86_avx512f_k) { + if (caps & sz_cap_skylake_k) { impl->equal = sz_equal_skylake; - impl->order = sz_order_avx512; + impl->order = sz_order_skylake; - impl->copy = sz_copy_avx512; - impl->move = sz_move_avx512; - impl->fill = sz_fill_avx512; + impl->copy = sz_copy_skylake; + impl->move = sz_move_skylake; + impl->fill = sz_fill_skylake; impl->find = sz_find_skylake; impl->rfind = sz_rfind_skylake; - impl->find_byte = sz_find_byte_avx512; - impl->rfind_byte = sz_rfind_byte_avx512; - - impl->edit_distance = sz_edit_distance_avx512; + impl->find_byte = sz_find_byte_skylake; + impl->rfind_byte = sz_rfind_byte_skylake; } #endif #if SZ_USE_ICE - if ((caps & sz_cap_x86_avx512f_k) && (caps & sz_cap_x86_avx512vl_k) && (caps & sz_cap_x86_avx512vbmi2_k) && - (caps & sz_cap_x86_avx512bw_k) && (caps & sz_cap_x86_avx512vbmi_k)) { + if (caps & sz_cap_ice_k) { impl->find_from_set = sz_find_charset_ice; impl->rfind_from_set = sz_rfind_charset_ice; - impl->alignment_score = sz_alignment_score_avx512; + impl->edit_distance = sz_edit_distance_ice; + impl->alignment_score = sz_alignment_score_ice; impl->look_up_transform = sz_look_up_transform_ice; - impl->checksum = sz_checksum_avx512; + impl->checksum = sz_checksum_ice; } #endif #if SZ_USE_NEON - if (caps & sz_cap_arm_neon_k) { + if (caps & sz_cap_neon_k) { impl->equal = sz_equal_neon; impl->copy = sz_copy_neon; diff --git a/include/stringzilla/drafts.h b/include/stringzilla/drafts.h index 1817a81e..49099cbe 100644 --- a/include/stringzilla/drafts.h +++ b/include/stringzilla/drafts.h @@ -342,24 +342,24 @@ sz_u512_vec_t sz_inclusive_min(sz_i32_t previous, sz_error_cost_t gap, sz_u512_v shifted_vec.i32s[0] = previous; shifted_vec.zmm = _mm512_add_epi32(shifted_vec.zmm, gap_vec.zmm); new_vec.zmm = _mm512_mask_max_epi32(new_vec.zmm, mask_skip_one, new_vec.zmm, shifted_vec.zmm); - sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); + _sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); shifted_vec.zmm = _mm512_permutexvar_epi32(shift_by_two_vec.zmm, new_vec.zmm); shifted_vec.zmm = _mm512_add_epi32(shifted_vec.zmm, gap_double_vec.zmm); new_vec.zmm = _mm512_mask_max_epi32(new_vec.zmm, mask_skip_two, new_vec.zmm, shifted_vec.zmm); - sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); + _sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); shifted_vec.zmm = _mm512_permutexvar_epi32(shift_by_four_vec.zmm, new_vec.zmm); shifted_vec.zmm = _mm512_add_epi32(shifted_vec.zmm, gap_quad_vec.zmm); new_vec.zmm = _mm512_mask_max_epi32(new_vec.zmm, mask_skip_four, new_vec.zmm, shifted_vec.zmm); - sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); + _sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); shifted_vec.zmm = _mm512_permutexvar_epi32(shift_by_eight_vec.zmm, new_vec.zmm); shifted_vec.zmm = _mm512_add_epi32(shifted_vec.zmm, gap_octa_vec.zmm); new_vec.zmm = _mm512_mask_max_epi32(new_vec.zmm, mask_skip_eight, new_vec.zmm, shifted_vec.zmm); - sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); - for (sz_size_t i = 1; i < 16; i++) sz_assert(new_vec.i32s[i] == max(new_vec.i32s[i - 1] + gap, new_vec.i32s[i])); + _sz_assert(new_vec.i32s[0] == max(previous + gap, base_vec.i32s[0])); + for (sz_size_t i = 1; i < 16; i++) _sz_assert(new_vec.i32s[i] == max(new_vec.i32s[i - 1] + gap, new_vec.i32s[i])); return new_vec; } @@ -1015,7 +1015,7 @@ SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t return sz_order_serial(a, a_length, b, b_length); } -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { +SZ_PUBLIC sz_ordering_t sz_order_skylake(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { sz_u512_vec_t a_vec, b_vec; // The rare case, when both string are very long surves as a great example to understand @@ -1124,8 +1124,8 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt for (; length >= 128; target += 64, source += 64, length -= 64) { second_vec.zmm = _mm512_load_si512(target + 64); combined_vec.zmm = _mm512_permutex2var_epi8(first_vec.zmm, selector_vec.zmm, second_vec.zmm); - sz_assert(combined_vec.u8s[0] == source[0]); - sz_assert(combined_vec.u8s[63] == source[63]); + _sz_assert(combined_vec.u8s[0] == source[0]); + _sz_assert(combined_vec.u8s[63] == source[63]); _mm512_store_si512(target, combined_vec.zmm); first_vec.zmm = second_vec.zmm; } @@ -1147,8 +1147,8 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt second_vec.zmm = _mm512_load_si512(target + 64); first_shuffled_vec.zmm = _mm512_shuffle_epi8(first_vec.zmm, first_byte_permute_vec.zmm); second_shuffled_vec.zmm = _mm512_shuffle_epi8(second_vec.zmm, second_byte_permute_vec.zmm); - sz_assert(first_shuffled_vec.u8s[0] == source[0]); - sz_assert(second_shuffled_vec.u8s[63] == source[63]); + _sz_assert(first_shuffled_vec.u8s[0] == source[0]); + _sz_assert(second_shuffled_vec.u8s[63] == source[63]); combined_vec.zmm = _mm512_or_si512(first_shuffled_vec.zmm, second_shuffled_vec.zmm); _mm512_store_si512(target, combined_vec.zmm); first_vec.zmm = second_vec.zmm; @@ -1279,8 +1279,8 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt second_vec.zmm = _mm512_load_si512(source_page + 64); second_vec.zmm = _mm512_permutexvar_epi8(selector_vec.zmm, second_vec.zmm); combined_vec.zmm = _mm512_mask_blend_epi8(blend_mask, second_vec.zmm, first_vec.zmm); - sz_assert(combined_vec.u8s[0] == source[0]); - sz_assert(combined_vec.u8s[63] == source[63]); + _sz_assert(combined_vec.u8s[0] == source[0]); + _sz_assert(combined_vec.u8s[63] == source[63]); _mm512_store_si512(target, combined_vec.zmm); first_vec.zmm = second_vec.zmm; } @@ -1313,8 +1313,8 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt second_vec.zmm = _mm512_load_si512(source_second_page - 64); second_vec.zmm = _mm512_permutexvar_epi8(selector_vec.zmm, second_vec.zmm); combined_vec.zmm = _mm512_mask_blend_epi8(blend_mask, second_vec.zmm, first_vec.zmm); - sz_assert(combined_vec.u8s[0] == source[0]); - sz_assert(combined_vec.u8s[63] == source[63]); + _sz_assert(combined_vec.u8s[0] == source[0]); + _sz_assert(combined_vec.u8s[63] == source[63]); _mm512_store_si512(target + head_length + body_length, combined_vec.zmm); first_vec.zmm = second_vec.zmm; } diff --git a/include/stringzilla/find.h b/include/stringzilla/find.h index 91892a0f..b5740429 100644 --- a/include/stringzilla/find.h +++ b/include/stringzilla/find.h @@ -113,23 +113,23 @@ SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_c #if SZ_USE_HASWELL /** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_find_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); #endif #if SZ_USE_SKYLAKE /** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); #endif #if SZ_USE_NEON /** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); #endif /** @@ -173,23 +173,23 @@ SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz #if SZ_USE_HASWELL /** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); #endif #if SZ_USE_ICE /** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); #endif #if SZ_USE_NEON /** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); #endif #pragma endregion // Core API @@ -375,7 +375,7 @@ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); + _sz_assert(h_length >= 2 && "The haystack is too short."); sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS @@ -429,7 +429,7 @@ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); + _sz_assert(h_length >= 4 && "The haystack is too short."); sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS @@ -493,7 +493,7 @@ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); + _sz_assert(h_length >= 3 && "The haystack is too short."); sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS @@ -550,7 +550,7 @@ SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_ SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial( // sz_cptr_t h_chars, sz_size_t h_length, // sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); + _sz_assert(n_length <= 256 && "The pattern is too long."); // Several popular string matching algorithms are using a bad-character shift table. // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html @@ -604,7 +604,7 @@ SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial( // SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial( // sz_cptr_t h_chars, sz_size_t h_length, // sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); + _sz_assert(n_length <= 256 && "The pattern is too long."); union { sz_u8_t jumps[256]; sz_u64_vec_t vecs[64]; @@ -941,7 +941,7 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t text, sz_size_t length, sz // sz_u8_t input = *(sz_u8_t const *)(text + i); // sz_u8_t lo_nibble = input & 0x0f; // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); + // _sz_assert(bitmask_vec.u8s[i] == bitmask); // } // // Shift right every byte by 4 bits. @@ -959,8 +959,8 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t text, sz_size_t length, sz // sz_u8_t hi_nibble = input >> 4; // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); + // _sz_assert(bitset_even_vec.u8s[i] == bitset_even); + // _sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); // } // __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); @@ -1183,8 +1183,8 @@ SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t int potential_offset = sz_u64_clz(matches); if (n_length <= 3 || sz_equal_skylake(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); + _sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && + "The bit must be set before we squash it"); matches &= ~((sz_u64_t)1 << (63 - potential_offset)); } } @@ -1204,8 +1204,8 @@ SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t int potential_offset = sz_u64_clz(matches); if (n_length <= 3 || sz_equal_skylake(h + 64 - potential_offset - 1, n, n_length)) return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); + _sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && + "The bit must be set before we squash it"); matches &= ~((sz_u64_t)1 << (63 - potential_offset)); } } @@ -1223,13 +1223,16 @@ SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, * - 2018 CannonLake: IFMA, VBMI, * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. + * + * We are going to use VBMI2 for `_mm256_maskz_compress_epi8`. */ #pragma region Ice Lake Implementation #if SZ_USE_ICE #pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") +#pragma clang attribute push( \ + __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ + apply_to = function) SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { @@ -1247,7 +1250,7 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_cha sz_u512_vec_t filter_even_vec, filter_odd_vec; __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: + // In the chronological order of experiments: // - serial code initializing 128 bytes of odd and even mask // - using several shuffles // - using `_mm512_permutexvar_epi8` @@ -1260,14 +1263,14 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_cha // After the unzipping operation, we can validate the contents of the vectors like this: // // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); + // _sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); + // _sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); + // _sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); + // _sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); + // _sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); + // _sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); + // _sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); + // _sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); // } // sz_u512_vec_t text_vec; @@ -1310,7 +1313,7 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_cha // sz_u8_t input = *(sz_u8_t const *)(text + i); // sz_u8_t lo_nibble = input & 0x0f; // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); + // _sz_assert(bitmask_vec.u8s[i] == bitmask); // } // // Shift right every byte by 4 bits. @@ -1328,8 +1331,8 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_cha // sz_u8_t hi_nibble = input >> 4; // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); + // _sz_assert(bitset_even_vec.u8s[i] == bitset_even); + // _sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); // } // // TODO: Is this a good place for ternary logic? @@ -1539,8 +1542,8 @@ SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, int potential_offset = sz_u64_clz(matches) / 4; if (sz_equal_neon(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); + _sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && + "The bit must be set before we squash it"); matches &= ~(1ull << (63 - potential_offset * 4)); } } diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index d8f4a05e..0e5e883e 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -47,7 +47,10 @@ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); * * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); +SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length) { + sz_unused(text && length); + return 0; +} /** * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. @@ -99,7 +102,9 @@ SZ_DYNAMIC void sz_hashes( */ SZ_PUBLIC void sz_hashes_fingerprint( // sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); + sz_ptr_t fingerprint, sz_size_t fingerprint_bytes) { + sz_unused(text && length && window_length && fingerprint && fingerprint_bytes); +} /** * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes @@ -145,16 +150,18 @@ SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial( // - sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, - void *generator); - /** @copydoc sz_hashes */ SZ_PUBLIC void sz_hashes_serial( // sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle); +/** @copydoc sz_generate */ +SZ_PUBLIC void sz_generate_serial( // + sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, + void *generator) { + sz_unused(alphabet && cardinality && text && length && generate && generator); +} + #pragma endregion // Core API #pragma region Serial Implementation @@ -337,6 +344,33 @@ SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t win } } +/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ +SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { + sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; + sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; + sz_size_t fingerprint_bytes = fingerprint_buffer->length; + fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); + sz_unused(start && length); +} + +/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ +SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback( // + sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { + sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; + sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; + sz_size_t fingerprint_bytes = fingerprint_buffer->length; + fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); + sz_unused(start && length); +} + +/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ +SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback( // + sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *scalar_handle) { + sz_unused(start && length && hash && scalar_handle); + sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; + *scalar_ptr ^= hash; +} + #undef _sz_shift_low #undef _sz_shift_high #undef _sz_hash_mix @@ -350,10 +384,10 @@ SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t win #pragma region Haswell Implementation #if SZ_USE_HASWELL #pragma GCC push_options -#pragma GCC target("haswell") -#pragma clang attribute push(__attribute__((target("haswell"))), apply_to = function) +#pragma GCC target("avx2") +#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". // @@ -448,8 +482,8 @@ SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { return prod; } -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { +SZ_PUBLIC void sz_hashes_haswell(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { if (length < window_length || !window_length) return; if (length < 4 * window_length) { @@ -702,8 +736,8 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { } } -SZ_PUBLIC void sz_hashes_ice(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { +SZ_PUBLIC void sz_hashes_skylake(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { if (length < window_length || !window_length) return; if (length < 4 * window_length) { @@ -888,7 +922,7 @@ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { #if SZ_USE_ICE return sz_checksum_ice(text, length); #elif SZ_USE_HASWELL - return sz_checksum_avx2(text, length); + return sz_checksum_haswell(text, length); #elif SZ_USE_NEON return sz_checksum_neon(text, length); #else @@ -898,10 +932,10 @@ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_ICE - sz_hashes_ice(text, length, window_length, window_step, callback, callback_handle); +#if SZ_USE_SKYLAKE + sz_hashes_skylake(text, length, window_length, window_step, callback, callback_handle); #elif SZ_USE_HASWELL - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); + sz_hashes_haswell(text, length, window_length, window_step, callback, callback_handle); #else sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); #endif diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index 06a3dc60..c17f031f 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -64,29 +64,29 @@ SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); #if SZ_USE_HASWELL /** @copydoc sz_copy */ -SZ_PUBLIC sz_cptr_t sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +SZ_PUBLIC void sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ -SZ_PUBLIC sz_cptr_t sz_move_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +SZ_PUBLIC void sz_move_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_rfind_fill */ -SZ_PUBLIC sz_cptr_t sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value); +SZ_PUBLIC void sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value); #endif #if SZ_USE_SKYLAKE /** @copydoc sz_copy */ -SZ_PUBLIC sz_cptr_t sz_copy_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +SZ_PUBLIC void sz_copy_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ -SZ_PUBLIC sz_cptr_t sz_move_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +SZ_PUBLIC void sz_move_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_rfind_fill */ -SZ_PUBLIC sz_cptr_t sz_fill_skylake(sz_ptr_t target, sz_size_t length, sz_u8_t value); +SZ_PUBLIC void sz_fill_skylake(sz_ptr_t target, sz_size_t length, sz_u8_t value); #endif #if SZ_USE_NEON /** @copydoc sz_copy */ -SZ_PUBLIC sz_cptr_t sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ -SZ_PUBLIC sz_cptr_t sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_rfind_fill */ -SZ_PUBLIC sz_cptr_t sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); +SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); #endif /** @@ -358,13 +358,13 @@ SZ_PUBLIC void sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value) if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; if (head_length & 16) _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); + _sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); // Fill the aligned body of the buffer. for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); + _sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); if (tail_length & 16) _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; @@ -374,7 +374,7 @@ SZ_PUBLIC void sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value) } } -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { +SZ_PUBLIC void sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "stores" and "loads". // @@ -387,7 +387,7 @@ SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) // For now, let's avoid the cases beyond the L2 size. int is_huge = length > 1ull * 1024ull * 1024ull; if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_haswell` function, + // When dealing with larger arrays, the optimization is not as simple as with the `sz_fill_haswell` function, // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, // we can use aligned loads and stores, and the performance will be great. else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { @@ -411,7 +411,7 @@ SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) if (head_length & 16) _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); + _sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); // Fill the aligned body of the buffer. if (!is_huge) { @@ -429,7 +429,7 @@ SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) } // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); + _sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); if (tail_length & 16) _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, tail_length -= 16; @@ -440,7 +440,7 @@ SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) } } -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { +SZ_PUBLIC void sz_move_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { if (target < source || target >= source + length) { for (; length >= 32; target += 32, source += 32, length -= 32) _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); @@ -454,7 +454,7 @@ SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) } } -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_PUBLIC void sz_look_up_transform_haswell(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. @@ -637,7 +637,7 @@ SZ_PUBLIC void sz_fill_skylake(sz_ptr_t target, sz_size_t length, sz_u8_t value) } } -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { +SZ_PUBLIC void sz_copy_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "stores" and "loads". // @@ -656,7 +656,7 @@ SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt __mmask64 mask = _sz_u64_mask_until(length); _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_skylake` function, + // When dealing with larger arrays, the optimization is not as simple as with the `sz_fill_skylake` function, // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, // we can use aligned loads and stores, and the performance will be great. else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { @@ -715,7 +715,7 @@ SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt } } -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { +SZ_PUBLIC void sz_move_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { if (target == source) return; // Don't be silly, don't move the data if it's already there. // On very short buffers, that are one cache line in width or less, we don't need any loops. @@ -757,7 +757,7 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt } // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } + else if (target + length < source || target >= source + length) { sz_copy_skylake(target, source, length); } // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores @@ -1257,9 +1257,9 @@ SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { #if SZ_USE_ICE - sz_copy_avx512(target, source, length); + sz_copy_skylake(target, source, length); #elif SZ_USE_HASWELL - sz_copy_avx2(target, source, length); + sz_copy_haswell(target, source, length); #elif SZ_USE_NEON sz_copy_neon(target, source, length); #else @@ -1269,9 +1269,9 @@ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { #if SZ_USE_ICE - sz_move_avx512(target, source, length); + sz_move_skylake(target, source, length); #elif SZ_USE_HASWELL - sz_move_avx2(target, source, length); + sz_move_haswell(target, source, length); #elif SZ_USE_NEON sz_move_neon(target, source, length); #else diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index 5451c95f..943f7f35 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -16,6 +16,7 @@ #ifndef STRINGZILLA_SIMILARITY_H_ #define STRINGZILLA_SIMILARITY_H_ +#include "find.h" #include "types.h" #ifdef __cplusplus @@ -183,6 +184,20 @@ SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // sz_error_cost_t const *subs, sz_error_cost_t gap, // sz_memory_allocator_t *alloc); +#if SZ_USE_ICE + +SZ_INTERNAL sz_size_t sz_edit_distance_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); + +SZ_INTERNAL sz_ssize_t sz_alignment_score_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc); + +#endif + #pragma endregion // Core API #pragma region Serial Implementation @@ -200,8 +215,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // } // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); + _sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); + _sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); sz_unused(longer_length && bound); // We are going to store 3 diagonals of the matrix. @@ -269,7 +284,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: + * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in + * extra: * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. * = 2400 bytes of memory or @b 12x memory amplification! @@ -302,10 +318,13 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // // If the strings contain Unicode characters, let's estimate the max character width, // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } + sz_charset_t ascii_charset; + sz_charset_init_ascii(&ascii_charset); + sz_charset_invert(&ascii_charset); + int const longer_is_ascii = sz_find_charset_serial(longer, longer_length, &ascii_charset) == SZ_NULL_CHAR; + int const shorter_is_ascii = sz_find_charset_serial(shorter, shorter_length, &ascii_charset) == SZ_NULL_CHAR; + int const will_convert_to_unicode = can_be_unicode == sz_true_k && (!longer_is_ascii || !shorter_is_ascii); + if (will_convert_to_unicode) { buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); } else { can_be_unicode = sz_false_k; } // If the allocation fails, return the maximum distance. @@ -619,19 +638,19 @@ SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // /** * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. + * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop + * cycles. Supports an early exit, if the distance is bounded. Keeps all of the data and Levenshtein matrices skew + * diagonal in just a couple of registers. Benefits from the @b `vpermb` instructions, that can rotate the bytes + * across the entire ZMM register. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); + _sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); + _sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. @@ -792,9 +811,9 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // * - source code analysis, assuming most lines are either under 80 or under 120 characters long. * - DNA sequence alignment, as most short reads are 50-300 characters long. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; @@ -812,9 +831,9 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // * This is the largest space-efficient variant, as strings beyond 255 characters may require * 16-bit accumulators, which would be a significant bottleneck. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; @@ -833,9 +852,9 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // * This is the largest space-efficient variant, as strings beyond 255 characters may require * 16-bit accumulators, which would be a significant bottleneck. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; @@ -850,17 +869,17 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // * * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; } -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { sz_unused(shorter && longer && bound && alloc); @@ -874,8 +893,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // // TODO: Generalize! sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); + _sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); + _sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); sz_unused(longer_length && bound && max_length); #if 0 @@ -1017,7 +1036,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // return 0; } -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // +SZ_INTERNAL sz_size_t sz_edit_distance_ice( // sz_cptr_t shorter, sz_size_t shorter_length, // sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { @@ -1044,10 +1063,10 @@ SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // // Dispatch the right implementation based on the length of the strings. if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // + return _sz_edit_distance_skewed_diagonals_upto63_ice( // shorter, shorter_length, longer, longer_length, bound); // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // + // return _sz_edit_distance_skewed_diagonals_upto65k_ice( // // shorter, shorter_length, longer, longer_length, bound, alloc); else return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); @@ -1061,9 +1080,9 @@ SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // * * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of + * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string + * with a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character + * against a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of * a 256 x 256 matrix, but from a single row! */ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // @@ -1091,7 +1110,7 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // sz_size_t const max_length = 256ull * 256ull * 256ull; sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); + _sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); sz_unused(longer_length && max_length); sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; @@ -1099,7 +1118,7 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // sz_i32_t *previous_distances = distances; sz_i32_t *current_distances = previous_distances + n; - // Intialize the first row of the Levenshtein matrix with `iota`. + // Initialize the first row of the Levenshtein matrix with `iota`. for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; @@ -1135,8 +1154,9 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); + // sz_ssize_t cost_substitution = previous_distances[idx_longer] + + // row_subs[longer_unsigned[idx_longer]]; current_distances[idx_longer + 1] = + // sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); // } // // Given the complexity of handling the data-dependency between consecutive insertion cost computations @@ -1201,9 +1221,10 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // // "experimental" section. // // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); + // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + + // gap); current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + // + gap); current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], + // last_in_row + gap); // ... yet this approach is also quite expensive. for (int i = 0; i != 16; ++i) current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); @@ -1345,7 +1366,7 @@ SZ_DYNAMIC sz_size_t sz_edit_distance( // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { #if SZ_USE_ICE - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); + return sz_edit_distance_ice(a, a_length, b, b_length, bound, alloc); #else return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); #endif diff --git a/include/stringzilla/small_string.h b/include/stringzilla/small_string.h index ba823901..c5c70773 100644 --- a/include/stringzilla/small_string.h +++ b/include/stringzilla/small_string.h @@ -261,7 +261,7 @@ SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const } SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); + _sz_assert(string && "String can't be SZ_NULL."); // Only 8 + 1 + 1 need to be initialized. string->internal.start = &string->internal.chars[0]; @@ -275,7 +275,7 @@ SZ_PUBLIC void sz_string_init(sz_string_t *string) { SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); + _sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); // Initialize the string to zeros for safety. string->words[1] = 0; string->words[2] = 0; @@ -292,14 +292,14 @@ SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, string->external.length = length; string->external.space = space_needed; } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); + _sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); string->external.start[length] = 0; return string->external.start; } SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); + _sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); sz_size_t new_space = new_capacity + 1; if (new_space <= _SZ_STRING_INTERNAL_SPACE) return string->external.start; @@ -309,7 +309,7 @@ SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity sz_size_t string_space; sz_bool_t string_is_external; sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); + _sz_assert(new_space > string_space && "New space must be larger than current."); sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); if (!new_start) return SZ_NULL_CHAR; @@ -327,7 +327,7 @@ SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); + _sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); sz_ptr_t string_start; sz_size_t string_length; @@ -356,7 +356,7 @@ SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_alloca SZ_PUBLIC sz_ptr_t sz_string_expand( // sz_string_t *string, sz_size_t offset, sz_size_t added_length, sz_memory_allocator_t *allocator) { - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); + _sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); sz_ptr_t string_start; sz_size_t string_length; @@ -393,7 +393,7 @@ SZ_PUBLIC sz_ptr_t sz_string_expand( // SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - sz_assert(string && "String can't be SZ_NULL."); + _sz_assert(string && "String can't be SZ_NULL."); sz_ptr_t string_start; sz_size_t string_length; diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index c0b1b369..ba36975c 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -41,6 +41,15 @@ #define STRINGZILLA_VERSION_MINOR 11 #define STRINGZILLA_VERSION_PATCH 0 +#include "compare.h" // `sz_equal`, `sz_order` +#include "find.h" // `sz_find`, `sz_find_charset`, `sz_rfind` +#include "hash.h" // `sz_checksum`, `sz_hash`, `sz_hashes` +#include "memory.h" // `sz_copy`, `sz_move`, `sz_fill` +#include "similarity.h" // `sz_edit_distance`, `sz_alignment_score` +#include "small_string.h" // `sz_string_t`, `sz_string_init`, `sz_string_free` +#include "sort.h" // `sz_sort`, `sz_sort_partial`, `sz_partition` +#include "types.h" // `sz_size_t`, `sz_bool_t`, `sz_ordering_t` + #ifdef __cplusplus extern "C" { #endif @@ -49,20 +58,18 @@ extern "C" { * @brief Enumeration of SIMD capabilities of the target architecture. * Used to introspect the supported functionality of the dynamic library. */ -typedef enum sz_capability_t { - sz_cap_serial_k = 1, /// Serial (non-SIMD) capability - sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability +typedef enum { + sz_cap_serial_k = 1, ///< Serial (non-SIMD) capability + sz_cap_any_k = 0x7FFFFFFF, ///< Mask representing any capability with `INT_MAX` - sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability - sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used - sz_cap_arm_sve2_k = 1 << 12, - sz_cap_arm_sve2p1_k = 1 << 13, - sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability - sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability - sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability - sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability - sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability - sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability + sz_cap_haswell_k = 1 << 10, ///< x86 AVX2 capability with FMA and F16C extensions + sz_cap_skylake_k = 1 << 11, ///< x86 AVX512 baseline capability + sz_cap_ice_k = 1 << 12, ///< x86 AVX512 capability with advanced integer algos + + sz_cap_neon_k = 1 << 20, ///< ARM NEON baseline capability + sz_cap_sve_k = 1 << 21, ///< ARM SVE baseline capability + sz_cap_sve2_k = 1 << 22, ///< ARM SVE2 capability + sz_cap_sve2p1_k = 1 << 23, ///< ARM SVE2p1 capability } sz_capability_t; @@ -72,6425 +79,7 @@ typedef enum sz_capability_t { */ SZ_DYNAMIC sz_capability_t sz_capabilities(void); -/** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. - * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap - * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. - */ -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); - -/** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. - */ -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); - -/** - * @brief Initializes a string class instance to an empty value. - */ -SZ_PUBLIC void sz_string_init(sz_string_t *string); - -/** - * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, - * alternative being - allocated in a remote region of the heap. - */ -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); - -/** - * @brief Unpacks the opaque instance of a string class into its components. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. - * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. - */ -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external); - -/** - * @brief Unpacks only the start and length of the string. - * Recommended to use only in read-only operations. - * - * @param string String to unpack. - * @param start Pointer to the start of the string. - * @param length Number of bytes in the string, before the SZ_NULL character. - */ -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); - -/** - * @brief Constructs a string of a given ::length with noisy contents. - * Use the returned character pointer to populate the string. - * - * @param string String to initialize. - * @param length Number of bytes in the string, before the SZ_NULL character. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); - -/** - * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. - * This is beneficial, if several insertions are expected, and we want to minimize allocations. - * - * @param string String to grow. - * @param new_capacity The number of characters to reserve space for, including existing ones. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); - -/** - * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. - * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. - * Similar to `sz_string_reserve`, but changes the length of the ::string. - * - * @param string String to grow. - * @param offset Offset of the first byte to reserve space for. - * If provided offset is larger than the length, it will be capped. - * @param added_length The number of new characters to reserve space for. - * @param allocator Memory allocator to use for the allocation. - * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. - */ -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator); - -/** - * @brief Removes a range from a string. Changes the length, but not the capacity. - * Performs no allocations or deallocations and can't fail. - * - * @param string String to clean. - * @param offset Offset of the first byte to remove. - * @param length Number of bytes to remove. Out-of-bound ranges will be capped. - * @return Number of bytes removed. - */ -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); - -/** - * @brief Shrinks the string to fit the current length, if it's allocated on the heap. - * It's the reverse operation of ::sz_string_reserve. - * - * @param string String to shrink. - * @param allocator Memory allocator to use for the allocation. - * @return Whether the operation was successful. The only failures can come from the allocator. - * On failure, the string will remain unchanged. - */ -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); - -/** - * @brief Frees the string, if it's allocated on the heap. - * If the string is on the stack, the function clears/resets the state. - */ -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); - -#pragma endregion - -#pragma region Fast Substring Search API - -typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); - -/** - * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S - * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. - * - * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S - * Aarch64 implementation: missing - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); - -/** - * @brief Locates first matching substring. - * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. - * Similar to `strstr(haystack, needle)` in LibC, but requires known length. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. - */ -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Locates the last matching substring. - * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); - -/** - * @brief Finds the first character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** - * @brief Finds the last character present from the ::set, present in ::text. - * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. - * - * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". - * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); - -#pragma endregion - -#pragma region String Similarity Measures API - -/** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); - -/** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for the distance, the `bound` if was exceeded. - * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance - */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -/** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_size_t bound); - -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); - -/** - * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. - * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `bound` if it was exceeded or `SZ_SIZE_MAX` if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @param bound Upper bound on the distance, that allows us to exit early. - * If zero is passed, the maximum possible distance will be equal to the length of the longer input. - * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` - * if the memory allocation failed. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance - * @see https://en.wikipedia.org/wiki/Levenshtein_distance - */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); - -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); - -/** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. - * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. - * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. - * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. - * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default - * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm - */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); - -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); - -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); - -#pragma endregion - -#pragma region Convenience API - -/** - * @brief Finds the first character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the first character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -/** - * @brief Finds the last character in the haystack, that is @b not present in the needle. - * Convenience function, reused across different language bindings. - * @see sz_find_charset - */ -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); - -#pragma endregion - -#pragma region String Sequences API - -struct sz_sequence_t; - -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); - -typedef struct sz_sequence_t { - sz_sorted_idx_t *order; - sz_size_t count; - sz_sequence_member_start_t get_start; - sz_sequence_member_length_t get_length; - void const *handle; -} sz_sequence_t; - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, - sz_sequence_t *sequence); - -/** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); - -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. - * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. - */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); - -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); - -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); - -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); - -#pragma endregion - -/* - * Hardware feature detection. - * All of those can be controlled by the user. - */ -#ifndef SZ_USE_ICE -#ifdef __AVX512BW__ -#define SZ_USE_ICE 1 -#else -#define SZ_USE_ICE 0 -#endif -#endif - -#ifndef SZ_USE_HASWELL -#ifdef __AVX2__ -#define SZ_USE_HASWELL 1 -#else -#define SZ_USE_HASWELL 0 -#endif -#endif - -#ifndef SZ_USE_NEON -#ifdef __ARM_NEON -#define SZ_USE_NEON 1 -#else -#define SZ_USE_NEON 0 -#endif -#endif - -#ifndef SZ_USE_SVE -#ifdef __ARM_FEATURE_SVE -#define SZ_USE_SVE 1 -#else -#define SZ_USE_SVE 0 -#endif -#endif - -/* - * Include hardware-specific headers. - */ -#if SZ_USE_ICE || SZ_USE_HASWELL -#include -#endif // SZ_USE_X86... -#if SZ_USE_NEON -#if !defined(_MSC_VER) -#include -#endif -#include -#endif // SZ_USE_NEON -#if SZ_USE_SVE -#if !defined(_MSC_VER) -#include -#endif -#endif // SZ_USE_SVE - -#pragma region Hardware Specific API - -#if SZ_USE_ICE - -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_ice(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_HASWELL -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle); -#endif - -#if SZ_USE_NEON -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#if SZ_USE_SVE -/** @copydoc sz_equal */ -SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); -/** @copydoc sz_order */ -SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); -/** @copydoc sz_copy */ -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_move */ -SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); -/** @copydoc sz_fill */ -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); -/** @copydoc sz_find_byte */ -SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_rfind_byte */ -SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); -/** @copydoc sz_find */ -SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_rfind */ -SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -#endif - -#pragma endregion - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -/* - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - * - * This is where we the actual implementation begins. - * The rest of the file is hidden from the public API. - * - ********************************************************************************************************************** - ********************************************************************************************************************** - ********************************************************************************************************************** - */ - -#pragma region Compiler Extensions and Helper Functions - -#pragma GCC visibility push(hidden) - -/** - * @brief Helper-macro to mark potentially unused variables. - */ -#define sz_unused(x) ((void)(x)) - -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ -#define sz_bitcast(type, value) (*((type *)&(value))) - -/** - * @brief Defines `SZ_NULL`, analogous to `NULL`. - * The default often comes from locale.h, stddef.h, - * stdio.h, stdlib.h, string.h, time.h, or wchar.h. - */ -#ifdef __GNUG__ -#define SZ_NULL __null -#define SZ_NULL_CHAR __null -#else -#define SZ_NULL ((void *)0) -#define SZ_NULL_CHAR ((char *)0) -#endif - -/** - * @brief Cache-line width, that will affect the execution of some algorithms, - * like equality checks and relative order computing. - */ -#define SZ_CACHE_LINE_WIDTH (64) // bytes - -/** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. - * @note If you want to catch it, put a breakpoint at @b `__GI_exit` - */ -#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` -#include // `EXIT_FAILURE` -SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { - fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); - exit(EXIT_FAILURE); -} -#define sz_assert(condition) \ - do { \ - if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ - } while (0) -#else -#define sz_assert(condition) ((void)(condition)) -#endif - -/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. - * The following section of compiler intrinsics comes in 2 flavors. - */ -#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL -#include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. -#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { - x = x - ((x >> 1) & 0x5555555555555555ull); - x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); - return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; -} -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 1) == 0) { n++, x >>= 1; } - return n; -} -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); - int n = 0; - while ((x & 0x80000000u) == 0) { n++, x <<= 1; } - return n; -} -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { - x = x - ((x >> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >> 2) & 0x33333333); - return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; -} -#else -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } -#endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given -#pragma intrinsic(_byteswap_uint64) -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } -#pragma intrinsic(_byteswap_ulong) -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } -#else -SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } -SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } -SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } -SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } -SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` -SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } -SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } -#endif - -SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } - -/** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. - * - * Similar to `_mm_blend_epi16` intrinsic on x86. - * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. - * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching - */ -SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } - -/* - * Efficiently computing the minimum and maximum of two or three values can be tricky. - * The simple branching baseline would be: - * - * x < y ? x : y // can replace with 1 conditional move - * - * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. - * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function - * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax - * Using only bit-shifts for singed integers it would be: - * - * y + ((x - y) & (x - y) >> 31) // 4 unique operations - * - * Alternatively, for any integers using multiplication: - * - * (x > y) * y + (x <= y) * x // 5 operations - * - * Alternatively, to avoid multiplication: - * - * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations - */ -#define sz_min_of_two(x, y) (x < y ? x : y) -#define sz_max_of_two(x, y) (x < y ? y : x) -#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) -#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } - -/** @brief Branchless minimum function for two signed 32-bit integers. */ -SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } - -/** - * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. - */ -SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, - sz_size_t *normalized_offset, sz_size_t *normalized_length) { - // TODO: Remove branches. - // Normalize negative indices - if (start < 0) start += length; - if (end < 0) end += length; - - // Clamp indices to a valid range - if (start < 0) start = 0; - if (end < 0) end = 0; - if (start > (sz_ssize_t)length) start = length; - if (end > (sz_ssize_t)length) end = length; - - // Ensure start <= end - if (start > end) start = end; - - *normalized_offset = start; - *normalized_length = end - start; -} - -/** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. - */ -SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); - sz_size_t leading_zeros = sz_u64_clz(x); - return 63 - leading_zeros; -} - -/** - * @brief Compute the smallest power of two greater than or equal to ::x. - */ -SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; -#if _SZ_IS_64_BIT - x |= x >> 32; -#endif - x++; - return x; -} - -/** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. - * - * There is a well known SWAR sequence for that known to chess programmers, - * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. - * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating - * https://lukas-prokop.at/articles/2021-07-23-transpose - */ -SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { - sz_u64_t t; - t = x ^ (x << 36); - x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); - t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); - x ^= t ^ (t >> 18); - t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); - x ^= t ^ (t >> 9); - return x; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load - */ -typedef union sz_u16_vec_t { - sz_u16_t u16; - sz_u8_t u8s[2]; -} sz_u16_vec_t; - -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u16_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u16_vec_t *)ptr); -#else - return *((__unaligned sz_u16_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load - */ -typedef union sz_u32_vec_t { - sz_u32_t u32; - sz_u16_t u16s[2]; - sz_u8_t u8s[4]; -} sz_u32_vec_t; - -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u32_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u32_vec_t *)ptr); -#else - return *((__unaligned sz_u32_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; - return *result; -#endif -} - -/** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load - */ -typedef union sz_u64_vec_t { - sz_u64_t u64; - sz_u32_t u32s[2]; - sz_u16_t u16s[4]; - sz_u8_t u8s[8]; -} sz_u64_vec_t; - -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ -SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { -#if !SZ_USE_MISALIGNED_LOADS - sz_u64_vec_t result; - result.u8s[0] = ptr[0]; - result.u8s[1] = ptr[1]; - result.u8s[2] = ptr[2]; - result.u8s[3] = ptr[3]; - result.u8s[4] = ptr[4]; - result.u8s[5] = ptr[5]; - result.u8s[6] = ptr[6]; - result.u8s[7] = ptr[7]; - return result; -#elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. - return *((sz_u64_vec_t *)ptr); -#else - return *((__unaligned sz_u64_vec_t *)ptr); -#endif -#else - __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; - return *result; -#endif -} - -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ -SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { - sz_size_t capacity; - sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); - sz_size_t consumed_capacity = sizeof(sz_size_t); - if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; - return (sz_ptr_t)handle + consumed_capacity; -} - -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ -SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(start && length && handle); -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - -#pragma GCC visibility pop -#pragma endregion - -#pragma region Serial Implementation - -#if !SZ_AVOID_LIBC -#include // `fprintf` -#include // `malloc`, `EXIT_FAILURE` - -SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { - sz_unused(handle); - return malloc(length); -} -SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { - sz_unused(handle && length); - free(start); -} - -#endif - -SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { -#if !SZ_AVOID_LIBC - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; - alloc->free = (sz_memory_free_t)_sz_memory_free_default; -#else - alloc->allocate = (sz_memory_allocate_t)SZ_NULL; - alloc->free = (sz_memory_free_t)SZ_NULL; -#endif - alloc->handle = SZ_NULL; -} - -SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { - // The logic here is simple - put the buffer length in the first slots of the buffer. - // Later use it for bounds checking. - alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; - alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; - alloc->handle = &buffer; - sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { - for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Warray-bounds" - sz_cptr_t const end = text; - for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; - return SZ_NULL_CHAR; -#pragma GCC diagnostic pop -} - -/** - * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: - * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; - * for (; a != min_end; ++a, ++b) - * if (*a != *b) return ordering_lookup[*a < *b]; - * That, however, introduces a data-dependency. - * A cleaner option is to perform two comparisons and a subtraction. - * One instruction more, but no data-dependency. - */ -#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) - -SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); - sz_size_t min_length = a_shorter ? a_length : b_length; - sz_cptr_t min_end = a + min_length; -#if SZ_USE_MISALIGNED_LOADS && !_SZ_IS_BIG_ENDIAN - for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { - a_vec = sz_u64_load(a); - b_vec = sz_u64_load(b); - if (a_vec.u64 != b_vec.u64) - return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); - } -#endif - for (; a != min_end; ++a, ++b) - if (*a != *b) return _sz_order_scalars(*a, *b); - - // If the strings are equal up to `min_end`, then the shorter string is smaller - return _sz_order_scalars(a_length, b_length); -} - -/** - * @brief Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each byte is set. - // For that take the bottom 7 bits of each byte, add one to them, - // and if this sets the top bit to one, then all the 7 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memchr(haystack, needle[0], haystack_length)`. - */ -SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_end = h + h_length; - -#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - match_vec.u64 = 0; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h + 8 <= h_end; h += 8) { - h_vec.u64 = *(sz_u64_t const *)h; - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; - } -#endif - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - * Identical to `memrchr(haystack, needle[0], haystack_length)`. - */ -sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - if (!h_length) return SZ_NULL_CHAR; - sz_cptr_t const h_start = h; - - // Reposition the `h` pointer to the end, as we will be walking backwards. - h = h + h_length - 1; - -#if !_SZ_IS_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. -#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) - if (*h == *n) return h; -#endif - - // Broadcast the n into every byte of a 64-bit integer to use SWAR - // techniques and process eight characters at a time. - sz_u64_vec_t h_vec, n_vec, match_vec; - n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; - for (; h >= h_start + 7; h -= 8) { - h_vec.u64 = *(sz_u64_t const *)(h - 7); - match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); - if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; - } -#endif - - for (; h >= h_start; --h) - if (*h == *n) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 2Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 2byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 2byte is set. - // For that take the bottom 15 bits of each 2byte, add one to them, - // and if this sets the top bit to one, then all the 15 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. - sz_assert(h_length >= 2 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; -#endif - - sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; - n_vec.u64 *= 0x0001000100010001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. - for (; h + 9 <= h_end; h += 8) { - h_even_vec.u64 = *(sz_u64_t *)h; - h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); - matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); - matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); - - matches_even_vec.u64 >>= 8; - if (matches_even_vec.u64 + matches_odd_vec.u64) { - sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 2 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 4Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 4byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 4 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; -#endif - - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; - n_vec.u64 *= 0x0000000100000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. - // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u32_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { - matches0_vec.u64 >>= 24; - matches1_vec.u64 >>= 16; - matches2_vec.u64 >>= 8; - sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 4 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief 3Byte-level equality comparison between two 64-bit integers. - * @return 64-bit integer, where every top bit in each 3byte signifies a match. - */ -SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { - sz_u64_vec_t vec; - vec.u64 = ~(a.u64 ^ b.u64); - // The match is valid, if every bit within each 4byte is set. - // For that take the bottom 31 bits of each 4byte, add one to them, - // and if this sets the top bit to one, then all the 31 bits are ones as well. - vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); - return vec; -} - -/** - * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. - * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. - */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - - // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. - sz_assert(h_length >= 3 && "The haystack is too short."); - sz_cptr_t const h_end = h + h_length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; -#endif - - // We fetch 12 - sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; - sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; - sz_u64_vec_t n_vec; - n_vec.u64 = 0; - n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; - n_vec.u64 *= 0x0000000001000001ull; // broadcast - - // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. - // We load the subsequent two-byte word as well. - sz_u64_t h_page_current, h_page_next; - for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { - h_page_current = *(sz_u64_t *)h; - h_page_next = *(sz_u16_t *)(h + 8); - h0_vec.u64 = (h_page_current); - h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); - h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); - h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); - h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); - matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); - matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); - matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); - matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); - matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); - - if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { - matches0_vec.u64 >>= 16; - matches1_vec.u64 >>= 8; - matches3_vec.u64 <<= 8; - matches4_vec.u64 <<= 16; - sz_u64_t match_indicators = - matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; - return h + sz_u64_ctz(match_indicators) / 8; - } - } - - for (; h + 3 <= h_end; ++h) - if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - // Several popular string matching algorithms are using a bad-character shift table. - // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html - // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html - // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the last `n_length - 1` bytes. - for (sz_size_t i = 0; i <= h_length - n_length;) { - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - i += bad_shift_table.jumps[h[i + n_length - 1]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. - * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // - sz_cptr_t n_chars, sz_size_t n_length) { - sz_assert(n_length <= 256 && "The pattern is too long."); - union { - sz_u8_t jumps[256]; - sz_u64_vec_t vecs[64]; - } bad_shift_table; - - // Let's initialize the table using SWAR to the total length of the string. - sz_u8_t const *h = (sz_u8_t const *)h_chars; - sz_u8_t const *n = (sz_u8_t const *)n_chars; - { - sz_u64_vec_t n_length_vec; - n_length_vec.u64 = n_length; - n_length_vec.u64 *= 0x0101010101010101ull; // broadcast - for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; - for (sz_size_t i = 0; i + 1 < n_length; ++i) - bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); - } - - // Another common heuristic is to match a few characters from different parts of a string. - // Raita suggests to use the first two, the last, and the middle character of the pattern. - sz_u32_vec_t h_vec, n_vec; - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into an unsigned integer. - n_vec.u8s[0] = n[offset_first]; - n_vec.u8s[1] = n[offset_first + 1]; - n_vec.u8s[2] = n[offset_mid]; - n_vec.u8s[3] = n[offset_last]; - - // Scan through the whole haystack, skipping the first `n_length - 1` bytes. - for (sz_size_t j = 0; j <= h_length - n_length;) { - sz_size_t i = h_length - n_length - j; - h_vec.u8s[0] = h[i + offset_first]; - h_vec.u8s[1] = h[i + offset_first + 1]; - h_vec.u8s[2] = h[i + offset_mid]; - h_vec.u8s[3] = h[i + offset_last]; - if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; - j += bad_shift_table.jumps[h[i]]; - } - return SZ_NULL_CHAR; -} - -/** - * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle - * using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_prefix, sz_size_t prefix_length) { - - sz_size_t suffix_length = n_length - prefix_length; - while (1) { - sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = h_length - (found - h); - if (remaining < n_length) return SZ_NULL_CHAR; - if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; - - // Adjust the position. - h = found + 1; - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -/** - * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the - * needle using a given search function, and then verifies the remaining part of the needle. - */ -SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, - sz_find_t find_suffix, sz_size_t suffix_length) { - - sz_size_t prefix_length = n_length - suffix_length; - while (1) { - sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); - if (!found) return SZ_NULL_CHAR; - - // Verify the remaining part of the needle - sz_size_t remaining = found - h; - if (remaining < prefix_length) return SZ_NULL_CHAR; - if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; - - // Adjust the position. - h_length = remaining - 1; - } - - // Unreachable, but helps silence compiler warnings: - return SZ_NULL_CHAR; -} - -SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); -} - -SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); -} - -SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, - sz_size_t n_length) { - return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); -} - -SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - -#if _SZ_IS_BIG_ENDIAN - sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); -#else - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - (n_length > 1) + (n_length > 2) + (n_length > 3) + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 4) + - // For longer needles - use skip tables. - (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); -#endif -} - -SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - - sz_find_t backends[] = { - // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, - // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, - // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, - // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, - }; - - return backends[ - // For very short strings brute-force SWAR makes sense. - 0 + - // To avoid constructing the skip-table, let's use the prefixed approach. - (n_length > 1) + - // For longer needles - use skip tables. - (n_length > 256)](h, h_length, n, n_length); -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize to remove the following asserts! - sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); - sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); - sz_unused(longer_length && bound); - - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; - sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - sz_size_t *previous_distances = distances; - sz_size_t *current_distances = previous_distances + n; - sz_size_t *next_distances = previous_distances + n * 2; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != n; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t i = 0; i + 2 < next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[next_diagonal_index - i - 2] != longer[i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = next_diagonal_index; - // Perform a circular rotation of those buffers, to reuse the memory. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - sz_size_t diagonals_count = n + n - 1; - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length; ++i) { - sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_diagonal_index - n + i]; - sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; - sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; - next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); - } - // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_size_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -/** - * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. - */ -typedef enum { - sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. - sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. - sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. - sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. - sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. -} sz_rune_length_t; - -typedef sz_u32_t sz_rune_t; - -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ -SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { - sz_u8_t const *current = (sz_u8_t const *)utf8; - sz_u8_t leading_byte = *current++; - sz_rune_t ch; - sz_rune_length_t ch_length; - - // TODO: This can be made entirely branchless using 32-bit SWAR. - if (leading_byte < 0x80) { - // Single-byte rune (0xxxxxxx) - ch = leading_byte; - ch_length = sz_utf8_rune_1byte_k; - } - else if ((leading_byte & 0xE0) == 0xC0) { - // Two-byte rune (110xxxxx 10xxxxxx) - ch = (leading_byte & 0x1F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_2bytes_k; - } - else if ((leading_byte & 0xF0) == 0xE0) { - // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x0F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_3bytes_k; - } - else if ((leading_byte & 0xF8) == 0xF0) { - // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - ch = (leading_byte & 0x07) << 18; - ch |= (*current++ & 0x3F) << 12; - ch |= (*current++ & 0x3F) << 6; - ch |= (*current++ & 0x3F); - ch_length = sz_utf8_rune_4bytes_k; - } - else { - // Invalid UTF8 rune. - ch = 0; - ch_length = sz_utf8_invalid_k; - } - *code = ch; - *code_length = ch_length; -} - -/** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. - * @return The length in the number of codepoints. - */ -SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { - sz_cptr_t const end = utf8 + utf8_length; - sz_size_t count = 0; - sz_rune_length_t rune_length; - for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); - return count; -} - -/** - * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. - * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, - * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. - * - * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: - * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. - * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. - * = 2400 bytes of memory or @b 12x memory amplification! - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // A good idea may be to dispatch different kernels for different string lengths. - // Like using `uint8_t` counters for strings under 255 characters long. - // Good in theory, this results in frequent upcasts and downcasts in serial code. - // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. - // So one must be very cautious with such optimizations. - typedef sz_size_t _distance_t; - - // Compute the number of columns in our Levenshtein matrix. - sz_size_t const n = shorter_length + 1; - - // If a buffering memory-allocator is provided, this operation is practically free, - // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. - sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); - - // If the strings contain Unicode characters, let's estimate the max character width, - // and use it to allocate a larger buffer to decode UTF8. - if ((can_be_unicode == sz_true_k) && - (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { - buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); - } - else { can_be_unicode = sz_false_k; } - - // If the allocation fails, return the maximum distance. - sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; - - // Let's export the UTF8 sequence into the newly allocated buffer at the end. - if (can_be_unicode == sz_true_k) { - sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); - sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; - // Export the UTF8 sequences into the newly allocated buffer. - longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); - shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); - longer = (sz_cptr_t)longer_utf32; - shorter = (sz_cptr_t)shorter_utf32; - } - - // Let's parameterize the core logic for different character types and distance types. -#define _wagner_fisher_unbounded(_distance_t, _char_t) \ - /* Now let's cast our pointer to avoid it in subsequent sections. */ \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - /* The main loop of the algorithm with quadratic complexity. */ \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - /* Using pure pointer arithmetic is faster than iterating with an index. */ \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ - /* saving one increment operation. */ \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - /* ? It might be a good idea to enforce branchless execution here. */ \ - /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - } \ - /* Swap `previous_distances` and `current_distances` pointers. */ \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - /* Cache scalar before `free` call. */ \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return result; - - // Let's define a separate variant for bounded distance computation. - // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. -#define _wagner_fisher_bounded(_distance_t, _char_t) \ - _char_t const *const longer_chars = (_char_t const *)longer; \ - _char_t const *const shorter_chars = (_char_t const *)shorter; \ - _distance_t *previous_distances = (_distance_t *)buffer; \ - _distance_t *current_distances = previous_distances + n; \ - for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ - for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ - _char_t const longer_char = longer_chars[idx_longer]; \ - _char_t const *shorter_ptr = shorter_chars; \ - _distance_t const *previous_ptr = previous_distances; \ - _distance_t *current_ptr = current_distances; \ - _distance_t *const current_end = current_ptr + shorter_length; \ - current_ptr[0] = idx_longer + 1; \ - /* Initialize min_distance with a value greater than bound */ \ - _distance_t min_distance = bound - 1; \ - for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ - _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ - _distance_t cost_deletion = previous_ptr[1]; \ - _distance_t cost_insertion = current_ptr[0]; \ - current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ - /* Keep track of the minimum distance seen so far in this row */ \ - min_distance = sz_min_of_two(current_ptr[1], min_distance); \ - } \ - /* If the minimum distance in this row exceeded the bound, return early */ \ - if (min_distance >= bound) { \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ - } \ - _distance_t *temporary = previous_distances; \ - previous_distances = current_distances; \ - current_distances = temporary; \ - } \ - sz_size_t result = previous_distances[shorter_length]; \ - alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); - - // Dispatch the actual computation. - if (!bound) { - if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } - } - else { - if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } - else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } - } -} - -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Skip the matching prefixes and suffixes, they won't affect the distance. - for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; - longer != a_end && shorter != b_end && *longer == *shorter; - ++longer, ++shorter, --longer_length, --shorter_length); - for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; - --longer_length, --shorter_length); - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. - if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, - alloc); -} - -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t n = shorter_length + 1; - sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; - sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); - sz_ssize_t *previous_distances = distances; - sz_ssize_t *current_distances = previous_distances + n; - - for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) - previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; - for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { - current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; - - // Initialize min_distance with a value greater than bound - sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; - sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; - sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; - current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[shorter_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_size_t const min_length = sz_min_of_two(a_length, b_length); - sz_size_t const max_length = sz_max_of_two(a_length, b_length); - sz_cptr_t const a_end = a + min_length; - bound = bound == 0 ? max_length : bound; - - // Walk through both strings using SWAR and counting the number of differing characters. - sz_size_t distance = max_length - min_length; -#if SZ_USE_MISALIGNED_LOADS && !_SZ_IS_BIG_ENDIAN - if (min_length >= SZ_SWAR_THRESHOLD) { - sz_u64_vec_t a_vec, b_vec, match_vec; - for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { - a_vec.u64 = sz_u64_load(a).u64; - b_vec.u64 = sz_u64_load(b).u64; - match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); - distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); - } - } -#endif - - for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); -} - -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - - sz_cptr_t const a_end = a + a_length; - sz_cptr_t const b_end = b + b_length; - sz_size_t distance = 0; - - sz_rune_t a_rune, b_rune; - sz_rune_length_t a_rune_length, b_rune_length; - - if (bound) { - for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - if (distance < bound) { - for (; a < a_end && distance < bound; a += a_rune_length, ++distance) - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - - for (; b < b_end && distance < bound; b += b_rune_length, ++distance) - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - } - else { - for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { - _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - distance += (a_rune != b_rune); - } - // If one string has more runes, we need to go through the tail. - for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); - for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); - } - return distance; -} - -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - -/** - * @brief Largest prime number that fits into 31 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - */ -#define SZ_U32_MAX_PRIME (2147483647u) - -/** - * @brief Largest prime number that fits into 64 bits. - * @see https://mersenneforum.org/showthread.php?t=3471 - * - * 2^64 = 18,446,744,073,709,551,616 - * this = 18,446,744,073,709,551,557 - * diff = 59 - */ -#define SZ_U64_MAX_PRIME (18446744073709551557ull) - -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} - -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { - static sz_u8_t const lowered[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return lowered[c]; -} - -/** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { - static sz_u8_t const upped[256] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // - 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // - 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // - }; - return upped[c]; -} - -/** - * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small - * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. - * - * @param divisor Integral value @b larger than one. - * @param number Integral value to divide. - */ -SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { - sz_assert(divisor > 1); - static sz_u16_t const multipliers[256] = { - 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, - 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, - 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, - 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, - 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, - 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, - 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, - 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, - 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, - 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, - 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, - 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, - 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, - 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, - 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, - 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, - }; - // This table can be avoided using a single addition and counting trailing zeros. - static sz_u8_t const shifts[256] = { - 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // - 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // - 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // - 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // - }; - sz_u32_t multiplier = multipliers[divisor]; - sz_u8_t shift = shifts[divisor]; - - sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); - sz_u16_t t = ((number - q) >> 1) + q; - return (sz_u8_t)(t >> shift); -} - -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); -} - -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; -} - -/** - * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. - * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. - */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { - - if (!length) return sz_true_k; - sz_u8_t const *h = (sz_u8_t const *)text; - sz_u8_t const *const h_end = h + length; - -#if !SZ_USE_MISALIGNED_LOADS - // Process the misaligned head, to void UB on unaligned 64-bit loads. - for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; -#endif - - // Validate eight bytes at once using SWAR. - sz_u64_vec_t text_vec; - for (; h + 8 <= h_end; h += 8) { - text_vec.u64 = *(sz_u64_t const *)h; - if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; - } - - // Handle the misaligned tail. - for (; h < h_end; ++h) - if (*h & 0x80ull) return sz_false_k; - return sz_true_k; -} - -SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - - sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); - - if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); - - else { - sz_assert(generator && "Expects a valid random generator"); - sz_u8_t divisor = (sz_u8_t)alphabet_size; - for (sz_cptr_t end = result + result_length; result != end; ++result) { - sz_u8_t random = generator(generator_user_data) & 0xFF; - sz_u8_t quotient = sz_u8_divide(random, divisor); - *result = alphabet[random - quotient * divisor]; - } - } -} - -#pragma endregion - -/* - * Serial implementation of string class operations. - */ -#pragma region Serial Implementation for the String Class - -SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { - // It doesn't matter if it's on stack or heap, the pointer location is the same. - return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); -} - -SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); -} - -SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, - sz_bool_t *is_external) { - sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; - sz_size_t is_big_mask = is_small - 1ull; - *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. - // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. - *length = string->external.length & (0x00000000000000FFull | is_big_mask); - // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. - *space = sz_u64_blend(_SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); - *is_external = (sz_bool_t)!is_small; -} - -SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { - // Tempting to say that the external.length is bitwise the same even if it includes - // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. - // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this - // the hard/correct way. - -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); -} - -SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { -#if SZ_USE_MISALIGNED_LOADS - // Dealing with StringZilla strings, we know that the `start` pointer always points - // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. - -#endif - // Alternatively, fall back to byte-by-byte comparison. - sz_ptr_t a_start, b_start; - sz_size_t a_length, b_length; - sz_string_range(a, &a_start, &a_length); - sz_string_range(b, &b_start, &b_length); - return sz_order(a_start, a_length, b_start, b_length); -} - -SZ_PUBLIC void sz_string_init(sz_string_t *string) { - sz_assert(string && "String can't be SZ_NULL."); - - // Only 8 + 1 + 1 need to be initialized. - string->internal.start = &string->internal.chars[0]; - // But for safety let's initialize the entire structure to zeros. - // string->internal.chars[0] = 0; - // string->internal.length = 0; - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; -} - -SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { - sz_size_t space_needed = length + 1; // space for trailing \0 - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - // Initialize the string to zeros for safety. - string->words[1] = 0; - string->words[2] = 0; - string->words[3] = 0; - // If we are lucky, no memory allocations will be needed. - if (space_needed <= _SZ_STRING_INTERNAL_SPACE) { - string->internal.start = &string->internal.chars[0]; - string->internal.length = (sz_u8_t)length; - } - else { - // If we are not lucky, we need to allocate memory. - string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); - if (!string->external.start) return SZ_NULL_CHAR; - string->external.length = length; - string->external.space = space_needed; - } - sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); - string->external.start[length] = 0; - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_size_t new_space = new_capacity + 1; - if (new_space <= _SZ_STRING_INTERNAL_SPACE) return string->external.start; - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - sz_assert(new_space > string_space && "New space must be larger than current."); - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // We may already be space-optimal, and in that case we don't need to do anything. - sz_size_t new_space = string_length + 1; - if (string_space == new_space || !string_is_external) return string->external.start; - - sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); - if (!new_start) return SZ_NULL_CHAR; - - sz_copy(new_start, string_start, string_length); - string->external.start = new_start; - string->external.space = new_space; - string->external.padding = 0; - string->external.length = string_length; - - // Deallocate the old string. - if (string_is_external) allocator->free(string_start, string_space, allocator->handle); - return string->external.start; -} - -SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, - sz_memory_allocator_t *allocator) { - - sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // The user intended to extend the string. - offset = sz_min_of_two(offset, string_length); - - // If we are lucky, no memory allocations will be needed. - if (string_length + added_length < string_space) { - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - // Even if the string is on the stack, the `+=` won't affect the tail of the string. - string->external.length += added_length; - } - // If we are not lucky, we need to allocate more memory. - else { - sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); - sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); - sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); - string_start = sz_string_reserve(string, new_space - 1, allocator); - if (!string_start) return SZ_NULL_CHAR; - - // Copy into the new buffer. - sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); - string_start[string_length + added_length] = 0; - string->external.length = string_length + added_length; - } - - return string_start; -} - -SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { - - sz_assert(string && "String can't be SZ_NULL."); - - sz_ptr_t string_start; - sz_size_t string_length; - sz_size_t string_space; - sz_bool_t string_is_external; - sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); - - // Normalize the offset, it can't be larger than the length. - offset = sz_min_of_two(offset, string_length); - - // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, - // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain - // exactly the delta between original and final length of this `string`. - length = sz_min_of_two(length, string_length - offset); - - // There are 2 common cases, that wouldn't even require a `memmove`: - // 1. Erasing the entire contents of the string. - // In that case `length` argument will be equal or greater than `length` member. - // 2. Removing the tail of the string with something like `string.pop_back()` in C++. - // - // In both of those, regardless of the location of the string - stack or heap, - // the erasing is as easy as setting the length to the offset. - // In every other case, we must `memmove` the tail of the string to the left. - if (offset + length < string_length) - sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); - - // The `string->external.length = offset` assignment would discard last characters - // of the on-the-stack string, but inplace subtraction would work. - string->external.length -= length; - string_start[string_length - length] = 0; - return length; -} - -SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { - if (!sz_string_is_on_stack(string)) - allocator->free(string->external.start, string->external.space, allocator->handle); - sz_string_init(string); -} - -#pragma endregion - -/* - * @brief Serial implementation for strings sequence processing. - */ -#pragma region Serial Implementation for Sequences - -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; -} - -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } - else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; - } - } -} - -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; - } -} - -SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, - sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } -} - -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; - } -} - -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); - } -} - -SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, - sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; - } - - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } - - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; - } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); -} - -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { - - if (!sequence->count) return; - - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; - } - - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. - // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; - // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } -} - -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if _SZ_IS_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; - } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif -} - -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if _SZ_IS_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} - -#pragma endregion - -/* - * @brief AVX2 implementation of the string search algorithms. - * Very minimalistic, but still faster than the serial implementation. - */ -#pragma region AVX2 Implementation - -#if SZ_USE_HASWELL -#pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 256-bit registers. - */ -typedef union sz_u256_vec_t { - __m256i ymm; - __m128i xmms[2]; - sz_u64_t u64s[4]; - sz_u32_t u32s[8]; - sz_u16_t u16s[16]; - sz_u8_t u8s[32]; -} sz_u256_vec_t; - -SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - char value_char = *(char *)&value; - __m256i value_vec = _mm256_set1_epi8(value_char); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); - // sz_fill_serial(target, length, value); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 32) sz_fill_serial(target, length, value); - // When the buffer is aligned, we can avoid any split-stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u16_t value16 = (sz_u16_t)value * 0x0101u; - sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; - sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; - } -} - -SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 32; target += 32, source += 32, length -= 32) - // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - // sz_copy_serial(target, source, length); - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); - if (length) sz_copy_serial(target, source, length); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else { - sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - - // Fill the head of the buffer. This part is much cleaner with AVX-512. - if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; - if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; - if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; - if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; - if (head_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - head_length -= 16; - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; target += 32, source += 32, body_length -= 32) - _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { - _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); - _mm256_store_si256((__m256i *)(target + body_length - 32), - _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); - } - if (body_length) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - } - - // Fill the tail of the buffer. This part is much cleaner with AVX-512. - sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); - if (tail_length & 16) - _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, - tail_length -= 16; - if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; - if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; - if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; - if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; - } -} - -SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { - for (; length >= 32; target += 32, source += 32, length -= 32) - _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); - while (length--) *(target++) = *(source++); - } - else { - // Jump to the end and walk backwards. - for (target += length, source += length; length >= 32; length -= 32) - _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); - while (length--) *(--target) = *(--source); - } -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} - -SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // We need to pull the lookup table into 8x YMM registers. - // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, - // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, - // so that we can at least compensate high latency with twice larger window and one more level of lookup. - sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // - lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // - lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // - lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; - - lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); - lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); - lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); - lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); - lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); - lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); - lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); - lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); - lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); - lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); - lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); - lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); - lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); - lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); - lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); - lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); - - // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. - sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; - - /// Top and bottom nibbles of the source are used separately. - sz_u256_vec_t source_vec, source_bot_vec; - sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, - blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; - - // Handling the head. - while (length >= 32) { - // Load and separate the nibbles of each byte in the source. - source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); - source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); - - // In the first round, we select using the 4th bit. - not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // - _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // - _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // - not_fourth_bit_vec.ymm); - - // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. - // The first round selects using the 3rd bit. - not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_32_to_63_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_third_bit_vec.ymm); - blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // - blended_96_to_127_vec.ymm, // - blended_64_to_95_vec.ymm, // - not_third_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_160_to_191_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_third_bit_vec.ymm); - blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // - blended_224_to_255_vec.ymm, // - blended_192_to_223_vec.ymm, // - not_third_bit_vec.ymm); - - // The second round selects using the 2nd bit. - not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_64_to_95_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_second_bit_vec.ymm); - blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // - blended_192_to_223_vec.ymm, // - blended_128_to_159_vec.ymm, // - not_second_bit_vec.ymm); - - // The third round selects using the 1st bit. - not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // - _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); - blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // - blended_128_to_159_vec.ymm, // - blended_0_to_31_vec.ymm, // - not_first_bit_vec.ymm); - - // And dump the result into the target. - _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); - source += 32, target += 32, length -= 32; - } - - // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + sz_u32_ctz(mask); - h += 32, h_length -= 32; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - int mask; - sz_u256_vec_t h_vec, n_vec; - n_vec.ymm = _mm256_set1_epi8(n[0]); - - while (h_length >= 32) { - h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); - mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); - if (mask) return h + h_length - 1 - sz_u32_clz(mask); - h_length -= 32; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - for (; h_length >= n_length + 32; h += 32, h_length -= 32) { - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_ctz(matches); - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into YMM registers. - int matches; - sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); - n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); - n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 32; h_length -= 32) { - h_reversed = h + h_length - n_length - 32 + 1; - h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); - h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); - h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); - matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & - _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); - while (matches) { - int potential_offset = sz_u32_clz(matches); - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - matches &= ~(1 << (31 - potential_offset)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. - sz_u256_vec_t filter_even_vec, filter_odd_vec; - for (sz_size_t i = 0; i != 16; ++i) - filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; - filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; - filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; - - sz_u256_vec_t text_vec; - sz_u256_vec_t matches_vec; - sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u256_vec_t bitset_even_vec, bitset_odd_vec; - sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length >= 32) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); - bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); - bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); - bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != 32; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); - bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); - - // It would have been great to have an instruction that tests the bits and then broadcasts - // the matching bit into all bits in that byte. But we don't have that, so we have to - // `and`, `cmpeq`, `movemask`, and then invert at the end... - matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); - matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); - int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); - if (matches_mask) { - int offset = sz_u32_ctz(matches_mask); - return text + offset; - } - else { text += 32, length -= 32; } - } - - return sz_find_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif -#pragma endregion - -/* - * @brief AVX-512 implementation of the string search algorithms. - * - * Different subsets of AVX-512 were introduced in different years: - * - 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW - * - 2018 CannonLake: IFMA, VBMI - * - 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES - * - 2020 TigerLake: VP2INTERSECT - */ -#pragma region AVX512 Implementation - -#if SZ_USE_ICE -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -#include - -/** - * @brief Helper structure to simplify work with 512-bit registers. - */ -typedef union sz_u512_vec_t { - __m512i zmm; - __m256i ymms[2]; - __m128i xmms[4]; - sz_u64_t u64s[8]; - sz_u32_t u32s[16]; - sz_u16_t u16s[32]; - sz_u8_t u8s[64]; - sz_i64_t i64s[8]; - sz_i32_t i32s[16]; -} sz_u512_vec_t; - -SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); -} - -SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); -} - -SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); -} - -SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 16: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 16: - return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 32: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 32: - return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); -} - -SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { - // The simplest approach to compute this if we know that `n` is blow or equal 64: - // return (1ull << n) - 1; - // A slightly more complex approach, if we don't know that `n` is under 64: - return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); -} - -SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - sz_u512_vec_t a_vec, b_vec; - - // Pointer arithmetic is cheap, fetching memory is not! - // So we can use the masked loads to fetch at most one cache-line for each string, - // compare the prefixes, and only then move forward. - sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. - sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. - a_head_length = a_head_length < a_length ? a_head_length : a_length; - b_head_length = b_head_length < b_length ? b_head_length : b_length; - sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); - __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } - else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } - - // The rare case, when both string are very long. - __mmask64 a_mask, b_mask; - while ((a_length >= 64) & (b_length >= 64)) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - a += 64, b += 64, a_length -= 64, b_length -= 64; - } - - // In most common scenarios at least one of the strings is under 64 bytes. - if (a_length | b_length) { - a_mask = _sz_u64_clamp_mask_until(a_length); - b_mask = _sz_u64_clamp_mask_until(b_length); - a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); - // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. - // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have - // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. - mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask_not_equal != 0) { - sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); - char a_char = a_vec.u8s[first_diff]; - char b_char = b_vec.u8s[first_diff]; - return _sz_order_scalars(a_char, b_char); - } - // From logic perspective, the hardest cases are "abc\0" and "abc". - // The result must be `sz_greater_k`, as the latter is shorter. - else { return _sz_order_scalars(a_length, b_length); } - } - - return sz_equal_k; -} - -SZ_PUBLIC sz_bool_t sz_equal_skylake(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - __mmask64 mask; - sz_u512_vec_t a_vec, b_vec; - - while (length >= 64) { - a_vec.zmm = _mm512_loadu_si512(a); - b_vec.zmm = _mm512_loadu_si512(b); - mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); - if (mask != 0) return sz_false_k; - a += 64, b += 64, length -= 64; - } - - if (length) { - mask = _sz_u64_mask_until(length); - a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); - b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); - return (sz_bool_t)(mask == 0); - } - - return sz_true_k; -} - -SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - __m512i value_vec = _mm512_set1_epi8(value); - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores". - // - // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); - // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); - // - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, value_vec); - } - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, value_vec); - for (target += head_length; body_length >= 64; target += 64, body_length -= 64) - _mm512_store_si512(target, value_vec); - _mm512_mask_storeu_epi8(target, tail_mask, value_vec); - } -} - -SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "stores" and "loads". - // - // for (; length >= 64; target += 64, source += 64, length -= 64) - // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); - // __mmask64 mask = _sz_u64_mask_until(length); - // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // When dealing wirh larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, - // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, - // we can use aligned loads and stores, and the performance will be great. - else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { - for (; length >= 64; target += 64, source += 64, length -= 64) - _mm512_store_si512(target, _mm512_load_si512(source)); - // At this point the length is guaranteed to be under 64. - __mmask64 mask = _sz_u64_mask_until(length); - // Aligned load and stores would work too, but it's not defined. - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - // The trickiest case is when both `source` and `target` are not aligned. - // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, - // and then combine unaligned loads with aligned stores. - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. - // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(target + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source)); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (target += head_length, source += head_length; // - body_length >= 128; // - target += 64, source += 64, body_length -= 128) { - _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); - _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); - } - if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); - } -} - -SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target == source) return; // Don't be silly, don't move the data if it's already there. - - // On very short buffers, that are one cache line in width or less, we don't need any loops. - // We can also avoid any data-dependencies between iterations, assuming we have 32 registers - // to pre-load the data, before writing it back. - if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); - } - else if (length <= 128) { - sz_size_t last_length = length - 64; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); - _mm512_storeu_epi8(target, source0); - _mm512_mask_storeu_epi8(target + 64, mask, source1); - } - else if (length <= 192) { - sz_size_t last_length = length - 128; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_mask_storeu_epi8(target + 128, mask, source2); - } - else if (length <= 256) { - sz_size_t last_length = length - 192; - __mmask64 mask = _sz_u64_mask_until(last_length); - __m512i source0 = _mm512_loadu_epi8(source); - __m512i source1 = _mm512_loadu_epi8(source + 64); - __m512i source2 = _mm512_loadu_epi8(source + 128); - __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); - _mm512_storeu_epi8(target, source0); - _mm512_storeu_epi8(target + 64, source1); - _mm512_storeu_epi8(target + 128, source2); - _mm512_mask_storeu_epi8(target + 192, mask, source3); - } - - // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. - else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - else { - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // The absolute most common case of using "moves" is shifting the data within a continuous buffer - // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, - // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. - // - // Remember: - // - if we are shifting data left, that we are traversing to the right. - // - if we are shifting data right, that we are traversing to the left. - int const left_to_right_traversal = source > target; - - // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. - // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. - // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. - // - // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. - // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. - // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. - // - // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! - // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! - // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. - // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. - // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html - // - // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. - // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or - // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, - // and is available with VBMI. That solution is still noticeably slower than AVX2. - // - // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. - // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html - if (left_to_right_traversal) { - // Head, body, and tail. - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - for (target += head_length, source += head_length; body_length >= 64; - target += 64, source += 64, body_length -= 64) - _mm512_store_si512(target, _mm512_loadu_si512(source)); - _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); - } - else { - // Tail, body, and head. - _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, - _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); - for (; body_length >= 64; body_length -= 64) - _mm512_store_si512(target + head_length + body_length - 64, - _mm512_loadu_si512(source + head_length + body_length - 64)); - _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); - } - } -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - h += 64, h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + sz_u64_ctz(mask); - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_find_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 matches; - __mmask64 mask; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - // We have several optimized versions of the lagorithm for shorter strings, - // but they all mimic the default case for unbounded length needles - if (n_length >= 64) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - - // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. - // This will be very helpful for very long needles. - } - } - // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. - else if (n_length <= 3) { - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - if (matches) return h + sz_u64_ctz(matches); - } - } - // If the needle is smaller than the size of the ZMM register, we can use masked comparisons - // to avoid the the inner-most nested loop and compare the entire needle against a haystack - // slice in 3 CPU cycles. - else { - __mmask64 n_mask = _sz_u64_mask_until(n_length); - sz_u512_vec_t n_full_vec, h_full_vec; - n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); - for (; h_length >= n_length + 64; h += 64, h_length -= 64) { - h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); - if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) - return h + potential_offset; - matches &= matches - 1; - } - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_ctz(matches); - if (n_length <= 3 || sz_equal_skylake(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - __mmask64 mask; - sz_u512_vec_t h_vec, n_vec; - n_vec.zmm = _mm512_set1_epi8(n[0]); - - while (h_length >= 64) { - h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); - mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); - if (mask) return h + h_length - 1 - sz_u64_clz(mask); - h_length -= 64; - } - - if (h_length) { - mask = _sz_u64_mask_until(h_length); - h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); - // Reuse the same `mask` variable to find the bit that doesn't match - mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); - if (mask) return h + 64 - sz_u64_clz(mask) - 1; - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Broadcast those characters into ZMM registers. - __mmask64 mask; - __mmask64 matches; - sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; - n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); - n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); - n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); - - // Scan through the string. - sz_cptr_t h_reversed; - for (; h_length >= n_length + 64; h_length -= 64) { - h_reversed = h + h_length - n_length - 64 + 1; - h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); - h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); - h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_skylake(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - // The "tail" of the function uses masked loads to process the remaining bytes. - { - mask = _sz_u64_mask_until(h_length - n_length + 1); - h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); - h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); - h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); - matches = _kand_mask64(_kand_mask64( // Intersect the masks - _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), - _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), - _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); - while (matches) { - int potential_offset = sz_u64_clz(matches); - if (n_length <= 3 || sz_equal_skylake(h + 64 - potential_offset - 1, n, n_length)) - return h + 64 - potential_offset - 1; - sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && - "The bit must be set before we squash it"); - matches &= ~((sz_u64_t)1 << (63 - potential_offset)); - } - } - - return SZ_NULL_CHAR; -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) - -/** - * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop cycles. - * Supports an early exit, if the distance is bounded. - * Keeps all of the data and Levenshtein matrices skew diagonal in just a couple of registers. - * Benefits from the @b `vpermb` instructions, that can rotate the bytes across the entire ZMM register. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - - sz_size_t const max_length = 63u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - - // We are going to store 3 diagonals of the matrix, assuming each would fit into a single ZMM register. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - - // The next few buffers will be swapped around. - sz_u512_vec_t previous_vec, current_vec, next_vec; - sz_u512_vec_t gaps_vec, substitutions_vec; - - // Load the strings into ZMM registers - just once. - sz_u512_vec_t longer_vec, shorter_vec, shorter_rotated_vec, rotate_left_vec, rotate_right_vec, ones_vec, bound_vec; - longer_vec.zmm = _mm512_maskz_loadu_epi8(_sz_u64_mask_until(longer_length), longer); - rotate_left_vec.zmm = _mm512_set_epi8( // - 0, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, // - 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, // - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, // - 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1); - rotate_right_vec.zmm = _mm512_set_epi8( // - 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, // - 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, // - 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, // - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 63); - ones_vec.zmm = _mm512_set1_epi8(1); - bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); - - // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Initialize the first two diagonals: - // - // previous_vec.u8s[0] = 0; - // current_vec.u8s[0] = current_vec.u8s[1] = 1; - // - // We can do a similar thing with vector ops: - previous_vec.zmm = _mm512_setzero_si512(); - current_vec.zmm = _mm512_set1_epi8(1); - - // We skip diagonals 0 and 1, as they are trivial. - // We will start with diagonal 2, which has length 3, with the first and last elements being preset, - // so we are effectively computing just one value, as will be marked by a single set bit in - // the `next_diagonal_mask` on the very first iteration. - sz_size_t next_diagonal_index = 2; - __mmask64 next_diagonal_mask = 0; - - // Progress through the upper triangle of the Levenshtein matrix. - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - // After this iteration, the values at offset `0` and `next_diagonal_index` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // The mask also adds one set bit. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - next_diagonal_mask = _kshiftli_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - substitutions_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, substitutions_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(_mm512_permutexvar_epi8(rotate_right_vec.zmm, current_vec.zmm), current_vec.zmm), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = current_vec.zmm; - current_vec.zmm = next_vec.zmm; - - // Shift the shorter string - shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_rotated_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. - for (; next_diagonal_index != longer_dim; ++next_diagonal_index) { - // After this iteration, the value `shorted_dim - 1` in the `next_vec` - // should be set to `next_diagonal_index`, but it's easier to broadcast the value to the whole vector, - // and later merge with a mask with new values. - next_vec.zmm = _mm512_set1_epi8((sz_u8_t)next_diagonal_index); - - // Make sure we update the first entry. - next_diagonal_mask = _kor_mask64(next_diagonal_mask, 1); - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_mask_min_epu8(next_vec.zmm, next_diagonal_mask, gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - } - - // Now let's handle the bottom right triangle. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - - // Check for equality between string slices. - __mmask64 conflict_mask = _mm512_cmpneq_epi8_mask(longer_vec.zmm, shorter_rotated_vec.zmm); - substitutions_vec.zmm = _mm512_mask_add_epi8(previous_vec.zmm, conflict_mask, previous_vec.zmm, ones_vec.zmm); - gaps_vec.zmm = _mm512_add_epi8( - // Insertions or deletions - _mm512_min_epu8(current_vec.zmm, _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm)), - ones_vec.zmm); - next_vec.zmm = _mm512_min_epu8(gaps_vec.zmm, substitutions_vec.zmm); - - // Mark the current skewed diagonal as the previous one and the next one as the current one. - previous_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, current_vec.zmm); - current_vec.zmm = next_vec.zmm; - - // Let's shift the longer string now. - longer_vec.zmm = _mm512_permutexvar_epi8(rotate_left_vec.zmm, longer_vec.zmm); - - // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. - __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) { // - return SZ_SIZE_MAX; - } - // In every following iterations we take use a shorter prefix of each register, - // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. - next_diagonal_mask = _kshiftri_mask64(next_diagonal_mask, 1); - } - return current_vec.u8s[0]; -} - -/** - * @brief Computes the edit distance between two somewhat short bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 255) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * Benefits from the @b `vpermi2b` instructions, that can rotate the bytes in 2 registers at once. - * - * This may be one of the most freuqently called kernels for: - * - source code analysis, assuming most lines are either under 80 or under 120 characters long. - * - DNA sequence alignment, as most short reads are 50-300 characters long. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two longer bytes-strings using the AVX-512VBMI extensions, - * assuming the upper distance bound can not exceed 255, but the string length can be arbitrary. - * - * Applies to string lengths up to 255, and evaluates at most (255 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Uses a lot more CPU registers space, than the `upto63` variant. - * - * Each of 2x string ends up occupying 4 ZMM registers, and each of 3x diagonals uses 4 ZMM registers. - * So 20x of the 32x are persistently occupied, and the rest are used for math temporarily. - * This is the largest space-efficient variant, as strings beyond 255 characters may require - * 16-bit accumulators, which would be a significant bottleneck. - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -/** - * @brief Computes the edit distance between two mid-length UTF-8-strings using the AVX-512VBMI extensions. - * - * Applies to string lengths up to 127, and evaluates at most (127 * 2 + 1 = 511) diagonals. - * Supports an early exit, if the distance is bounded. - * Benefits from the @b `valignd` instructions used to rotate UTF-32 unpacked unicode codepoints. - * - * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. - * - */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound) { - sz_unused(shorter && shorter_length && longer && longer_length && bound); - return 0; -} - -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - sz_unused(shorter && longer && bound && alloc); - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // TODO: Generalize! - sz_size_t const max_length = 256u * 256u; - sz_assert(shorter_length <= longer_length && "The 'shorter' string is longer than the 'longer' one."); - sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && bound && max_length); - -#if 0 - // We are going to store 3 diagonals of the matrix. - // The length of the longest (main) diagonal would be `shorter_dim = (shorter_length + 1)`. - sz_size_t const shorter_dim = shorter_length + 1; - sz_size_t const longer_dim = longer_length + 1; - // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. - // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. - sz_size_t const buffer_length = sizeof(sz_u16_t) * longer_dim * 3 + shorter_length; - sz_u16_t *const distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; - - // The next few pointers will be swapped around. - sz_u16_t *previous_distances = distances; - sz_u16_t *current_distances = previous_distances + longer_dim; - sz_u16_t *next_distances = current_distances + longer_dim; - sz_ptr_t const shorter_reversed = (sz_ptr_t)(next_distances + longer_dim); - - // Export the reversed string into the buffer. - for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; - - // Initialize the first two diagonals: - previous_distances[0] = 0; - current_distances[0] = current_distances[1] = 1; - - // Using ZMM registers, we can process 32x 16-bit values at once, - // storing 16 bytes of each string in YMM registers. - sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; - sz_u512_vec_t ones_u16_vec; - ones_u16_vec.zmm = _mm512_set1_epi16(1); - - // This is a mixed-precision implementation, using 8-bit representations for part of the operations. - // Even there, in case `SZ_USE_HASWELL=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. - sz_u512_vec_t shorter_vec, longer_vec; - sz_u512_vec_t ones_u8_vec; - ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); - - // Let's say we are dealing with 3 and 5 letter words. - // The matrix will have size 4 x 6, parameterized as (shorter_dim x longer_dim). - // It will have: - // - 4 diagonals of increasing length, at positions: 0, 1, 2, 3. - // - 2 diagonals of fixed length, at positions: 4, 5. - // - 3 diagonals of decreasing length, at positions: 6, 7, 8. - sz_size_t const diagonals_count = shorter_dim + longer_dim - 1; - - // Progress through the upper triangle of the Levenshtein matrix. - sz_size_t next_diagonal_index = 2; - for (; next_diagonal_index != shorter_dim; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = next_diagonal_index + 1; - for (sz_size_t offset_within_diagonal = 0; offset_within_diagonal + 2 < next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - offset_within_diagonal - 2); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + offset_within_diagonal); - // Our original code addressed the shorter string `[next_diagonal_index - offset_within_diagonal - 2]` - // for growing `offset_within_diagonal`. If the `shorter` string was reversed, the - // `[next_diagonal_index - offset_within_diagonal - 2]` would be equal to `[shorter_length - 1 - - // next_diagonal_index + offset_within_diagonal + 2]`. Which simplified would be equal to - // `[shorter_length - next_diagonal_index + offset_within_diagonal + 1]`. - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( // - remaining_length_mask, - shorter_reversed + shorter_length - next_diagonal_index + offset_within_diagonal + 1); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow - // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, - _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + offset_within_diagonal)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal); - deletions_vec.zmm = - _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + offset_within_diagonal + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + offset_within_diagonal + 1, remaining_length_mask, next_vec.zmm); - offset_within_diagonal += register_length; - } - // Don't forget to populate the first row and the first column of the Levenshtein matrix. - next_distances[0] = next_distances[next_diagonal_length - 1] = (sz_u16_t)next_diagonal_index; - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances; - current_distances = next_distances; - next_distances = temporary; - } - - // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a - // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal - // index on either side, we will be cropping those values out. - for (; next_diagonal_index != diagonals_count; ++next_diagonal_index) { - sz_size_t const next_diagonal_length = diagonals_count - next_diagonal_index; - for (sz_size_t i = 0; i != next_diagonal_length;) { - sz_u32_t remaining_length = (sz_u32_t)(next_diagonal_length - i); - sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; - sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); - longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_diagonal_index - n + i); - // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. - // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would - // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. - // Which simplified would be equal to just `[i]`. Beautiful! - shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); - // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 - // to get the result as a vector, instead of a bitmask. The compare it against the accumulated - // substitution costs. - substitutions_vec.zmm = _mm512_cvtepi8_epi16( // - _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); - substitutions_vec.zmm = _mm512_add_epi16( // - substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); - // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, - // than rotate the bytes in the ZMM register. - insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); - deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); - // First get the minimum of insertions and deletions. - next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); - next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); - _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); - i += register_length; - } - - // Perform a circular rotation (three-way swap) of those buffers, to reuse the memory, this time, with a shift, - // dropping the first element in the current array. - sz_u16_t *temporary = previous_distances; - previous_distances = current_distances + 1; - current_distances = next_distances; - next_distances = temporary; - } - - // Cache scalar before `free` call. - sz_size_t result = current_distances[0]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -#endif - return 0; -} - -SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - - // Bounded computations may exit early. - int const is_bounded = bound < longer_length; - if (is_bounded) { - // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); - // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; - } - - // Make sure the shorter string is actually shorter. - if (shorter_length > longer_length) { - sz_cptr_t temporary = shorter; - shorter = longer; - longer = temporary; - sz_size_t temporary_length = shorter_length; - shorter_length = longer_length; - longer_length = temporary_length; - } - - // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_avx512( // - shorter, shorter_length, longer, longer_length, bound); - // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_avx512( // - // shorter, shorter_length, longer, longer_length, bound, alloc); - else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); -} - -SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} - -SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options - -#pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ - apply_to = function) - -SZ_PUBLIC void sz_look_up_transform_ice(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, - // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores - // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked - // for the body. - sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - // We need to pull the lookup table into 4x ZMM registers. - // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` - // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to - // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. - // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. - // - // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 6 cycles latency, ports: 1*FP12 - // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p05 - // - On Genoa: 1 cycle latency, ports: 1*FP0123 - // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": - // - On Ice Lake: 3 cycles latency, ports: 1*p5 - // - On Genoa: 4 cycles latency, ports: 1*FP01 - // - sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); - lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); - lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); - lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); - - sz_u512_vec_t first_bit_vec, second_bit_vec; - first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); - second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); - - __mmask64 first_bit_mask, second_bit_mask; - sz_u512_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; - - // Handling the head. - if (head_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); - source += head_length, target += head_length, length -= head_length; - } - - // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. - while (length >= 64) { - source_vec.zmm = _mm512_loadu_si512(source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! - source += 64, target += 64, length -= 64; - } - - // Handling the tail. - if (tail_length) { - source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); - lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); - lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); - lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); - lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); - first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); - second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); - blended_0_to_127_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); - blended_128_to_255_vec.zmm = - _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); - blended_0_to_255_vec.zmm = - _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); - _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); - source += tail_length, target += tail_length, length -= tail_length; - } -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - - // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. - // In practice, that only hurts, even when we have matches every 5-ish bytes. - // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); - // if (early_result) return early_result; - // text += SZ_SWAR_THRESHOLD; - // length -= SZ_SWAR_THRESHOLD; - // - // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. - // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. - sz_u512_vec_t filter_even_vec, filter_odd_vec; - __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); - // There are a few way to initialize filters without having native strided loads. - // In the cronological order of experiments: - // - serial code initializing 128 bytes of odd and even mask - // - using several shuffles - // - using `_mm512_permutexvar_epi8` - // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` - // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` - filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); - filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i - _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); - // After the unzipping operation, we can validate the contents of the vectors like this: - // - // for (sz_size_t i = 0; i != 16; ++i) { - // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); - // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); - // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); - // } - // - sz_u512_vec_t text_vec; - sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; - sz_u512_vec_t bitset_even_vec, bitset_odd_vec; - sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; - bitmask_lookup_vec.zmm = _mm512_set_epi8( // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // - -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); - - while (length) { - // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" - // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so - // StrinZilla uses a somewhat different approach. - // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new - // - // sz_u8_t input = *(sz_u8_t const *)text; - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; - // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; - // if ((bitset & bitmask) != 0) return text; - // else { length--, text++; } - // - // The nice part about this, loading the strided data is vey easy with Arm NEON, - // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. - sz_size_t load_length = sz_min_of_two(length, 64); - __mmask64 load_mask = _sz_u64_mask_until(load_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); - lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); - bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); - // - // At this point we can validate the `bitmask_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t lo_nibble = input & 0x0f; - // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); - // sz_assert(bitmask_vec.u8s[i] == bitmask); - // } - // - // Shift right every byte by 4 bits. - // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` - // and combine it with a mask to clear the higher bits. - higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); - bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); - bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); - // - // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: - // - // for (sz_size_t i = 0; i != load_length; ++i) { - // sz_u8_t input = *(sz_u8_t const *)(text + i); - // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; - // sz_u8_t hi_nibble = input >> 4; - // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; - // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; - // sz_assert(bitset_even_vec.u8s[i] == bitset_even); - // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); - // } - // - // TODO: Is this a good place for ternary logic? - __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); - bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); - __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); - if (matches_mask) { - int offset = sz_u64_ctz(matches_mask); - return text + offset; - } - else { text += load_length, length -= load_length; } - } - - return SZ_NULL_CHAR; -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); -} - -SZ_PUBLIC sz_cptr_t sz_find_many_avx512( // - sz_cptr_t haystack, sz_size_t haystack_length, // - sz_cptr_t const *needles, sz_size_t const *needles_lengths, // - sz_size_t *needle_offset) { - - // When dealing with huge needles vocabularies, like in tokenization workloads, we need to construct an automaton. - // But in many cases, the vocabulary is small enough to use a simpler DFA-less approach, combining the ideas from - // the `sz_find_skylake` and `sz_find_charset_ice` functions. - // - // Pick the offsets within needles where there is the least variance in the characters. - // Like for "the", "then", "there", "these", "those", "their", "they", "them", "that", "this", "thus", "than": - // - // 0: 't' - // 1: 'h' - // 2: 'e', 'a', 'i', 'o', 'u' - // 3: 'n', 'r', 's', 'i', 'y', 'm', 't' - // - // So depending on our "register budget", we can use a different number of pivot points: offset 0, 1, 2 make - // the most sense if we can only use 3 ZMM registers. - sz_unused(haystack && haystack_length && needles && needles_lengths && needle_offset); - return 0; -} - -/** - * Computes the Needleman Wunsch alignment score between two strings. - * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. - * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used - * on strings not exceeding 2^24 length or 16.7 million characters. - * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store - * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs - * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with - * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against - * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of - * a 256 x 256 matrix, but from a single row! - */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; - - // Let's make sure that we use the amount proportional to the - // number of elements in the shorter string, not the larger. - if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); - } - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - sz_size_t const max_length = 256ull * 256ull * 256ull; - sz_size_t const n = longer_length + 1; - sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); - sz_unused(longer_length && max_length); - - sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; - sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); - sz_i32_t *previous_distances = distances; - sz_i32_t *current_distances = previous_distances + n; - - // Intialize the first row of the Levenshtein matrix with `iota`. - for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) - previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; - - /// Contains up to 16 consecutive characters from the longer string. - sz_u512_vec_t longer_vec; - sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; - sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; - sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; - - // Prepare constants and masks. - sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; - { - char is_third_or_fourth_check, is_second_or_fourth_check; - *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; - is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); - is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); - gap_vec.zmm = _mm512_set1_epi32(gap); - } - - sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; - for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { - sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; - - // Load one row of the substitution matrix into four ZMM registers. - sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; - row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); - row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); - row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); - row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); - - // In the serial version we have one forward pass, that computes the deletion, - // insertion, and substitution costs at once. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; - // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; - // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; - // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); - // } - // - // Given the complexity of handling the data-dependency between consecutive insertion cost computations - // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation - // separately. - // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. - // 2. Compute the pairwise minimum with deletion costs. - // 3. Inclusive prefix minimum computation to combine with addition costs. - // Proceeding with substitutions: - for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { - sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); - __mmask64 mask = _sz_u64_mask_until(register_length); - longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); - - // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source - // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. - // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. - shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); - shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); - shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); - shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); - - // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using - // the AND logical operation, checking the top two bits of every byte. - // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. - __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); - __mmask64 is_second_or_fourth = - _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); - lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( - is_third_or_fourth, - // Choose between the first and the second. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), - // Choose between the third and the fourth. - _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); - - // First, sign-extend lower and upper 16 bytes to 16-bit integers. - __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); - __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); - - // Now extend those 16-bit integers to 32-bit. - // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. - // To minimize the number of loads and stores, we can combine our substitution costs with the previous - // distances, containing the deletion costs. - { - cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); - cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Inclusive prefix minimum computation to combine with insertion costs. - // Simply disabling this operation results in 5x performance improvement, meaning - // that this operation is responsible for 80% of the total runtime. - // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { - // current_distances[idx_longer + 1] = - // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); - // } - // - // To perform the same operation in vectorized form, we need to perform a tree-like reduction, - // that will involve multiple steps. It's quite expensive and should be first tested in the - // "experimental" section. - // - // Another approach might be loop unrolling: - // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); - // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); - // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); - // ... yet this approach is also quite expensive. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 16 to 31. - if (register_length > 16) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 32) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); - } - - // Export the values from 32 to 47. - if (register_length > 48) { - mask = _kshiftri_mask64(mask, 16); - cost_substitution_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); - cost_substitution_vec.zmm = _mm512_add_epi32( - cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); - cost_deletion_vec.zmm = - _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); - cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); - current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); - - // Aggregate running insertion costs within the register. - for (int i = 0; i != 16; ++i) - current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); - _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); - } - } - - // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); - } - - // Cache scalar before `free` call. - sz_ssize_t result = previous_distances[longer_length]; - alloc->free(distances, buffer_length, alloc->handle); - return result; -} - -SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - - if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, - gap, alloc); - else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); -} - -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_jwt_k, - sz_base64_k, - // Low priority encodings: - sz_encoding_utf8bom_k = 5, - sz_encoding_utf16le_k = 6, - sz_encoding_utf16be_k = 7, - sz_encoding_utf32le_k = 8, - sz_encoding_utf32be_k = 9, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit - * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. - */ -#pragma region ARM NEON - -#if SZ_USE_NEON -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+simd") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) - -/** - * @brief Helper structure to simplify work with 64-bit words. - */ -typedef union sz_u128_vec_t { - uint8x16_t u8x16; - uint16x8_t u16x8; - uint32x4_t u32x4; - uint64x2_t u64x2; - sz_u64_t u64s[2]; - sz_u32_t u32s[4]; - sz_u16_t u16s[8]; - sz_u8_t u8s[16]; -} sz_u128_vec_t; - -SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { - // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. - // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; -} - -SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { - //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: - //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations - return sz_order_serial(a, a_length, b, b_length); -} - -SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u128_vec_t a_vec, b_vec; - for (; length >= 16; a += 16, b += 16, length -= 16) { - a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); - b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); - uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); - if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match - } - - // Handle remaining bytes - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; -} - -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - -SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // In most cases the `source` and the `target` are not aligned, but we should - // at least make sure that writes don't touch many cache lines. - // NEON has an instruction to load and write 64 bytes at once. - // - // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. - // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. - // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; - // length -= head_length; - // for (; length >= 64; target += 64, source += 64, length -= 64) - // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); - // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; - // - // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: - for (; length >= 16; target += 16, source += 16, length -= 16) - vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); - if (length) sz_copy_serial(target, source, length); -} - -SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - // When moving small buffers, using a small buffer on stack as a temporary storage is faster. - - if (target < source || target >= source + length) { - // Non-overlapping, proceed forward - sz_copy_neon(target, source, length); - } - else { - // Overlapping, proceed backward - target += length; - source += length; - - sz_u128_vec_t src_vec; - while (length >= 16) { - target -= 16, source -= 16, length -= 16; - src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - vst1q_u8((sz_u8_t *)target, src_vec.u8x16); - } - while (length) { - target -= 1, source -= 1, length -= 1; - *target = *source; - } - } -} - -SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register - - while (length >= 16) { - vst1q_u8((sz_u8_t *)target, fill_vec); - target += 16; - length -= 16; - } - - // Handle remaining bytes - if (length) sz_fill_serial(target, length, value); -} - -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - - // If the input is tiny (especially smaller than the look-up table itself), we may end up paying - // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. - if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); - return; - } - - sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. - sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. - - // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. - // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. - uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; - lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); - lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); - lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); - lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); - - sz_u128_vec_t source_vec; - // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or - // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. - sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; - sz_u128_vec_t blended_0_to_255_vec; - - // Process the head with serial code - for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; - - // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction - // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position - // within each 64-byte range of the table. - // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ - length -= head_length; - length -= tail_length; - for (; length >= 16; source += 16, target += 16, length -= 16) { - source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); - lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); - lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); - lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); - lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); - blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), - vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); - vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); - } - - // Process the tail with serial code - for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; -} - -SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. - // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) - // the vector with a relative offsets array. - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - - h += 16, h_length -= 16; - } - - return sz_find_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { - sz_u64_t matches; - sz_u128_vec_t h_vec, n_vec, matches_vec; - n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); - - while (h_length >= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); - matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - h_length -= 16; - } - - return sz_rfind_byte_serial(h, h_length, n); -} - -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, - uint8x16_t set_bottom_vec_u8x16) { - - // Once we've read the characters in the haystack, we want to - // compare them against our bitset. The serial version of that code - // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. - uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); - uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); - uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); - // The table lookup instruction in NEON replies to out-of-bound requests with zeros. - // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow - // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely - // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. - uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); - uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); - // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. - matches_vec = vtstq_u8(matches_vec, byte_mask_vec); - return _sz_vreinterpretq_u8_u4(matches_vec); -} - -SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_find_byte_neon(h, h_length, n); - - // Scan through the string. - // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. - // That's why, for smaller needles, we use different loops. - if (n_length == 2) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; - // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets - // in a single loop iteration. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - for (; h_length >= 17; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - matches_vec.u8x16 = - vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else if (n_length == 3) { - // Broadcast needle characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach - // as when searching for string over 4 characters long. I only avoid the last comparison. - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); - for (; h_length >= 18; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - } - else { - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - // Broadcast those characters into SIMD registers. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - // Walk through the string. - for (; h_length >= n_length + 16; h += 16, h_length -= 16) { - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_ctz(matches) / 4; - if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; - matches &= matches - 1; - } - } - } - - return sz_find_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - - // This almost never fires, but it's better to be safe than sorry. - if (h_length < n_length || !n_length) return SZ_NULL_CHAR; - if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); - - // Pick the parts of the needle that are worth comparing. - sz_size_t offset_first, offset_mid, offset_last; - _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); - - // Will contain 4 bits per character. - sz_u64_t matches; - sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; - n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); - n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); - n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); - - sz_cptr_t h_reversed; - for (; h_length >= n_length + 16; h_length -= 16) { - h_reversed = h + h_length - n_length - 16 + 1; - h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); - h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); - h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); - matches_vec.u8x16 = vandq_u8( // - vandq_u8( // - vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // - vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), - vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); - matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); - while (matches) { - int potential_offset = sz_u64_clz(matches) / 4; - if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) - return h + h_length - n_length - potential_offset; - sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && - "The bit must be set before we squash it"); - matches &= ~(1ull << (63 - potential_offset * 4)); - } - } - - return sz_rfind_serial(h, h_length, n, n_length); -} - -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - for (; h_length >= 16; h += 16, h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + sz_u64_ctz(matches) / 4; - } - - return sz_find_charset_serial(h, h_length, set); -} - -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { - sz_u64_t matches; - sz_u128_vec_t h_vec; - uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); - uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - - // Check `sz_find_charset_neon` for explanations. - for (; h_length >= 16; h_length -= 16) { - h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); - if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; - } - - return sz_rfind_charset_serial(h, h_length, set); -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm Neon - -#pragma endregion - -/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available - * in Arm v9 processors. - * - * Implements: - * - memory: {copy, move, fill} - * - comparisons: {equal, order} - * - search: {substring, character, character set} x {forward, reverse}. - */ -#pragma region ARM SVE - -#if SZ_USE_SVE -#pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+sve") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) - -SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { - svuint8_t value_vec = svdup_u8(value); - sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) - - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svst1_u8(mask, (unsigned char *)target, value_vec); - } - else { - // Calculate head, body, and tail sizes - sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); - sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned head - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svst1_u8(head_mask, (unsigned char *)target, value_vec); - target += head_length; - - // Aligned body loop - for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { - svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); - } - - // Handle unaligned tail - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svst1_u8(tail_mask, (unsigned char *)target, value_vec); - } -} - -SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - sz_size_t vec_len = svcntb(); // Vector length in bytes - - // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, - // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. - // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. - // - // int is_huge = length >= 4ull * 1024ull * 1024ull; - // - // When the buffer is small, there isn't much to innovate. - if (length <= vec_len) { - // Small buffer case: use mask to handle small writes - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations - // and handle the head, body, and tail separately. We can also traverse the buffer in both directions - // as Arm generally supports more simultaneous stores than x86 CPUs. - // - // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. - // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) - // we will pay a huge penalty on loads, fetching the same content many times. - // It may be better to allow caching (and subsequent eviction), in favor of using four-element - // tuples, wich will be guaranteed to be a multiple of a cache line. - // - // Another approach is to use the `LD4B` instructions, which will populate four registers at once. - // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. - else { - // Calculating head, body, and tail sizes depends on the `vec_len`, - // but it's runtime constant, and the modulo operation is expensive! - // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. - sz_size_t head_length = 16 - ((sz_size_t)target % 16); - sz_size_t tail_length = (sz_size_t)(target + length) % 16; - sz_size_t body_length = length - head_length - tail_length; - - // Handle unaligned parts - svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); - svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); - svst1_u8(head_mask, (unsigned char *)target, head_data); - svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); - svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); - svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); - target += head_length; - source += head_length; - - // Aligned body loop, walking in two directions - for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { - svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); - svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); - svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); - svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); - } - // Up to (vec_len * 2 - 1) bytes of data may be left in the body, - // so we can unroll the last two optional loop iterations. - if (body_length > vec_len) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - body_length -= vec_len; - source += body_length; - target += body_length; - } - if (body_length) { - svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); - svuint8_t data = svld1_u8(mask, (unsigned char *)source); - svst1_u8(mask, (unsigned char *)target, data); - } - } -} - -#pragma clang attribute pop -#pragma GCC pop_options -#endif // Arm SVE - -#pragma endregion - -/* - * @brief Pick the right implementation for the string search algorithms. - */ -#pragma region Compile Time Dispatching - -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } -SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } -SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } -SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } -SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } - -SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, - sz_size_t fingerprint_bytes) { - - sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); - sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; - - // There are several issues related to the fingerprinting algorithm. - // First, the memory traversal order is important. - // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ - - // In most cases the fingerprint length will be a power of two. - if (fingerprint_length_is_power_of_two == sz_false_k) - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); - else - sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); -} - -#if !SZ_DYNAMIC_DISPATCH - -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_ICE - return sz_checksum_avx512(text, length); -#elif SZ_USE_HASWELL - return sz_checksum_avx2(text, length); -#elif SZ_USE_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif -} - -SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { -#if SZ_USE_ICE - return sz_equal_skylake(a, b, length); -#elif SZ_USE_HASWELL - return sz_equal_avx2(a, b, length); -#elif SZ_USE_NEON - return sz_equal_neon(a, b, length); -#else - return sz_equal_serial(a, b, length); -#endif -} - -SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { -#if SZ_USE_ICE - return sz_order_avx512(a, a_length, b, b_length); -#elif SZ_USE_HASWELL - return sz_order_avx2(a, a_length, b, b_length); -#elif SZ_USE_NEON - return sz_order_neon(a, a_length, b, b_length); -#else - return sz_order_serial(a, a_length, b, b_length); -#endif -} - -SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_ICE - sz_copy_avx512(target, source, length); -#elif SZ_USE_HASWELL - sz_copy_avx2(target, source, length); -#elif SZ_USE_NEON - sz_copy_neon(target, source, length); -#else - sz_copy_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_ICE - sz_move_avx512(target, source, length); -#elif SZ_USE_HASWELL - sz_move_avx2(target, source, length); -#elif SZ_USE_NEON - sz_move_neon(target, source, length); -#else - sz_move_serial(target, source, length); -#endif -} - -SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_ICE - sz_fill_avx512(target, length, value); -#elif SZ_USE_HASWELL - sz_fill_avx2(target, length, value); -#elif SZ_USE_NEON - sz_fill_neon(target, length, value); -#else - sz_fill_serial(target, length, value); -#endif -} - -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { -#if SZ_USE_ICE - sz_look_up_transform_ice(source, length, lut, target); -#elif SZ_USE_HASWELL - sz_look_up_transform_avx2(source, length, lut, target); -#elif SZ_USE_NEON - sz_look_up_transform_neon(source, length, lut, target); -#else - sz_look_up_transform_serial(source, length, lut, target); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_ICE - return sz_find_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_HASWELL - return sz_find_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_NEON - return sz_find_byte_neon(haystack, h_length, needle); -#else - return sz_find_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { -#if SZ_USE_ICE - return sz_rfind_byte_avx512(haystack, h_length, needle); -#elif SZ_USE_HASWELL - return sz_rfind_byte_avx2(haystack, h_length, needle); -#elif SZ_USE_NEON - return sz_rfind_byte_neon(haystack, h_length, needle); -#else - return sz_rfind_byte_serial(haystack, h_length, needle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_ICE - return sz_find_skylake(haystack, h_length, needle, n_length); -#elif SZ_USE_HASWELL - return sz_find_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_NEON - return sz_find_neon(haystack, h_length, needle, n_length); -#else - return sz_find_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { -#if SZ_USE_ICE - return sz_rfind_skylake(haystack, h_length, needle, n_length); -#elif SZ_USE_HASWELL - return sz_rfind_avx2(haystack, h_length, needle, n_length); -#elif SZ_USE_NEON - return sz_rfind_neon(haystack, h_length, needle, n_length); -#else - return sz_rfind_serial(haystack, h_length, needle, n_length); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_ICE - return sz_find_charset_ice(text, length, set); -#elif SZ_USE_HASWELL - return sz_find_charset_avx2(text, length, set); -#elif SZ_USE_NEON - return sz_find_charset_neon(text, length, set); -#else - return sz_find_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { -#if SZ_USE_ICE - return sz_rfind_charset_ice(text, length, set); -#elif SZ_USE_HASWELL - return sz_rfind_charset_avx2(text, length, set); -#elif SZ_USE_NEON - return sz_rfind_charset_neon(text, length, set); -#else - return sz_rfind_charset_serial(text, length, set); -#endif -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { -#if SZ_USE_ICE - return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); -#else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); -#endif -} - -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); -} - -SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, - sz_error_cost_t const *subs, sz_error_cost_t gap, - sz_memory_allocator_t *alloc) { -#if SZ_USE_ICE - return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); -#else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); -#endif -} - -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_ICE - sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_HASWELL - sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); -#else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); -#endif -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); -} - -#endif -#pragma endregion - #ifdef __cplusplus -#pragma GCC diagnostic pop } #endif // __cplusplus diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index a80da804..f65b0212 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -20,6 +20,7 @@ * @brief When set to 1, the library will include the C++ STL headers and implement * automatic conversion from and to `std::stirng_view` and `std::basic_string`. */ +#include "types.h" #ifndef SZ_AVOID_STL #define SZ_AVOID_STL (0) // true or false #endif @@ -2069,7 +2070,7 @@ class basic_string { * @brief The number of characters that can be stored in the internal buffer. * Depends on the size of the internal buffer for the "Small String Optimization". */ - static constexpr size_type min_capacity = SZ_STRING_INTERNAL_SPACE - 1; + static constexpr size_type min_capacity = _SZ_STRING_INTERNAL_SPACE - 1; #pragma region Constructors and STL Utilities @@ -3663,8 +3664,9 @@ bool basic_string::try_assign(concatenation -bool basic_string::try_preparing_replacement(size_type offset, size_type length, - size_type replacement_length) noexcept { +bool basic_string::try_preparing_replacement( // + size_type offset, size_type length, size_type replacement_length) noexcept { + // There are three cases: // 1. The replacement is the same length as the replaced range. // 2. The replacement is shorter than the replaced range. @@ -3759,10 +3761,11 @@ typename concatenation_result::type // std::string result; // result.reserve(total_size); // (result.append(strings), ...); - return ashvardanian::stringzilla::concatenate( + return ashvardanian::stringzilla::concatenate( // std::forward(first), - ashvardanian::stringzilla::concatenate(std::forward(second), - std::forward(following)...)); + ashvardanian::stringzilla::concatenate( // + std::forward(second), // + std::forward(following)...)); } /** @@ -3770,8 +3773,9 @@ typename concatenation_result::type * @see sz_edit_distance */ template -std::size_t hamming_distance(basic_string_slice const &a, basic_string_slice const &b, - std::size_t bound = 0) noexcept { +std::size_t hamming_distance( // + basic_string_slice const &a, basic_string_slice const &b, // + std::size_t bound = 0) noexcept { return sz_hamming_distance(a.data(), a.size(), b.data(), b.size(), bound); } @@ -3780,8 +3784,9 @@ std::size_t hamming_distance(basic_string_slice const &a, basic_stri * @see sz_edit_distance */ template ::type>> -std::size_t hamming_distance(basic_string const &a, - basic_string const &b, std::size_t bound = 0) noexcept { +std::size_t hamming_distance( // + basic_string const &a, basic_string const &b, // + std::size_t bound = 0) noexcept { return ashvardanian::stringzilla::hamming_distance(a.view(), b.view(), bound); } @@ -3790,8 +3795,8 @@ std::size_t hamming_distance(basic_string const &a, * @see sz_hamming_distance_utf8 */ template -std::size_t hamming_distance_utf8(basic_string_slice const &a, basic_string_slice const &b, - std::size_t bound = 0) noexcept { +std::size_t hamming_distance_utf8( // + basic_string_slice const &a, basic_string_slice const &b, std::size_t bound = 0) noexcept { return sz_hamming_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound); } @@ -3800,8 +3805,9 @@ std::size_t hamming_distance_utf8(basic_string_slice const &a, basic * @see sz_edit_distance */ template ::type>> -std::size_t hamming_distance_utf8(basic_string const &a, - basic_string const &b, std::size_t bound = 0) noexcept { +std::size_t hamming_distance_utf8( // + basic_string const &a, basic_string const &b, + std::size_t bound = 0) noexcept { return ashvardanian::stringzilla::hamming_distance_utf8(a.view(), b.view(), bound); } @@ -3810,8 +3816,9 @@ std::size_t hamming_distance_utf8(basic_string cons * @see sz_edit_distance */ template ::type>> -std::size_t edit_distance(basic_string_slice const &a, basic_string_slice const &b, - std::size_t bound = 0, allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { +std::size_t edit_distance( // + basic_string_slice const &a, basic_string_slice const &b, std::size_t bound = SZ_SIZE_MAX, + allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { std::size_t result; if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { result = sz_edit_distance(a.data(), a.size(), b.data(), b.size(), bound, &alloc); @@ -3826,8 +3833,9 @@ std::size_t edit_distance(basic_string_slice const &a, basic_string_ * @see sz_edit_distance */ template > -std::size_t edit_distance(basic_string const &a, - basic_string const &b, std::size_t bound = 0) noexcept(false) { +std::size_t edit_distance( // + basic_string const &a, basic_string const &b, // + std::size_t bound = SZ_SIZE_MAX) noexcept(false) { return ashvardanian::stringzilla::edit_distance(a.view(), b.view(), bound, a.get_allocator()); } @@ -3836,9 +3844,9 @@ std::size_t edit_distance(basic_string const &a, * @see sz_edit_distance_utf8 */ template ::type>> -std::size_t edit_distance_utf8(basic_string_slice const &a, basic_string_slice const &b, - std::size_t bound = 0, - allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { +std::size_t edit_distance_utf8( // + basic_string_slice const &a, basic_string_slice const &b, // + std::size_t bound = SZ_SIZE_MAX, allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { std::size_t result; if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { result = sz_edit_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &alloc); @@ -3853,9 +3861,9 @@ std::size_t edit_distance_utf8(basic_string_slice const &a, basic_st * @see sz_edit_distance_utf8 */ template > -std::size_t edit_distance_utf8(basic_string const &a, - basic_string const &b, - std::size_t bound = 0) noexcept(false) { +std::size_t edit_distance_utf8( // + basic_string const &a, basic_string const &b, // + std::size_t bound = SZ_SIZE_MAX) noexcept(false) { return ashvardanian::stringzilla::edit_distance_utf8(a.view(), b.view(), bound, a.get_allocator()); } @@ -3864,9 +3872,10 @@ std::size_t edit_distance_utf8(basic_string const & * @see sz_alignment_score */ template ::type>> -std::ptrdiff_t alignment_score(basic_string_slice const &a, basic_string_slice const &b, - std::int8_t const (&subs)[256][256], std::int8_t gap = -1, - allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { +std::ptrdiff_t alignment_score( // + basic_string_slice const &a, basic_string_slice const &b, // + std::int8_t const (&subs)[256][256], std::int8_t gap = -1, + allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { static_assert(sizeof(sz_error_cost_t) == sizeof(std::int8_t), "sz_error_cost_t must be 8-bit."); static_assert(std::is_signed() == std::is_signed(), @@ -3886,9 +3895,9 @@ std::ptrdiff_t alignment_score(basic_string_slice const &a, basic_st * @see sz_alignment_score */ template > -std::ptrdiff_t alignment_score(basic_string const &a, - basic_string const &b, // - std::int8_t const (&subs)[256][256], std::int8_t gap = -1) noexcept(false) { +std::ptrdiff_t alignment_score( // + basic_string const &a, basic_string const &b, // + std::int8_t const (&subs)[256][256], std::int8_t gap = -1) noexcept(false) { return ashvardanian::stringzilla::alignment_score(a.view(), b.view(), subs, gap, a.get_allocator()); } @@ -3900,8 +3909,9 @@ std::ptrdiff_t alignment_score(basic_string const & * @param alphabet A string of characters to choose from. */ template -void randomize(basic_string_slice string, generator_type_ &generator, - string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { +void randomize( // + basic_string_slice string, generator_type_ &generator, + string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { static_assert(!std::is_const::value, "The string must be mutable."); sz_random_generator_t generator_callback = &_call_random_generator; sz_generate(alphabet.data(), alphabet.size(), string.data(), string.size(), generator_callback, &generator); @@ -3921,8 +3931,9 @@ void transform(basic_string_slice string, basic_look_up_table -void transform(basic_string_slice source, basic_look_up_table const &table, - char_type_ *target) noexcept { +void transform( // + basic_string_slice source, basic_look_up_table const &table, + char_type_ *target) noexcept { static_assert(sizeof(char_type_) == 1, "The character type must be 1 byte long."); sz_look_up_transform((sz_cptr_t)source.data(), (sz_size_t)source.size(), (sz_cptr_t)table.raw(), (sz_ptr_t)target); } @@ -4007,8 +4018,9 @@ void sorted_order(objects_type_ const *begin, objects_type_ const *end, sorted_i * @see sz_hashes */ template -void hashes_fingerprint(basic_string_slice const &str, std::size_t window_length, - std::bitset &fingerprint) noexcept { +void hashes_fingerprint( // + basic_string_slice const &str, std::size_t window_length, + std::bitset &fingerprint) noexcept { constexpr std::size_t fingerprint_bytes = sizeof(std::bitset); return sz_hashes_fingerprint(str.data(), str.size(), window_length, (sz_ptr_t)&fingerprint, fingerprint_bytes); } @@ -4018,8 +4030,8 @@ void hashes_fingerprint(basic_string_slice const &str, std::size_t w * @see sz_hashes */ template -std::bitset hashes_fingerprint(basic_string_slice const &str, - std::size_t window_length) noexcept { +std::bitset hashes_fingerprint( // + basic_string_slice const &str, std::size_t window_length) noexcept { std::bitset fingerprint; ashvardanian::stringzilla::hashes_fingerprint(str, window_length, fingerprint); return fingerprint; @@ -4040,8 +4052,8 @@ std::bitset hashes_fingerprint(basic_string const &str * @throw `std::bad_alloc` if the allocation fails. */ template -std::vector sorted_order(objects_type_ const *begin, objects_type_ const *end, - string_extractor_ &&extractor) noexcept(false) { +std::vector sorted_order( // + objects_type_ const *begin, objects_type_ const *end, string_extractor_ &&extractor) noexcept(false) { std::vector order(end - begin); sorted_order(begin, end, order.data(), std::forward(extractor)); return order; @@ -4054,8 +4066,8 @@ std::vector sorted_order(objects_type_ const *begin, objects_type_ */ template std::vector sorted_order(string_like_type_ const *begin, string_like_type_ const *end) noexcept(false) { - static_assert(std::is_convertible::value, - "The type must be convertible to string_view."); + static_assert( // + std::is_convertible::value, "The type must be convertible to string_view."); return sorted_order(begin, end, [](string_like_type_ const &s) -> string_view { return s; }); } @@ -4066,8 +4078,8 @@ std::vector sorted_order(string_like_type_ const *begin, string_li */ template std::vector sorted_order(std::vector const &array) noexcept(false) { - static_assert(std::is_convertible::value, - "The type must be convertible to string_view."); + static_assert( // + std::is_convertible::value, "The type must be convertible to string_view."); return sorted_order(array.data(), array.data() + array.size(), [](string_like_type_ const &s) -> string_view { return s; }); } diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index be4a3e0d..8002b8a0 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -356,6 +356,12 @@ typedef union sz_charset_t { /** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } +/** @brief Initializes a bit-set to all ASCII character. */ +SZ_PUBLIC void sz_charset_init_ascii(sz_charset_t *s) { + s->_u64s[0] = s->_u64s[1] = 0xFFFFFFFFFFFFFFFFull; + s->_u64s[2] = s->_u64s[3] = 0; +} + /** @brief Adds a character to the set and accepts @b unsigned integers. */ SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } @@ -697,7 +703,7 @@ SZ_PUBLIC void sz_sequence_from_u64tape( // #define SZ_CACHE_LINE_WIDTH (64) // bytes /** - * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode + * @brief Similar to `assert`, the `_sz_assert` is used in the SZ_DEBUG mode * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. * @note If you want to catch it, put a breakpoint at @b `__GI_exit` */ @@ -708,12 +714,12 @@ SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int l fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); exit(EXIT_FAILURE); } -#define sz_assert(condition) \ +#define _sz_assert(condition) \ do { \ if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ } while (0) #else -#define sz_assert(condition) ((void)(condition)) +#define _sz_assert(condition) ((void)(condition)) #endif /* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. @@ -732,13 +738,13 @@ SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int l // Use the serial version on 32-bit x86 and on Arm. #if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { - sz_assert(x != 0); + _sz_assert(x != 0); int n = 0; while ((x & 1) == 0) { n++, x >>= 1; } return n; } SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { - sz_assert(x != 0); + _sz_assert(x != 0); int n = 0; while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } return n; @@ -749,13 +755,13 @@ SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; } SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { - sz_assert(x != 0); + _sz_assert(x != 0); int n = 0; while ((x & 1) == 0) { n++, x >>= 1; } return n; } SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { - sz_assert(x != 0); + _sz_assert(x != 0); int n = 0; while ((x & 0x80000000u) == 0) { n++, x <<= 1; } return n; @@ -896,7 +902,7 @@ SZ_INTERNAL void sz_ssize_clamp_interval( // * @brief Compute the logarithm base 2 of a positive integer, rounding down. */ SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { - sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); + _sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); sz_size_t leading_zeros = sz_u64_clz(x); return 63 - leading_zeros; } @@ -1042,33 +1048,6 @@ SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *h sz_unused(start && length && handle); } -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, - void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - #pragma GCC visibility pop #pragma endregion diff --git a/scripts/bench_memory.cpp b/scripts/bench_memory.cpp index ee6ae03b..93d7ab2d 100644 --- a/scripts/bench_memory.cpp +++ b/scripts/bench_memory.cpp @@ -69,11 +69,11 @@ tracked_unary_functions_t copy_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o tracked_unary_functions_t result = { {"memcpy" + suffix, wrap_sz(memcpy)}, {"sz_copy_serial" + suffix, wrap_sz(sz_copy_serial)}, -#if SZ_USE_ICE - {"sz_copy_avx512" + suffix, wrap_sz(sz_copy_avx512)}, +#if SZ_USE_SKYLAKE + {"sz_copy_skylake" + suffix, wrap_sz(sz_copy_skylake)}, #endif #if SZ_USE_HASWELL - {"sz_copy_avx2" + suffix, wrap_sz(sz_copy_avx2)}, + {"sz_copy_haswell" + suffix, wrap_sz(sz_copy_haswell)}, #endif #if SZ_USE_SVE {"sz_copy_sve" + suffix, wrap_sz(sz_copy_sve)}, @@ -109,11 +109,11 @@ tracked_unary_functions_t fill_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o return slice.size(); })}, {"sz_fill_serial", wrap_sz(sz_fill_serial)}, -#if SZ_USE_ICE - {"sz_fill_avx512", wrap_sz(sz_fill_avx512)}, +#if SZ_USE_SKYLAKE + {"sz_fill_avx512", wrap_sz(sz_fill_skylake)}, #endif #if SZ_USE_HASWELL - {"sz_fill_avx2", wrap_sz(sz_fill_avx2)}, + {"sz_fill_haswell", wrap_sz(sz_fill_haswell)}, #endif #if SZ_USE_SVE {"sz_fill_sve", wrap_sz(sz_fill_sve)}, @@ -149,11 +149,11 @@ tracked_unary_functions_t move_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o tracked_unary_functions_t result = { {"memmove" + suffix, wrap_sz(memmove)}, {"sz_move_serial" + suffix, wrap_sz(sz_move_serial)}, -#if SZ_USE_ICE - {"sz_move_avx512" + suffix, wrap_sz(sz_move_avx512)}, +#if SZ_USE_SKYLAKE + {"sz_move_skylake" + suffix, wrap_sz(sz_move_skylake)}, #endif #if SZ_USE_HASWELL - {"sz_move_avx2" + suffix, wrap_sz(sz_move_avx2)}, + {"sz_move_haswell" + suffix, wrap_sz(sz_move_haswell)}, #endif #if SZ_USE_NEON {"sz_move_neon" + suffix, wrap_sz(sz_move_neon)}, @@ -196,7 +196,7 @@ tracked_unary_functions_t transform_functions() { {"sz_look_up_transform_ice", wrap_sz(sz_look_up_transform_ice)}, #endif #if SZ_USE_HASWELL - {"sz_look_up_transform_avx2", wrap_sz(sz_look_up_transform_avx2)}, + {"sz_look_up_transform_haswell", wrap_sz(sz_look_up_transform_haswell)}, #endif #if SZ_USE_NEON {"sz_look_up_transform_neon", wrap_sz(sz_look_up_transform_neon)}, diff --git a/scripts/bench_similarity.cpp b/scripts/bench_similarity.cpp index 140433e2..9aa964c3 100644 --- a/scripts/bench_similarity.cpp +++ b/scripts/bench_similarity.cpp @@ -52,11 +52,11 @@ tracked_binary_functions_t distance_functions() { }; tracked_binary_functions_t result = { {"naive", wrap_baseline}, - {"sz_edit_distance", wrap_sz_distance(sz_edit_distance_serial), true}, - {"sz_alignment_score", wrap_sz_scoring(sz_alignment_score_serial), true}, + {"sz_edit_distance_serial", wrap_sz_distance(sz_edit_distance_serial), true}, + {"sz_alignment_score_serial", wrap_sz_scoring(sz_alignment_score_serial), true}, #if SZ_USE_ICE - {"sz_edit_distance_avx512", wrap_sz_distance(sz_edit_distance_avx512), true}, - {"sz_alignment_score_avx512", wrap_sz_scoring(sz_alignment_score_avx512), true}, + {"sz_edit_distance_ice", wrap_sz_distance(sz_edit_distance_ice), true}, + {"sz_alignment_score_ice", wrap_sz_scoring(sz_alignment_score_ice), true}, #endif }; return result; diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 1120ad52..492f93f4 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -55,8 +55,8 @@ tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, st }; std::string suffix = std::to_string(window_width) + ":step" + std::to_string(step); tracked_unary_functions_t result = { -#if SZ_USE_ICE - {"sz_hashes_avx512:" + suffix, wrap_sz(sz_hashes_avx512)}, +#if SZ_USE_SKYLAKE + {"sz_hashes_skylake:" + suffix, wrap_sz(sz_hashes_skylake)}, #endif #if SZ_USE_HASWELL {"sz_hashes_haswell:" + suffix, wrap_sz(sz_hashes_haswell)}, @@ -120,8 +120,8 @@ tracked_binary_functions_t equality_functions() { #if SZ_USE_HASWELL {"sz_equal_haswell", wrap_sz(sz_equal_haswell), true}, #endif -#if SZ_USE_ICE - {"sz_equal_avx512", wrap_sz(sz_equal_avx512), true}, +#if SZ_USE_SKYLAKE + {"sz_equal_skylake", wrap_sz(sz_equal_skylake), true}, #endif {"memcmp", [](std::string_view a, std::string_view b) { @@ -147,8 +147,8 @@ tracked_binary_functions_t ordering_functions() { #if SZ_USE_HASWELL {"sz_order_haswell", wrap_sz(sz_order_haswell), true}, #endif -#if SZ_USE_ICE - {"sz_order_avx512", wrap_sz(sz_order_avx512), true}, +#if SZ_USE_SKYLAKE + {"sz_order_skylake", wrap_sz(sz_order_skylake), true}, #endif {"memcmp", [](std::string_view a, std::string_view b) { diff --git a/scripts/test.cpp b/scripts/test.cpp index db856a8e..3f9add3b 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -121,7 +121,7 @@ static void test_arithmetical_utilities() { assert(sz_size_bit_ceil(1000000000ull) == (1ull << 30)); assert(sz_size_bit_ceil(2000000000ull) == (1ull << 31)); -#if SZ_DETECT_64_BIT +#if _SZ_IS_64_BIT assert(sz_size_bit_ceil(4000000000ull) == (1ull << 32)); assert(sz_size_bit_ceil(8000000000ull) == (1ull << 33)); assert(sz_size_bit_ceil(16000000000ull) == (1ull << 34)); @@ -130,11 +130,6 @@ static void test_arithmetical_utilities() { assert(sz_size_bit_ceil((1ull << 62) + 1) == (1ull << 63)); assert(sz_size_bit_ceil((1ull << 63)) == (1ull << 63)); #endif - - for (sz_u16_t number = 0; number != 256; ++number) - for (sz_u16_t divisor = 2; divisor != 256; ++divisor) - assert(sz_u8_divide(static_cast(number), static_cast(divisor)) == - (static_cast(number) / static_cast(divisor))); } inline void expect_equality(char const *a, char const *b, std::size_t size) { @@ -571,7 +566,7 @@ static void test_stl_compatibility_for_updates() { // On 32-bit systems the base capacity can be larger than our `z::string::min_capacity`. // It's true for MSVC: https://github.com/ashvardanian/StringZilla/issues/168 - if (SZ_DETECT_64_BIT) assert_scoped(str s = "hello", s.shrink_to_fit(), s.capacity() <= sz::string::min_capacity); + if (_SZ_IS_64_BIT) assert_scoped(str s = "hello", s.shrink_to_fit(), s.capacity() <= sz::string::min_capacity); // Concatenation. // Following are missing in strings, but are present in vectors. @@ -1559,16 +1554,16 @@ static void test_stl_containers() { int main(int argc, char const **argv) { - auto dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, SZ_SIZE_MAX); - sz_assert(dist == 5); - dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 3); - sz_assert(dist == SZ_SIZE_MAX); - dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 4); - sz_assert(dist == SZ_SIZE_MAX); - dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 5); - sz_assert(dist == 5); - dist = _sz_edit_distance_skewed_diagonals_upto63_avx512("kiten", 5, "katerinas", 9, 6); - sz_assert(dist == 5); + auto dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, SZ_SIZE_MAX); + _sz_assert(dist == 5); + dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 3); + _sz_assert(dist == SZ_SIZE_MAX); + dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 4); + _sz_assert(dist == SZ_SIZE_MAX); + dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 5); + _sz_assert(dist == 5); + dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 6); + _sz_assert(dist == 5); // Similarity measures and fuzzy search test_levenshtein_distances(); From 364e2ca4908fa3eae21689dc3afc7a1460a1a023 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 8 Dec 2024 19:48:31 +0000 Subject: [PATCH 049/751] Make: Rename `stringzillite` to `stringzilla_bare` --- .github/workflows/prerelease.yml | 2 +- .github/workflows/release.yml | 24 +++++------ CMakeLists.txt | 68 ++++++++++++++++++++++++-------- CONTRIBUTING.md | 4 +- 4 files changed, 66 insertions(+), 32 deletions(-) diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index 8f8b7803..57514b79 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -275,7 +275,7 @@ jobs: # We can't run the produced builds, but we can make sure they exist - name: Test artifacts presense run: | - test -e build_artifacts/libstringzillite.so + test -e build_artifacts/libstringzilla_bare.so test -e build_artifacts/libstringzilla_shared.so test -e build_artifacts/stringzilla_test_cpp20 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 80a8f989..6a726b14 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -253,15 +253,15 @@ jobs: cmake --build build_release --config Release - cp build_release/libstringzillite.so "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.so" - mkdir -p "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/DEBIAN" - touch "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/DEBIAN/control" - mkdir -p "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/lib" - mkdir "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/include" - cp include/stringzilla/stringzilla.h "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/include/" - cp build_release/libstringzillite.so "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/lib/" - echo -e "Package: stringzilla\nVersion: ${{ steps.set_version.outputs.version }}\nMaintainer: Ash Vardanian\nArchitecture: ${{ matrix.arch }}\nDescription: SIMD-accelerated string search, sort, hashes, fingerprints, & edit distances" > "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/DEBIAN/control" - dpkg-deb --build "stringzillite_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}" + cp build_release/libstringzilla_bare.so "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.so" + mkdir -p "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/DEBIAN" + touch "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/DEBIAN/control" + mkdir -p "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/lib" + mkdir "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/include" + cp include/stringzilla/stringzilla.h "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/include/" + cp build_release/libstringzilla_bare.so "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/usr/local/lib/" + echo -e "Package: stringzilla\nVersion: ${{ steps.set_version.outputs.version }}\nMaintainer: Ash Vardanian\nArchitecture: ${{ matrix.arch }}\nDescription: SIMD-accelerated string search, sort, hashes, fingerprints, & edit distances" > "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}/DEBIAN/control" + dpkg-deb --build "stringzilla_bare_linux_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}" - name: Upload library uses: xresloader/upload-to-github-release@v1 @@ -314,14 +314,14 @@ jobs: run: | cmake -DCMAKE_BUILD_TYPE=Release -B build_release cmake --build build_release --config Release - tar -cvf "stringzillite_windows_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.tar" "build_release/stringzillite.dll" "./include/stringzilla/stringzilla.h" + tar -cvf "stringzilla_bare_windows_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.tar" "build_release/stringzilla_bare.dll" "./include/stringzilla/stringzilla.h" - name: Upload archive uses: xresloader/upload-to-github-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: - file: "stringzillite_windows_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.tar" + file: "stringzilla_bare_windows_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.tar" update_latest_release: true create_macos_library: @@ -349,7 +349,7 @@ jobs: run: | cmake -DCMAKE_BUILD_TYPE=Release -B build_release cmake --build build_release --config Release - zip -r stringzillite_macos_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.zip build_release/libstringzillite.dylib include/stringzilla/stringzilla.h + zip -r stringzilla_bare_macos_${{ matrix.arch }}_${{ steps.set_version.outputs.version }}.zip build_release/libstringzilla_bare.dylib include/stringzilla/stringzilla.h - name: Upload archive uses: xresloader/upload-to-github-release@v1 diff --git a/CMakeLists.txt b/CMakeLists.txt index c09fd6e7..6b931960 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,37 @@ -cmake_minimum_required(VERSION 3.1) +# StringZilla CMakeLists.txt +# +# This file defines several library build & installation targets: +# +# - stringzilla_header: A header-only library with the StringZilla C and C++ headers. +# - stringzilla_shared: A shared library with the StringZilla C and C++ headers and dynamic SIMD dispatch. +# - stringzilla_bare: A shared library with the StringZilla headers, but without linking the standard C library. +# +# Tests for different C++ standards: +# +# - stringzilla_test_cpp11: A test executable for C++11. +# - stringzilla_test_cpp14: A test executable for C++14. +# - stringzilla_test_cpp17: A test executable for C++17. +# - stringzilla_test_cpp20: A test executable for C++20. +# +# Tests for different SIMD architectures: +# +# - stringzilla_test_cpp20_serial: A test executable for serial execution. +# - stringzilla_test_cpp20_haswell: A test executable for AVX2. +# - stringzilla_test_cpp20_ice: A test executable for AVX-512. +# - stringzilla_test_cpp20_neon: A test executable for ARM Neon. +# - stringzilla_test_cpp20_sve: A test executable for ARM Scalable Vector Extension. +# +# Benchmarks: +# +# - stringzilla_bench_search: A benchmark for substring search operations. +# - stringzilla_bench_similarity: A benchmark for similarity operations. +# - stringzilla_bench_sort: A benchmark for sorting operations. +# - stringzilla_bench_token: A benchmark for comparators and hash functions. +# - stringzilla_bench_container: A benchmark for STL containers powered by StringZilla. +# - stringzilla_bench_memory: A benchmark for LibC-style low-level memory operations. +# +# For higher-level language bindings separate build scripts are provided, native to each toolchain. +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project( stringzilla VERSION 3.11.0 @@ -7,7 +40,7 @@ project( HOMEPAGE_URL "https://github.com/ashvardanian/stringzilla") set(CMAKE_C_STANDARD 99) -set(CMAKE_CXX_STANDARD 17) # This gives many issues for msvc and clang-cl, especially if later on you set it to std-c++11 later on in the tests... +set(CMAKE_CXX_STANDARD 11) set(CMAKE_C_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF) @@ -270,18 +303,19 @@ if(${STRINGZILLA_BUILD_TEST}) if(SZ_PLATFORM_X86) # x86 specific backends if (MSVC) - define_launcher(stringzilla_test_cpp20_x86_serial scripts/test.cpp 20 "AVX") - define_launcher(stringzilla_test_cpp20_x86_avx2 scripts/test.cpp 20 "AVX2") - define_launcher(stringzilla_test_cpp20_x86_avx512 scripts/test.cpp 20 "AVX512") + define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "AVX") + define_launcher(stringzilla_test_cpp20_haswell scripts/test.cpp 20 "AVX2") + define_launcher(stringzilla_test_cpp20_ice scripts/test.cpp 20 "AVX512") else() - define_launcher(stringzilla_test_cpp20_x86_serial scripts/test.cpp 20 "ivybridge") - define_launcher(stringzilla_test_cpp20_x86_avx2 scripts/test.cpp 20 "haswell") - define_launcher(stringzilla_test_cpp20_x86_avx512 scripts/test.cpp 20 "sapphirerapids") + define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "ivybridge") + define_launcher(stringzilla_test_cpp20_haswell scripts/test.cpp 20 "haswell") + define_launcher(stringzilla_test_cpp20_ice scripts/test.cpp 20 "sapphirerapids") endif() elseif(SZ_PLATFORM_ARM) # ARM specific backends - define_launcher(stringzilla_test_cpp20_arm_serial scripts/test.cpp 20 "armv8-a") - define_launcher(stringzilla_test_cpp20_arm_neon scripts/test.cpp 20 "armv8-a+simd") + define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "armv8-a") + define_launcher(stringzilla_test_cpp20_neon scripts/test.cpp 20 "armv8-a+simd") + define_launcher(stringzilla_test_cpp20_sve scripts/test.cpp 20 "armv8.2-a+sve") endif() endif() @@ -335,16 +369,16 @@ if(${STRINGZILLA_BUILD_SHARED}) target_compile_definitions(stringzilla_shared PRIVATE "SZ_OVERRIDE_LIBC=1") # Try compiling a version without linking the LibC - define_shared(stringzillite) - target_compile_definitions(stringzillite PRIVATE "SZ_AVOID_LIBC=1") - target_compile_definitions(stringzillite PRIVATE "SZ_OVERRIDE_LIBC=1") + define_shared(stringzilla_bare) + target_compile_definitions(stringzilla_bare PRIVATE "SZ_AVOID_LIBC=1") + target_compile_definitions(stringzilla_bare PRIVATE "SZ_OVERRIDE_LIBC=1") # Avoid built-ins on MSVC and other compilers, as that will cause compilation errors - target_compile_options(stringzillite PRIVATE + target_compile_options(stringzilla_bare PRIVATE "$<$:-fno-builtin;-nostdlib>" "$<$:/Oi-;/GS->") - target_link_options(stringzillite PRIVATE "$<$:-nostdlib>") - target_link_options(stringzillite PRIVATE "$<$:/NODEFAULTLIB>") + target_link_options(stringzilla_bare PRIVATE "$<$:-nostdlib>") + target_link_options(stringzilla_bare PRIVATE "$<$:/NODEFAULTLIB>") endif() @@ -362,7 +396,7 @@ if(STRINGZILLA_INSTALL) RESOURCE RUNTIME) install( - TARGETS stringzillite + TARGETS stringzilla_bare ARCHIVE BUNDLE FRAMEWORK diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 524d6c49..231291c8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -131,8 +131,8 @@ Using modern syntax, this is how you build and run the test suite: cmake -D STRINGZILLA_BUILD_TEST=1 -D CMAKE_BUILD_TYPE=Debug -B build_debug cmake --build build_debug --config Debug # Which will produce the following targets: build_debug/stringzilla_test_cpp20 # Unit test for the entire library compiled for current hardware -build_debug/stringzilla_test_cpp20_x86_serial # x86 variant compiled for IvyBridge - last arch. before AVX2 -build_debug/stringzilla_test_cpp20_arm_serial # Arm variant compiled without Neon +build_debug/stringzilla_test_cpp20_serial # x86 variant compiled for IvyBridge - last arch. before AVX2 +build_debug/stringzilla_test_cpp20_serial # Arm variant compiled without Neon ``` To use CppCheck for static analysis make sure to export the compilation commands. From 6d61c2166671ebecc57dba2a5016c5404872b02a Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 8 Dec 2024 20:02:43 +0000 Subject: [PATCH 050/751] Make: Detect Apple Universal builds Imported from #169 Co-authored-by: ashbob999 <32575256+ashbob999@users.noreply.github.com> --- pyproject.toml | 7 +++++++ setup.py | 7 ++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e93355ae..a8dd42e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,5 +117,12 @@ select = "*-macos*_arm64" inherit.environment = "append" environment.SZ_ARM64 = "1" +# Detect MacOS Universal2 builds +[[tool.cibuildwheel.overrides]] +select = "*-macos*_universal2" +inherit.environment = "append" +environment.SZ_X86_64 = "1" +environment.SZ_ARM64 = "1" + [tool.cibuildwheel.macos.environment] MACOSX_DEPLOYMENT_TARGET = "10.11" diff --git a/setup.py b/setup.py index 27ef6be2..a1bce8ad 100644 --- a/setup.py +++ b/setup.py @@ -88,12 +88,13 @@ def darwin_settings() -> Tuple[List[str], List[str], List[Tuple[str]]]: # so we must pre-set the CPU generation. Technically the last Intel-based Apple # product was the 2021 MacBook Pro, which had the "Coffee Lake" architecture. # During Universal builds, however, even AVX header cause compilation errors. - can_use_avx2 = is_64bit_x86() and sysconfig.get_platform().startswith("universal") + is_building_x86 = is_64bit_x86() or "universal" in sysconfig.get_platform() + is_building_arm = is_64bit_arm() or "universal" in sysconfig.get_platform() macros_args = [ - ("SZ_USE_HASWELL", "1" if can_use_avx2 else "0"), + ("SZ_USE_HASWELL", "1" if is_building_x86 else "0"), ("SZ_USE_SKYLAKE", "0"), ("SZ_USE_ICE", "0"), - ("SZ_USE_NEON", "1" if is_64bit_arm() else "0"), + ("SZ_USE_NEON", "1" if is_building_arm else "0"), ("SZ_USE_SVE", "0"), ] From 19c2ae9743ca9e9b767cda2eb5e828553fd850cd Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 8 Dec 2024 20:04:27 +0000 Subject: [PATCH 051/751] Improve: C++ version macros naming --- .vscode/settings.json | 4 +++- CMakeLists.txt | 8 ++++---- include/stringzilla/stringzilla.hpp | 26 +++++++++++++------------- include/stringzilla/types.h | 6 +++--- scripts/test.cpp | 20 ++++++++++---------- 5 files changed, 33 insertions(+), 31 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ee1f1d3b..9d0e1b53 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -39,6 +39,7 @@ "cheminformatics", "cibuildwheel", "CONCAT", + "constexpr", "copydoc", "Corasick", "cptr", @@ -82,6 +83,7 @@ "Merkle-Damgård", "Mersenne", "MODINIT", + "MSVC", "napi", "nargsf", "ndim", @@ -120,7 +122,7 @@ "startswith", "STL", "stringzilla", - "stringzillite", + "stringzilla_bare", "Strs", "strzl", "substr", diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b931960..81e9bbaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,10 +8,10 @@ # # Tests for different C++ standards: # -# - stringzilla_test_cpp11: A test executable for C++11. -# - stringzilla_test_cpp14: A test executable for C++14. -# - stringzilla_test_cpp17: A test executable for C++17. -# - stringzilla_test_cpp20: A test executable for C++20. +# - stringzilla_test_cpp11: C++11 baseline support. +# - stringzilla_test_cpp14: C++14 support with `std::less`-like function objects. +# - stringzilla_test_cpp17: C++17 support with `std::string_view` compatibility. +# - stringzilla_test_cpp20: C++20 support with `<=>` operator and more `constexpr` features. # # Tests for different SIMD architectures: # diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index f65b0212..c705dae6 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -28,18 +28,18 @@ /* We need to detect the version of the C++ language we are compiled with. * This will affect recent features like `operator<=>` and tests against STL. */ -#define SZ_DETECT_CPP_23 (__cplusplus >= 202101L) -#define SZ_DETECT_CPP20 (__cplusplus >= 202002L) -#define SZ_DETECT_CPP_17 (__cplusplus >= 201703L) -#define SZ_DETECT_CPP14 (__cplusplus >= 201402L) -#define SZ_DETECT_CPP_11 (__cplusplus >= 201103L) -#define SZ_DETECT_CPP_98 (__cplusplus >= 199711L) +#define _SZ_IS_CPP23 (__cplusplus >= 202101L) +#define _SZ_IS_CPP20 (__cplusplus >= 202002L) +#define _SZ_IS_CPP17 (__cplusplus >= 201703L) +#define _SZ_IS_CPP14 (__cplusplus >= 201402L) +#define _SZ_IS_CPP11 (__cplusplus >= 201103L) +#define _SZ_IS_CPP98 (__cplusplus >= 199711L) /** * @brief The `constexpr` keyword has different applicability scope in different C++ versions. * Useful for STL conversion operators, as several `std::string` members are `constexpr` in C++20. */ -#if SZ_DETECT_CPP20 +#if _SZ_IS_CPP20 #define sz_constexpr_if_cpp20 constexpr #else #define sz_constexpr_if_cpp20 @@ -50,7 +50,7 @@ #include #include #include -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view #include #endif #endif @@ -398,7 +398,7 @@ struct end_sentinel_type {}; struct include_overlaps_type {}; struct exclude_overlaps_type {}; -#if SZ_DETECT_CPP_17 +#if _SZ_IS_CPP17 inline static constexpr end_sentinel_type end_sentinel; inline static constexpr include_overlaps_type include_overlaps; inline static constexpr exclude_overlaps_type exclude_overlaps; @@ -1265,7 +1265,7 @@ class basic_string_slice { return os.write(str.data(), str.size()); } -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view template ::value, int>::type = 0> sz_constexpr_if_cpp20 basic_string_slice(std::string_view const &other) noexcept @@ -1496,7 +1496,7 @@ class basic_string_slice { sz_equal(data() + other.first.size(), other.second.data(), other.second.size()) == sz_true_k; } -#if SZ_DETECT_CPP20 +#if _SZ_IS_CPP20 /** @brief Computes the lexicographic ordering between this and the ::other string. */ std::strong_ordering operator<=>(string_view other) const noexcept { @@ -2175,7 +2175,7 @@ class basic_string { return os.write(str.data(), str.size()); } -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view basic_string(std::string_view other) noexcept(false) : basic_string(other.data(), other.size()) {} basic_string &operator=(std::string_view other) noexcept(false) { return assign({other.data(), other.size()}); } @@ -2421,7 +2421,7 @@ class basic_string { bool operator==(string_view other) const noexcept { return view() == other; } bool operator==(const_pointer other) const noexcept { return view() == string_view(other); } -#if SZ_DETECT_CPP20 +#if _SZ_IS_CPP20 /** @brief Computes the lexicographic ordering between this and the ::other string. */ std::strong_ordering operator<=>(basic_string const &other) const noexcept { return view() <=> other.view(); } diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 8002b8a0..c34289fd 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -703,12 +703,12 @@ SZ_PUBLIC void sz_sequence_from_u64tape( // #define SZ_CACHE_LINE_WIDTH (64) // bytes /** - * @brief Similar to `assert`, the `_sz_assert` is used in the SZ_DEBUG mode - * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. + * @brief Similar to `assert`, the `_sz_assert` is used in the `SZ_DEBUG` mode + * to check the invariants of the library. It's a no-op in the "Release" mode. * @note If you want to catch it, put a breakpoint at @b `__GI_exit` */ #if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf` +#include // `fprintf`, `stderr` #include // `EXIT_FAILURE` SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); diff --git a/scripts/test.cpp b/scripts/test.cpp index 3f9add3b..9ae7e14c 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -38,7 +38,7 @@ #include // Baseline #include // Baseline -#if !SZ_DETECT_CPP_11 +#if !_SZ_IS_CPP11 #error "This test requires C++11 or later." #endif @@ -52,7 +52,7 @@ using sz::literals::operator""_sz; * Instantiate all the templates to make the symbols visible and also check * for weird compilation errors on uncommon paths. */ -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view template class std::basic_string_view; #endif template class sz::basic_string_slice; @@ -412,7 +412,7 @@ static void test_stl_compatibility_for_reads() { assert(str("b") >= str("a")); assert(str("a") < str("aa")); -#if SZ_DETECT_CPP20 && __cpp_lib_three_way_comparison +#if _SZ_IS_CPP20 && __cpp_lib_three_way_comparison // Spaceship operator instead of conventional comparions. assert((str("a") <=> str("b")) == std::strong_ordering::less); assert((str("b") <=> str("a")) == std::strong_ordering::greater); @@ -455,7 +455,7 @@ static void test_stl_compatibility_for_reads() { assert(str("hello world").compare(6, 5, "worlds", 5) == 0); // Substring "world" in both strings assert(str("hello world").compare(6, 5, "worlds", 6) < 0); // Substring "world" is less than "worlds" -#if SZ_DETECT_CPP20 && __cpp_lib_starts_ends_with +#if _SZ_IS_CPP20 && __cpp_lib_starts_ends_with // Prefix and suffix checks against strings. assert(str("https://cppreference.com").starts_with(str("http")) == true); assert(str("https://cppreference.com").starts_with(str("ftp")) == false); @@ -475,7 +475,7 @@ static void test_stl_compatibility_for_reads() { assert(str("string_view").ends_with("View") == false); #endif -#if SZ_DETECT_CPP_23 && __cpp_lib_string_contains +#if _SZ_IS_CPP23 && __cpp_lib_string_contains // Checking basic substring presence. assert(str("hello").contains(str("ell")) == true); assert(str("hello").contains(str("oll")) == false); @@ -506,7 +506,7 @@ static void test_stl_compatibility_for_reads() { assert(std::hash {}("hello") != 0); assert_scoped(std::ostringstream os, os << str("hello"), os.str() == "hello"); -#if SZ_DETECT_CPP14 +#if _SZ_IS_CPP14 // Comparison function objects are a C++14 feature. assert(std::equal_to {}("hello", "world") == false); assert(std::less {}("hello", "world") == true); @@ -660,7 +660,7 @@ static void test_stl_conversions() { sz_unused(sz); sz_unused(szv); } -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view // From STL `string_view` to StringZilla and vice-versa. { std::string_view stl {"hello"}; @@ -1179,7 +1179,7 @@ static void test_search() { assert(rsplits[4] == ""); } -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view /** * Evaluates the correctness of a "matcher", searching for all the occurrences of the `needle_stl` @@ -1582,7 +1582,7 @@ int main(int argc, char const **argv) { test_replacements(); // Compatibility with STL -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view test_stl_compatibility_for_reads(); #endif test_stl_compatibility_for_reads(); @@ -1607,7 +1607,7 @@ int main(int argc, char const **argv) { test_stl_conversions(); test_comparisons(); test_search(); -#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#if _SZ_IS_CPP17 && __cpp_lib_string_view test_search_with_misaligned_repetitions(); #endif From 645539b468f3c2902061425684d9b002c43a14f7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 8 Dec 2024 20:12:39 +0000 Subject: [PATCH 052/751] Fix: Overriding LibC in 32-bit Windows Imported from #169 Co-authored-by: ashbob999 <32575256+ashbob999@users.noreply.github.com> --- c/lib.c | 49 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/c/lib.c b/c/lib.c index 8a0a75b9..d829e379 100644 --- a/c/lib.c +++ b/c/lib.c @@ -202,7 +202,7 @@ __attribute__((aligned(64))) static sz_implementations_t sz_dispatch_table; * @brief Initializes a global static "virtual table" of supported backends * Run it just once to avoiding unnecessary `if`-s. */ -static void sz_dispatch_table_init(void) { +SZ_DYNAMIC void sz_dispatch_table_init(void) { sz_implementations_t *impl = &sz_dispatch_table; sz_capability_t caps = sz_capabilities(); sz_unused(caps); //< Unused when compiling on pre-SIMD machines. @@ -294,9 +294,17 @@ static void sz_dispatch_table_init(void) { } #if defined(_MSC_VER) -#pragma section(".CRT$XCU", read) -__declspec(allocate(".CRT$XCU")) void (*_sz_dispatch_table_init)() = sz_dispatch_table_init; +/* + * Makes sure the `sz_dispatch_table_init` function is called at startup, from either an executable or when loading + * a DLL. The section name must be no more than 8 characters long, and must be between .CRT$XCA and .CRT$XCZ + * alphabetically (exclusive). The Microsoft C++ compiler puts C++ initialisation code in .CRT$XCU, so avoid that + * section: https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-initialization?view=msvc-170 + */ +#pragma comment(linker, "/INCLUDE:_sz_dispatch_table_init") +#pragma section(".CRT$XCS", read) +__declspec(allocate(".CRT$XCS")) void (*_sz_dispatch_table_init)() = sz_dispatch_table_init; +/* Called either from CRT code or out own `_DLLMainCRTStartup`, when a DLL is loaded. */ BOOL WINAPI DllMain(HINSTANCE hints, DWORD forward_reason, LPVOID lp) { switch (forward_reason) { case DLL_PROCESS_ATTACH: @@ -309,6 +317,14 @@ BOOL WINAPI DllMain(HINSTANCE hints, DWORD forward_reason, LPVOID lp) { return TRUE; } +#if SZ_AVOID_LIBC +/* Called when the DLL is loaded, and ther is no CRT code. */ +BOOL WINAPI _DllMainCRTStartup(HINSTANCE hints, DWORD forward_reason, LPVOID lp) { + DllMain(hints, forward_reason, lp); + return TRUE; +} +#endif + #else __attribute__((constructor)) static void sz_dispatch_table_init_on_gcc_or_clang(void) { sz_dispatch_table_init(); } #endif @@ -451,14 +467,20 @@ SZ_DYNAMIC void sz_generate( // } // Provide overrides for the libc mem* functions -#if SZ_OVERRIDE_LIBC && !(defined(__CYGWIN__)) +#if SZ_OVERRIDE_LIBC && !defined(__CYGWIN__) -// SZ_DYNAMIC can't be use here for MSVC, because MSVC complains about different linkage (C2375), probably due to to the -// CRT headers specifying the function as __declspec(dllimport), there might be a combination of defines that works. But -// for now they will be manually exported using linker flags +// SZ_DYNAMIC can't be use here for MSVC, because MSVC complains about different linkage (C2375), probably due +// to to the CRT headers specifying the function as `__declspec(dllimport)`, there might be a combination of +// defines that works. But for now they will be manually exported using linker flags. +// Also when building for 32-bit we must add an underscore to the exported function name, because that's +// how `__cdecl` functions are decorated in MSVC: https://stackoverflow.com/questions/62753691) #if defined(_MSC_VER) +#if SZ_DETECT_64_BIT #pragma comment(linker, "/export:memchr") +#else +#pragma comment(linker, "/export:_memchr") +#endif void *__cdecl memchr(void const *s, int c_wide, size_t n) { #else SZ_DYNAMIC void *memchr(void const *s, int c_wide, size_t n) { @@ -468,7 +490,11 @@ SZ_DYNAMIC void *memchr(void const *s, int c_wide, size_t n) { } #if defined(_MSC_VER) +#if SZ_DETECT_64_BIT #pragma comment(linker, "/export:memcpy") +#else +#pragma comment(linker, "/export:_memcpy") +#endif void *__cdecl memcpy(void *dest, void const *src, size_t n) { #else SZ_DYNAMIC void *memcpy(void *dest, void const *src, size_t n) { @@ -478,7 +504,11 @@ SZ_DYNAMIC void *memcpy(void *dest, void const *src, size_t n) { } #if defined(_MSC_VER) +#if SZ_DETECT_64_BIT #pragma comment(linker, "/export:memmove") +#else +#pragma comment(linker, "/export:_memmove") +#endif void *__cdecl memmove(void *dest, void const *src, size_t n) { #else SZ_DYNAMIC void *memmove(void *dest, void const *src, size_t n) { @@ -488,7 +518,11 @@ SZ_DYNAMIC void *memmove(void *dest, void const *src, size_t n) { } #if defined(_MSC_VER) +#if SZ_DETECT_64_BIT #pragma comment(linker, "/export:memset") +#else +#pragma comment(linker, "/export:_memset") +#endif void *__cdecl memset(void *s, int c, size_t n) { #else SZ_DYNAMIC void *memset(void *s, int c, size_t n) { @@ -511,5 +545,6 @@ SZ_DYNAMIC void memfrob(void *s, size_t n) { char const *base64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; sz_generate(base64, 64, s, n, SZ_NULL, SZ_NULL); } + #endif #endif // SZ_OVERRIDE_LIBC From 660923e6d1be94a0cf0e2e97a8d0cebf3af2462f Mon Sep 17 00:00:00 2001 From: Alex Bondarev <44079602+alexbarev@users.noreply.github.com> Date: Mon, 9 Dec 2024 04:30:41 +0400 Subject: [PATCH 053/751] Test: Correct edge cases in ASCII tests --- scripts/test.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/test.cpp b/scripts/test.cpp index e8123995..4aa46766 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -154,7 +154,7 @@ static void test_ascii_utilities() { assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789").is_alnum()); assert(!str("abc!").is_alnum()); - assert(!str("").is_ascii()); + assert(str("").is_ascii()); assert(str("\x00x7F").is_ascii()); assert(!str("abc123🔥").is_ascii()); @@ -175,9 +175,9 @@ static void test_ascii_utilities() { assert(str("ABCDEFGHIJKLMNOPQRSTUVWXYZ").is_upper()); assert(!str("ABCa").is_upper()); - assert(!str("").is_printable()); + assert(str("").is_printable()); assert(str("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+").is_printable()); - assert(!str("012\n").is_printable()); + assert(!str("012🔥").is_printable()); } inline void expect_equality(char const *a, char const *b, std::size_t size) { From 064829ae0ff2501aba404afaacfd7826586377bf Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 9 Dec 2024 07:55:59 +0000 Subject: [PATCH 054/751] Improve: Ignore 40 commits in blame --- .git-blame-ignore-revs | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .git-blame-ignore-revs diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..3d26edb4 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,40 @@ +6512f1d129aeddc8601c9df7332c135038914b68 +fc9e5d61e5fb1c5031f6f10920f6b50e2530de1e +ad2af78f8651870727c5b39e1fea2eff26d71d2f +49e8d9d240993bdf68715a9c87824a032752798d +fc408fa0a0f2d947c610568bd7a5c4a60ecca443 +b835051c09a0ecfc420932de444f3c6839610764 +1ba7982559111d4fc9b58caa7bc7aa1c6e64257c +5b55e19d1378c61da88309b30a38f9cf7c64bf79 +be4c63d926c8628451726863e4d14dbd1ea374dd +8b401bd41e4bd9c29c8fad9a5b83d8232efa50c7 +295d49a38d66b08075357ac829ad66d80b5edab0 +2a1fcd113d217e3124f6501c38e93a318aca37f0 +2f7652141bd8dc3c2c38ab34321567bfcdb91d93 +9e3180019acffe5261f0a1713b4ea324dca79ea0 +45e57eefd796841cbd14ee7f75ec42b42b5bde0c +66778d6b2b3aa0eed27e32fbdceef79b8c54eda5 +c357c3ea756523d3bcc8d8f25068ad08aef5456d +9b1948b3771c21dd56954e5f43301ca8a0b8b1a9 +cbfe5c7ac6371047eae88621b092297474d0b82a +085d2d3c8b99e0f90d320dd027040e554e410929 +3464cb428ae9a8721ab82a8c4bff214aa9ce6254 +5d0d2da422c7df96f9613ec843cd47c579a2edce +89c46810c2f9bfafa31f8592339f9a1b45dcc245 +3f9c248fbf59add2246055462e8fc19dc9f1693b +e23c35ff2c2d4ccb752f4ffbf9b6f39a1677b532 +7fdc58fd26e06c41052287d47a9c729c068a95ca +10d829efcb8ed4cfa5f2db4050f8403184484423 +d74e5dca2e62eb0078cb2ebacc0dac2b8bb92d54 +1f60e6d7c81f0e285e594eb63fee6119e05a3e69 +a6768af38b40307fe66364403f141c285b3e164c +08d0a20d35d3b29a44b9c8a826d53435c3ef839c +9e9f2567d052d635722921a1d70ec63d69ec6669 +974ed78822dc0b519dd61bc1c4dc18d59fe4ad15 +b007ba571860e1d3737d1478c7f8d66ae1839e36 +14ba3bf3c43408438a7de9ad57118c747c1347b1 +9e577be71dcd2e20854bf55f08c54854b3e82989 +8cb0742b2d1b31b61fac5272f17017953c6677e6 +bd547453122e9f8565e5be15f137e7b0de37caca +22e3d1e34d62d68c1e89df7c8bdc201faa18a9de +ecb377541d0c706cf8997faff4f026b07e3f76f3 From e20d207ce5bf7b25da4740ca511a0e5ea44af41f Mon Sep 17 00:00:00 2001 From: Alex <44079602+alexbarev@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:56:46 +0400 Subject: [PATCH 055/751] Fix: Correct `basic_charset` operator (#203) --- include/stringzilla/stringzilla.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index c705dae6..85589909 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -309,7 +309,7 @@ class basic_charset { basic_charset result = *this; result.bitset_._u64s[0] |= other.bitset_._u64s[0], result.bitset_._u64s[1] |= other.bitset_._u64s[1], result.bitset_._u64s[2] |= other.bitset_._u64s[2], result.bitset_._u64s[3] |= other.bitset_._u64s[3]; - return *this; + return result; } inline basic_charset &add(char_type c) noexcept { From 864ee03fdbbba0b71b982cd9c6a206b9a7f96dee Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 9 Dec 2024 08:05:00 +0000 Subject: [PATCH 056/751] Fix: Initializing `basic_charset` Closes #200 --- include/stringzilla/stringzilla.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index a80da804..43869f08 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -283,7 +283,7 @@ class basic_charset { template explicit basic_charset(char_type const (&chars)[count_characters]) noexcept : basic_charset() { static_assert(count_characters > 0, "Character array cannot be empty"); - for (std::size_t i = 0; i < count_characters - 1; ++i) { // count_characters - 1 to exclude the null terminator + for (std::size_t i = 0; i != count_characters; ++i) { char_type c = chars[i]; bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); } @@ -292,7 +292,7 @@ class basic_charset { template explicit basic_charset(std::array const &chars) noexcept : basic_charset() { static_assert(count_characters > 0, "Character array cannot be empty"); - for (std::size_t i = 0; i < count_characters - 1; ++i) { // count_characters - 1 to exclude the null terminator + for (std::size_t i = 0; i != count_characters; ++i) { char_type c = chars[i]; bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); } From c99daf3fe04b6dd5dc2ac74803e868b2df056b31 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 9 Dec 2024 08:06:54 +0000 Subject: [PATCH 057/751] Docs: Formatting docstring --- scripts/test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/test.cpp b/scripts/test.cpp index 4aa46766..72379f78 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -138,8 +138,8 @@ static void test_arithmetical_utilities() { } /** - * @brief Tests various ASCII-based methods (e.g., is_alpha, is_digit) - * provided by `sz::string` and `sz::string_view`. + * @brief Tests various ASCII-based methods (e.g., `is_alpha`, `is_digit`) + * provided by `sz::string` and `sz::string_view`. */ template static void test_ascii_utilities() { From 084d6534d30d668edc6d7790f8aaec438832f12b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:34:34 +0000 Subject: [PATCH 058/751] Fix: Linking `stderr` Co-authored-by: Alex Bondarev <44079602+alexbarev@users.noreply.github.com> --- include/stringzilla/types.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index c34289fd..b9e202ae 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -181,6 +181,12 @@ #include // `uint8_t` #endif +/* The headers needed for the `_sz_assert_failure` function. */ +#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) +#include // `fprintf`, `stderr` +#include // `EXIT_FAILURE` +#endif + /* Compile-time hardware features detection. * All of those can be controlled by the user. */ @@ -708,8 +714,6 @@ SZ_PUBLIC void sz_sequence_from_u64tape( // * @note If you want to catch it, put a breakpoint at @b `__GI_exit` */ #if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) -#include // `fprintf`, `stderr` -#include // `EXIT_FAILURE` SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); exit(EXIT_FAILURE); From 48e0913944f109703ebbfcb7f14e1b7398544af7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:21:17 +0000 Subject: [PATCH 059/751] Fix: Skylake dispatch --- include/stringzilla/memory.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index c17f031f..d8db210b 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -1256,7 +1256,7 @@ SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) #pragma region Core Functionality SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_ICE +#if SZ_USE_SKYLAKE sz_copy_skylake(target, source, length); #elif SZ_USE_HASWELL sz_copy_haswell(target, source, length); @@ -1268,7 +1268,7 @@ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { } SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { -#if SZ_USE_ICE +#if SZ_USE_SKYLAKE sz_move_skylake(target, source, length); #elif SZ_USE_HASWELL sz_move_haswell(target, source, length); @@ -1280,7 +1280,7 @@ SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { } SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { -#if SZ_USE_ICE +#if SZ_USE_SKYLAKE sz_fill_skylake(target, length, value); #elif SZ_USE_HASWELL sz_fill_haswell(target, length, value); From 749b0d86e5cd41df053ea214ef95000bbb90543f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:22:55 +0000 Subject: [PATCH 060/751] Fix: Bounded Levenshtein returns The new uniform behavior across the project is to return a value different from `SZ_SIZE_MAX` when the limit is reached, to differentiate memory allocation and other errors. --- include/stringzilla/similarity.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index 943f7f35..0b119127 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -408,7 +408,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // /* If the minimum distance in this row exceeded the bound, return early */ \ if (min_distance >= bound) { \ alloc->free(buffer, buffer_length, alloc->handle); \ - return bound; \ + return longer_length + 1; \ } \ _distance_t *temporary = previous_distances; \ previous_distances = current_distances; \ @@ -416,7 +416,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // } \ sz_size_t result = previous_distances[shorter_length]; \ alloc->free(buffer, buffer_length, alloc->handle); \ - return sz_min_of_two(result, bound); + return result; // Dispatch the actual computation. if (!bound) { @@ -735,7 +735,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return SZ_SIZE_MAX; + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return longer_length + 1; } // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. @@ -766,7 +766,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return SZ_SIZE_MAX; + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return longer_length + 1; } // Now let's handle the bottom right triangle. @@ -790,7 +790,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return SZ_SIZE_MAX; + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return longer_length + 1; // In every following iterations we take use a shorter prefix of each register, // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. From 2007d494c019448d440bd8a548c62246bab82f94 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:23:37 +0000 Subject: [PATCH 061/751] Fix: `sz_u512_vec_t` members visibility --- include/stringzilla/types.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index b9e202ae..f8fe0c9a 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -545,10 +545,10 @@ typedef union sz_u256_vec_t { * as well as 4x XMM registers or 2x YMM registers or 1x ZMM register. */ typedef union sz_u512_vec_t { -#if SZ_USE_ICE +#if SZ_USE_SKYLAKE || SZ_USE_ICE __m512i zmm; #endif -#if SZ_USE_HASWELL +#if SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE __m256i ymms[2]; __m128i xmms[4]; #endif From f3811d70ee0725ea4d2395a1c7fd1125dac3bc3d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:24:37 +0000 Subject: [PATCH 062/751] Make: Library namespaced aliases --- CMakeLists.txt | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 81e9bbaa..7914aa0e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,19 +111,9 @@ endif() # Configuration include(GNUInstallDirs) -set(STRINGZILLA_TARGET_NAME ${PROJECT_NAME}) set(STRINGZILLA_INCLUDE_BUILD_DIR "${PROJECT_SOURCE_DIR}/include/") set(STRINGZILLA_INCLUDE_INSTALL_DIR "${CMAKE_INSTALL_INCLUDEDIR}") -# Define our library -add_library(${STRINGZILLA_TARGET_NAME} INTERFACE) -add_library(${PROJECT_NAME}::${STRINGZILLA_TARGET_NAME} ALIAS ${STRINGZILLA_TARGET_NAME}) - -target_include_directories( - ${STRINGZILLA_TARGET_NAME} - INTERFACE $ - $) - if(${CMAKE_VERSION} VERSION_EQUAL 3.13 OR ${CMAKE_VERSION} VERSION_GREATER 3.13) include(CTest) @@ -142,7 +132,6 @@ function(set_compiler_flags target cpp_standard target_arch) get_target_property(target_type ${target} TYPE) target_include_directories(${target} PRIVATE scripts) - target_link_libraries(${target} PRIVATE ${STRINGZILLA_TARGET_NAME}) # Set output directory for single-configuration generators (like Make) set_target_properties(${target} PROPERTIES @@ -278,6 +267,7 @@ endfunction() function(define_launcher exec_name source cpp_standard target_arch) add_executable(${exec_name} ${source}) set_compiler_flags(${exec_name} ${cpp_standard} "${target_arch}") + target_link_libraries(${exec_name} PRIVATE stringzilla_header) add_test(NAME ${exec_name} COMMAND ${exec_name}) endfunction() @@ -319,10 +309,20 @@ if(${STRINGZILLA_BUILD_TEST}) endif() endif() +# Define our libraries, first the header-only version +add_library(stringzilla_header INTERFACE) +add_library(${PROJECT_NAME}::stringzilla_header ALIAS stringzilla_header) +target_include_directories( + stringzilla_header + INTERFACE $ + $) + + if(${STRINGZILLA_BUILD_SHARED}) function(define_shared target) add_library(${target} SHARED c/lib.c) + add_library(${PROJECT_NAME}::${target} ALIAS ${target}) set_target_properties(${target} PROPERTIES VERSION ${PROJECT_VERSION} From bd7054ea9d5d0810f303a4a56fa0ee25a53410dd Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:46:57 +0000 Subject: [PATCH 063/751] Fix: Masks back to using `BZHI` --- include/stringzilla/types.h | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index f8fe0c9a..57ff7124 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -800,17 +800,6 @@ SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } #endif -/* - */ -SZ_INTERNAL sz_u16_t _sz_u16_mask_until(sz_size_t n) { return (0x0001u << n) - 1u; } -SZ_INTERNAL sz_u32_t _sz_u32_mask_until(sz_size_t n) { return (0x00000001u << n) - 1u; } -SZ_INTERNAL sz_u64_t _sz_u64_mask_until(sz_size_t n) { return (0x0000000000000001ull << n) - 1ull; } -SZ_INTERNAL sz_u16_t _sz_u16_clamp_mask_until(sz_size_t n) { return n < 16 ? _sz_u16_mask_until(n) : 0xFFFFu; } -SZ_INTERNAL sz_u32_t _sz_u32_clamp_mask_until(sz_size_t n) { return n < 32 ? _sz_u32_mask_until(n) : 0xFFFFFFFFu; } -SZ_INTERNAL sz_u64_t _sz_u64_clamp_mask_until(sz_size_t n) { - return n < 64 ? _sz_u64_mask_until(n) : 0xFFFFFFFFFFFFFFFFull; -} - SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } /** @@ -865,6 +854,22 @@ SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x /** @brief Branchless minimum function for two signed 32-bit integers. */ SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } +/* In AVX-512 we actively use masked operations and the "K mask registers". + * Producing a mask for the first N elements of a sequence can be done using the `1 << N - 1` idiom. + * It, however, induces undefined behavior if `N == 64` or `N == 32` on 64-bit or 32-bit systems respectively. + * Alternatively, the BZHI instruction can be used to clear the bits above N. + */ +#if SZ_USE_SKYLAKE || SZ_USE_ICE +SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { return (__mmask16)_bzhi_u32(0xFFFFu, n); } +SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { return (__mmask32)_bzhi_u64(0xFFFFFFFFu, n); } +SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { return (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n); } +SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { return n < 16 ? _sz_u16_mask_until(n) : 0xFFFFu; } +SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { return n < 32 ? _sz_u32_mask_until(n) : 0xFFFFFFFFu; } +SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { + return n < 64 ? _sz_u64_mask_until(n) : 0xFFFFFFFFFFFFFFFFull; +} +#endif + /** * @brief Byte-level equality comparison between two 64-bit integers. * @return 64-bit integer, where every top bit in each byte signifies a match. From fa47debf2f70d526be4625309f52d5a8a5d5643f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:18:27 +0000 Subject: [PATCH 064/751] Fix: BMI flags for `BZHI` --- include/stringzilla/types.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 57ff7124..a170b6b0 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -860,6 +860,9 @@ SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x * Alternatively, the BZHI instruction can be used to clear the bits above N. */ #if SZ_USE_SKYLAKE || SZ_USE_ICE +#pragma GCC push_options +#pragma GCC target("bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("bmi,bmi2"))), apply_to = function) SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { return (__mmask16)_bzhi_u32(0xFFFFu, n); } SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { return (__mmask32)_bzhi_u64(0xFFFFFFFFu, n); } SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { return (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n); } @@ -868,6 +871,8 @@ SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { return n < 32 ? _s SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { return n < 64 ? _sz_u64_mask_until(n) : 0xFFFFFFFFFFFFFFFFull; } +#pragma GCC pop_options +#pragma clang attribute pop #endif /** From d9557d35078e06e61e3667d2f01b367940189a96 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:18:43 +0000 Subject: [PATCH 065/751] Improve: Faster `levenshtein_baseline` --- scripts/test.hpp | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/scripts/test.hpp b/scripts/test.hpp index 9f9abe6b..261f90a5 100644 --- a/scripts/test.hpp +++ b/scripts/test.hpp @@ -67,29 +67,29 @@ inline std::string random_string(std::size_t length, char const *alphabet, std:: * Allocates a new matrix on every call, with rows potentially scattered around memory. */ inline std::size_t levenshtein_baseline(char const *s1, std::size_t len1, char const *s2, std::size_t len2) { - std::vector> dp(len1 + 1, std::vector(len2 + 1)); + std::size_t const rows = len1 + 1; + std::size_t const cols = len2 + 1; + std::vector matrix_buffer(rows * cols); // Initialize the borders of the matrix. - for (std::size_t i = 0; i <= len1; ++i) dp[i][0] = i; - for (std::size_t j = 0; j <= len2; ++j) dp[0][j] = j; + for (std::size_t i = 0; i < rows; ++i) matrix_buffer[i * cols + 0] /* [i][0] in 2D */ = i; + for (std::size_t j = 0; j < cols; ++j) matrix_buffer[0 * cols + j] /* [0][j] in 2D */ = j; - for (std::size_t i = 1; i <= len1; ++i) { - for (std::size_t j = 1; j <= len2; ++j) { + for (std::size_t i = 1; i < rows; ++i) { + std::size_t const *last_row = &matrix_buffer[(i - 1) * cols]; + std::size_t *row = &matrix_buffer[i * cols]; + for (std::size_t j = 1; j < cols; ++j) { std::size_t cost = (s1[i - 1] == s2[j - 1]) ? 0 : 1; - // dp[i][j] is the minimum of deletion, insertion, or substitution - dp[i][j] = std::min({ - dp[i - 1][j] + 1, // Deletion - dp[i][j - 1] + 1, // Insertion - dp[i - 1][j - 1] + cost // Substitution - }); + std::size_t deletion_or_insertion = std::min(last_row[j], row[j - 1]) + 1; + row[j] = std::min(deletion_or_insertion, last_row[j - 1] + cost); } } - return dp[len1][len2]; + return matrix_buffer.back(); } /** - * @brief Produces a substitution cost matrix for the Needlemann-Wunsch alignment score, + * @brief Produces a substitution cost matrix for the Needleman-Wunsch alignment score, * that would yield the same result as the negative Levenshtein distance. */ inline std::vector unary_substitution_costs() { From d20e589a56921679d0642350df8026df3240b54f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:24:55 +0000 Subject: [PATCH 066/751] Fix: Minor dispatch issues --- .vscode/settings.json | 3 ++- CMakeLists.txt | 5 ++++- include/stringzilla/hash.h | 8 ++++---- include/stringzilla/similarity.h | 10 +++++++--- scripts/bench_memory.cpp | 2 +- scripts/bench_similarity.cpp | 22 ++++++++++++++++++---- scripts/bench_token.cpp | 4 ++-- scripts/test.cpp | 13 ++++++++++--- 8 files changed, 48 insertions(+), 19 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 9d0e1b53..051fc5c8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -268,7 +268,8 @@ "xtree": "cpp", "xutility": "cpp", "errno.h": "c", - "text_encoding": "cpp" + "text_encoding": "cpp", + "ranges": "cpp" }, "python.pythonPath": "~/miniconda3/bin/python" } \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 7914aa0e..df90ad80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -367,11 +367,14 @@ if(${STRINGZILLA_BUILD_SHARED}) define_shared(stringzilla_shared) target_compile_definitions(stringzilla_shared PRIVATE "SZ_AVOID_LIBC=0") target_compile_definitions(stringzilla_shared PRIVATE "SZ_OVERRIDE_LIBC=1") - + target_include_directories(stringzilla_shared PUBLIC include) + + # Try compiling a version without linking the LibC define_shared(stringzilla_bare) target_compile_definitions(stringzilla_bare PRIVATE "SZ_AVOID_LIBC=1") target_compile_definitions(stringzilla_bare PRIVATE "SZ_OVERRIDE_LIBC=1") + target_include_directories(stringzilla_bare PUBLIC include) # Avoid built-ins on MSVC and other compilers, as that will cause compilation errors target_compile_options(stringzilla_bare PRIVATE diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 0e5e883e..262cbdc9 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -736,8 +736,8 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { } } -SZ_PUBLIC void sz_hashes_skylake(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { +SZ_PUBLIC void sz_hashes_ice(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { if (length < window_length || !window_length) return; if (length < 4 * window_length) { @@ -932,8 +932,8 @@ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_SKYLAKE - sz_hashes_skylake(text, length, window_length, window_step, callback, callback_handle); +#if SZ_USE_ICE + sz_hashes_ice(text, length, window_length, window_step, callback, callback_handle); #elif SZ_USE_HASWELL sz_hashes_haswell(text, length, window_length, window_step, callback, callback_handle); #else diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index 0b119127..5c521a40 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -639,9 +639,12 @@ SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // * @brief Computes the edit distance between two very short byte-strings using the AVX-512VBMI extensions. * * Applies to string lengths up to 63, and evaluates at most (63 * 2 + 1 = 127) diagonals, or just as many loop - * cycles. Supports an early exit, if the distance is bounded. Keeps all of the data and Levenshtein matrices skew - * diagonal in just a couple of registers. Benefits from the @b `vpermb` instructions, that can rotate the bytes - * across the entire ZMM register. + * cycles. Supports an early exit, if the distance is bounded. Keeps all of the data and Levenshtein matrices skew + * diagonal in just a couple of registers. Benefits from the @b `vpermb` instructions, that can rotate the bytes + * across the entire ZMM register. + * + *? Bounds check, for inputs ranging from 33 to 64 bytes doesn't affect the performance at all. + *? It's also worth exploring `_mm512_alignr_epi8` and `_mm512_maskz_compress_epi8` for the shift. */ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // sz_cptr_t shorter, sz_size_t shorter_length, // @@ -678,6 +681,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // bound_vec.zmm = _mm512_set1_epi8(bound <= 255 ? (sz_u8_t)bound : 255); // To simplify comparisons and traversals, we want to reverse the order of bytes in the shorter string. + shorter_vec.zmm = _mm512_setzero_si512(); //? To simplify debugging. for (sz_size_t i = 0; i != shorter_length; ++i) shorter_vec.u8s[63 - i] = shorter[i]; shorter_rotated_vec.zmm = _mm512_permutexvar_epi8(rotate_right_vec.zmm, shorter_vec.zmm); diff --git a/scripts/bench_memory.cpp b/scripts/bench_memory.cpp index 93d7ab2d..7a9acf25 100644 --- a/scripts/bench_memory.cpp +++ b/scripts/bench_memory.cpp @@ -110,7 +110,7 @@ tracked_unary_functions_t fill_functions(sz_cptr_t dataset_start_ptr, sz_ptr_t o })}, {"sz_fill_serial", wrap_sz(sz_fill_serial)}, #if SZ_USE_SKYLAKE - {"sz_fill_avx512", wrap_sz(sz_fill_skylake)}, + {"sz_fill_skylake", wrap_sz(sz_fill_skylake)}, #endif #if SZ_USE_HASWELL {"sz_fill_haswell", wrap_sz(sz_fill_haswell)}, diff --git a/scripts/bench_similarity.cpp b/scripts/bench_similarity.cpp index 9aa964c3..ca901a5f 100644 --- a/scripts/bench_similarity.cpp +++ b/scripts/bench_similarity.cpp @@ -38,7 +38,7 @@ tracked_binary_functions_t distance_functions() { }); auto wrap_sz_distance = [alloc](auto function) mutable -> binary_function_t { return binary_function_t([function, alloc](std::string_view a, std::string_view b) mutable -> std::size_t { - return function(a.data(), a.length(), b.data(), b.length(), (sz_size_t)0, &alloc); + return function(a.data(), a.length(), b.data(), b.length(), SZ_SIZE_MAX, &alloc); }); }; auto wrap_sz_scoring = [alloc, costs_ptr](auto function) mutable -> binary_function_t { @@ -113,10 +113,24 @@ void bench_similarity_on_input_data(int argc, char const **argv) { std::printf("Benchmarking on real words:\n"); bench_similarity(dataset.tokens); + struct size_range_t { + std::size_t min_length; + std::size_t max_length; + }; + // Run benchmarks on tokens of different length - for (std::size_t token_length : {20}) { - std::printf("Benchmarking on real words of length %zu and longer:\n", token_length); - bench_similarity(filter_by_length(dataset.tokens, token_length, std::greater_equal {})); + for (size_range_t size : { + size_range_t {1, 16}, + size_range_t {17, 32}, + size_range_t {33, 64}, + size_range_t {65, 128}, + }) { + auto filtered_dataset = filter_by_length(dataset.tokens, size.min_length, std::greater_equal {}); + filtered_dataset = filter_by_length(filtered_dataset, size.max_length, std::greater_equal {}); + if (filtered_dataset.size() < 3) continue; + std::printf("Benchmarking on %zu real words of length %zu to %zu:\n", filtered_dataset.size(), size.min_length, + size.max_length); + bench_similarity(std::move(filtered_dataset)); } } diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 492f93f4..eb82dfd4 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -55,8 +55,8 @@ tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, st }; std::string suffix = std::to_string(window_width) + ":step" + std::to_string(step); tracked_unary_functions_t result = { -#if SZ_USE_SKYLAKE - {"sz_hashes_skylake:" + suffix, wrap_sz(sz_hashes_skylake)}, +#if SZ_USE_ICE + {"sz_hashes_ice:" + suffix, wrap_sz(sz_hashes_ice)}, #endif #if SZ_USE_HASWELL {"sz_hashes_haswell:" + suffix, wrap_sz(sz_hashes_haswell)}, diff --git a/scripts/test.cpp b/scripts/test.cpp index 181e0648..e9bcf3c7 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -173,6 +173,10 @@ static void test_ascii_utilities() { assert(str("").is_printable()); assert(str("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+").is_printable()); assert(!str("012🔥").is_printable()); + + assert(str("").contains_only(sz::char_set("abc"))); + assert(str("abc").contains_only(sz::char_set("abc"))); + assert(!str("abcd").contains_only(sz::char_set("abc"))); } inline void expect_equality(char const *a, char const *b, std::size_t size) { @@ -1423,6 +1427,8 @@ static void test_levenshtein_distances() { char const *right; std::size_t distance; } explicit_cases[] = { + {"a", "a", 0}, + {"A", "=", 1}, {"listen", "silent", 4}, {"", "", 0}, {"", "abc", 3}, @@ -1473,7 +1479,7 @@ static void test_levenshtein_distances() { // Validate the bounded variants: if (received > 1) { assert(sz::edit_distance(l, r, received) == received); - assert(sz::edit_distance(r, l, received - 1) == SZ_SIZE_MAX); + assert(sz::edit_distance(r, l, received - 1) >= (std::max)(l.size(), r.size())); } }; @@ -1614,8 +1620,9 @@ int main(int argc, char const **argv) { // Let's greet the user nicely sz_unused(argc && argv); std::printf("Hi, dear tester! You look nice today!\n"); - std::printf("- Uses AVX2: %s \n", SZ_USE_HASWELL ? "yes" : "no"); - std::printf("- Uses AVX512: %s \n", SZ_USE_ICE ? "yes" : "no"); + std::printf("- Uses Haswell: %s \n", SZ_USE_HASWELL ? "yes" : "no"); + std::printf("- Uses Skylake: %s \n", SZ_USE_SKYLAKE ? "yes" : "no"); + std::printf("- Uses Ice Lake: %s \n", SZ_USE_ICE ? "yes" : "no"); std::printf("- Uses NEON: %s \n", SZ_USE_NEON ? "yes" : "no"); std::printf("- Uses SVE: %s \n", SZ_USE_SVE ? "yes" : "no"); From 821d19ed73c5a163f1fcd76a7d10de29be4b1b88 Mon Sep 17 00:00:00 2001 From: ashbob999 Date: Sun, 5 Jan 2025 20:08:25 +0000 Subject: [PATCH 067/751] Fix: stable sort bench tests failing --- scripts/bench_sort.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index f46be4a3..1cd99e49 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -232,9 +232,9 @@ int main(int argc, char const **argv) { }); expect_sorted(strings, permute_base); - bench_permute( - "hybrid_stable_sort_cpp", strings, permute_base, - [](strings_t const &strings, permute_t &permute) { hybrid_stable_sort_cpp(strings, permute.data()); }); + bench_permute("hybrid_stable_sort_cpp", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + hybrid_stable_sort_cpp(strings, permute.data()); + }); expect_sorted(strings, permute_new); expect_same(permute_base, permute_new); } From 455508f9aad42c295248fb6482711f32359d5521 Mon Sep 17 00:00:00 2001 From: ashbob999 Date: Sun, 5 Jan 2025 21:40:44 +0000 Subject: [PATCH 068/751] Fix: hybrid bench sorts loading initial stirng bytes incorrectly --- scripts/bench_sort.cpp | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 1cd99e49..9b484baa 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -66,14 +66,18 @@ void populate_from_file(std::string path, strings_t &strings, while (strings.size() < limit && std::getline(f, s, ' ')) strings.push_back(s); } -constexpr size_t offset_in_word = 0; +constexpr size_t offset_in_word = 4; static idx_t hybrid_sort_cpp(strings_t const &strings, sz_u64_t *order) { // What if we take up-to 4 first characters and the index - for (size_t i = 0; i != strings.size(); ++i) - std::memcpy((char *)&order[i] + offset_in_word, strings[order[i]].c_str(), - std::min(strings[order[i]].size(), 4ul)); + for (size_t i = 0; i != strings.size(); ++i) { + size_t index = order[i]; + + for (size_t j = 0; j < std::min(strings[(sz_size_t)index].size(), 4ul); ++j) { + std::memcpy((char *)&order[i] + offset_in_word + 3 - j, strings[(sz_size_t)index].c_str() + j, 1ul); + } + } std::sort(order, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { char *i_bytes = (char *)&i; @@ -91,9 +95,13 @@ static idx_t hybrid_sort_cpp(strings_t const &strings, sz_u64_t *order) { static idx_t hybrid_stable_sort_cpp(strings_t const &strings, sz_u64_t *order) { // What if we take up-to 4 first characters and the index - for (size_t i = 0; i != strings.size(); ++i) - std::memcpy((char *)&order[i] + offset_in_word, strings[order[i]].c_str(), - std::min(strings[order[i]].size(), 4ull)); + for (size_t i = 0; i != strings.size(); ++i) { + size_t index = order[i]; + + for (size_t j = 0; j < std::min(strings[(sz_size_t)index].size(), 4ul); ++j) { + std::memcpy((char *)&order[i] + offset_in_word + 3 - j, strings[(sz_size_t)index].c_str() + j, 1ul); + } + } std::stable_sort(order, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { char *i_bytes = (char *)&i; @@ -196,7 +204,7 @@ int main(int argc, char const **argv) { }); expect_sorted(strings, permute_new); -#if __linux__ && defined(_GNU_SOURCE) +#if __linux__ && defined(_GNU_SOURCE) & !defined(__BIONIC__) bench_permute("qsort_r", strings, permute_new, [](strings_t const &strings, permute_t &permute) { sz_sequence_t array; array.order = permute.data(); From 9880f266f6d9b440db494aab744aaf0f474232a3 Mon Sep 17 00:00:00 2001 From: ashbob999 Date: Sun, 5 Jan 2025 21:47:56 +0000 Subject: [PATCH 069/751] Improve: hybrid bench sort performance --- scripts/bench_sort.cpp | 68 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 9b484baa..91734c1b 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -85,9 +85,37 @@ static idx_t hybrid_sort_cpp(strings_t const &strings, sz_u64_t *order) { return *(uint32_t *)(i_bytes + offset_in_word) < *(uint32_t *)(j_bytes + offset_in_word); }); - for (size_t i = 0; i != strings.size(); ++i) std::memset((char *)&order[i] + offset_in_word, 0, 4ul); + const auto extract_bytes = [](sz_u64_t v) -> uint32_t { + char *bytes = (char *)&v; + return *(uint32_t *)(bytes + offset_in_word); + }; + + if (strings.size() >= 2) { + size_t prev_index = 0; + uint64_t prev_bytes = extract_bytes(order[0]); + + for (size_t i = 1; i < strings.size(); ++i) { + uint32_t bytes = extract_bytes(order[i]); + if (bytes != prev_bytes) { + std::sort(order + prev_index, order + i, [&](sz_u64_t i, sz_u64_t j) { + // Assumes: offset_in_word==4 + sz_size_t i_index = i & 0xFFFF'FFFF; + sz_size_t j_index = j & 0xFFFF'FFFF; + return strings[i_index] < strings[j_index]; + }); + prev_index = i; + prev_bytes = bytes; + } + } - std::sort(order, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { return strings[i] < strings[j]; }); + std::sort(order + prev_index, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { + sz_size_t i_index = i & 0xFFFF'FFFF; + sz_size_t j_index = j & 0xFFFF'FFFF; + return strings[i_index] < strings[j_index]; + }); + } + + for (size_t i = 0; i != strings.size(); ++i) std::memset((char *)&order[i] + offset_in_word, 0, 4ul); return strings.size(); } @@ -109,9 +137,37 @@ static idx_t hybrid_stable_sort_cpp(strings_t const &strings, sz_u64_t *order) { return *(uint32_t *)(i_bytes + offset_in_word) < *(uint32_t *)(j_bytes + offset_in_word); }); - for (size_t i = 0; i != strings.size(); ++i) std::memset((char *)&order[i] + offset_in_word, 0, 4ul); + const auto extract_bytes = [](sz_u64_t v) -> uint32_t { + char *bytes = (char *)&v; + return *(uint32_t *)(bytes + offset_in_word); + }; + + if (strings.size() >= 2) { + size_t prev_index = 0; + uint64_t prev_bytes = extract_bytes(order[0]); + + for (size_t i = 1; i < strings.size(); ++i) { + uint32_t bytes = extract_bytes(order[i]); + if (bytes != prev_bytes) { + std::stable_sort(order + prev_index, order + i, [&](sz_u64_t i, sz_u64_t j) { + // Assumes: offset_in_word==4 + sz_size_t i_index = i & 0xFFFF'FFFF; + sz_size_t j_index = j & 0xFFFF'FFFF; + return strings[i_index] < strings[j_index]; + }); + prev_index = i; + prev_bytes = bytes; + } + } + + std::stable_sort(order + prev_index, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { + sz_size_t i_index = i & 0xFFFF'FFFF; + sz_size_t j_index = j & 0xFFFF'FFFF; + return strings[i_index] < strings[j_index]; + }); + } - std::stable_sort(order, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { return strings[i] < strings[j]; }); + for (size_t i = 0; i != strings.size(); ++i) std::memset((char *)&order[i] + offset_in_word, 0, 4ul); return strings.size(); } @@ -204,7 +260,7 @@ int main(int argc, char const **argv) { }); expect_sorted(strings, permute_new); -#if __linux__ && defined(_GNU_SOURCE) & !defined(__BIONIC__) +#if __linux__ && defined(_GNU_SOURCE) && !defined(__BIONIC__) bench_permute("qsort_r", strings, permute_new, [](strings_t const &strings, permute_t &permute) { sz_sequence_t array; array.order = permute.data(); @@ -248,4 +304,4 @@ int main(int argc, char const **argv) { } return 0; -} \ No newline at end of file +} From 2c49eaed742ec055a51b9ba398ed54395ee73707 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 12 Feb 2025 21:51:08 +0000 Subject: [PATCH 070/751] Break: Replace `char_set` constructor with literals --- README.md | 6 +- include/stringzilla/stringzilla.hpp | 73 +++++---- scripts/test.cpp | 227 +++++++++++++++++----------- 3 files changed, 185 insertions(+), 121 deletions(-) diff --git a/README.md b/README.md index bf3872a7..453b3bf6 100644 --- a/README.md +++ b/README.md @@ -724,12 +724,12 @@ haystack.compare(needle) == 1; // Or `haystack <=> needle` in C++ 20 and beyond StringZilla also provides string literals for automatic type resolution, [similar to STL][stl-literal]: ```cpp -using sz::literals::operator""_sz; +using sz::literals::operator""_sv; using std::literals::operator""sv; auto a = "some string"; // char const * auto b = "some string"sv; // std::string_view -auto b = "some string"_sz; // sz::string_view +auto b = "some string"_sv; // sz::string_view ``` [stl-literal]: https://en.cppreference.com/w/cpp/string/basic_string_view/operator%22%22sv @@ -887,7 +887,7 @@ str("a:b").back(-2) == ":b"; // similar to Python's `"a:b"[-2:]` str("a:b").sub(1, -1) == ":"; // similar to Python's `"a:b"[1:-1]` str("a:b").sub(-2, -1) == ":"; // similar to Python's `"a:b"[-2:-1]` str("a:b").sub(-2, 1) == ""; // similar to Python's `"a:b"[-2:1]` -"a:b"_sz[{-2, -1}] == ":"; // works on views and overloads `operator[]` +"a:b"_sv[{-2, -1}] == ":"; // works on views and overloads `operator[]` ``` Assuming StringZilla is a header-only library you can use the full API in some translation units and gradually transition to safer restricted API in others. diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index a80da804..94c75cba 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -67,7 +67,7 @@ namespace ashvardanian { namespace stringzilla { template -class basic_charset; +class basic_char_set; template class basic_string_slice; template @@ -266,79 +266,85 @@ inline carray<64> const &base64() noexcept { * @brief A set of characters represented as a bitset with 256 slots. */ template -class basic_charset { +class basic_char_set { sz_charset_t bitset_; public: using char_type = char_type_; - basic_charset() noexcept { + constexpr basic_char_set() noexcept { // ! Instead of relying on the `sz_charset_init`, we have to reimplement it to support `constexpr`. bitset_._u64s[0] = 0, bitset_._u64s[1] = 0, bitset_._u64s[2] = 0, bitset_._u64s[3] = 0; } - explicit basic_charset(std::initializer_list chars) noexcept : basic_charset() { + explicit constexpr basic_char_set(std::initializer_list chars) noexcept : basic_char_set() { // ! Instead of relying on the `sz_charset_add(&bitset_, c)`, we have to reimplement it to support `constexpr`. for (auto c : chars) bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); } - template - explicit basic_charset(char_type const (&chars)[count_characters]) noexcept : basic_charset() { - static_assert(count_characters > 0, "Character array cannot be empty"); - for (std::size_t i = 0; i < count_characters - 1; ++i) { // count_characters - 1 to exclude the null terminator + + explicit constexpr basic_char_set(char_type const *chars, std::size_t count_characters) noexcept + : basic_char_set() { + for (std::size_t i = 0; i < count_characters; ++i) { char_type c = chars[i]; bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); } } template - explicit basic_charset(std::array const &chars) noexcept : basic_charset() { + explicit constexpr basic_char_set(std::array const &chars) noexcept + : basic_char_set() { static_assert(count_characters > 0, "Character array cannot be empty"); - for (std::size_t i = 0; i < count_characters - 1; ++i) { // count_characters - 1 to exclude the null terminator + for (std::size_t i = 0; i < count_characters; ++i) { char_type c = chars[i]; bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); } } - basic_charset(basic_charset const &other) noexcept : bitset_(other.bitset_) {} - basic_charset &operator=(basic_charset const &other) noexcept { + constexpr basic_char_set(basic_char_set const &other) noexcept : bitset_(other.bitset_) {} + constexpr basic_char_set &operator=(basic_char_set const &other) noexcept { bitset_ = other.bitset_; return *this; } - basic_charset operator|(basic_charset other) const noexcept { - basic_charset result = *this; + constexpr basic_char_set operator|(basic_char_set other) const noexcept { + basic_char_set result = *this; result.bitset_._u64s[0] |= other.bitset_._u64s[0], result.bitset_._u64s[1] |= other.bitset_._u64s[1], result.bitset_._u64s[2] |= other.bitset_._u64s[2], result.bitset_._u64s[3] |= other.bitset_._u64s[3]; - return *this; + return result; } - inline basic_charset &add(char_type c) noexcept { + inline basic_char_set &add(char_type c) noexcept { sz_charset_add(&bitset_, sz_bitcast(sz_u8_t, c)); return *this; } + inline std::size_t size() const noexcept { + return // + sz_u64_popcount(bitset_._u64s[0]) + sz_u64_popcount(bitset_._u64s[1]) + // + sz_u64_popcount(bitset_._u64s[2]) + sz_u64_popcount(bitset_._u64s[3]); + } inline sz_charset_t &raw() noexcept { return bitset_; } inline sz_charset_t const &raw() const noexcept { return bitset_; } inline bool contains(char_type c) const noexcept { return sz_charset_contains(&bitset_, sz_bitcast(sz_u8_t, c)); } - inline basic_charset inverted() const noexcept { - basic_charset result = *this; + inline basic_char_set inverted() const noexcept { + basic_char_set result = *this; sz_charset_invert(&result.bitset_); return result; } }; -using char_set = basic_charset; - -inline char_set ascii_letters_set() { return char_set {ascii_letters()}; } -inline char_set ascii_lowercase_set() { return char_set {ascii_lowercase()}; } -inline char_set ascii_uppercase_set() { return char_set {ascii_uppercase()}; } -inline char_set ascii_printables_set() { return char_set {ascii_printables()}; } -inline char_set ascii_controls_set() { return char_set {ascii_controls()}; } -inline char_set digits_set() { return char_set {digits()}; } -inline char_set hexdigits_set() { return char_set {hexdigits()}; } -inline char_set octdigits_set() { return char_set {octdigits()}; } -inline char_set punctuation_set() { return char_set {punctuation()}; } -inline char_set whitespaces_set() { return char_set {whitespaces()}; } -inline char_set newlines_set() { return char_set {newlines()}; } -inline char_set base64_set() { return char_set {base64()}; } +using char_set = basic_char_set; + +inline char_set ascii_letters_set() { return char_set {ascii_letters(), sizeof(ascii_letters())}; } +inline char_set ascii_lowercase_set() { return char_set {ascii_lowercase(), sizeof(ascii_lowercase())}; } +inline char_set ascii_uppercase_set() { return char_set {ascii_uppercase(), sizeof(ascii_uppercase())}; } +inline char_set ascii_printables_set() { return char_set {ascii_printables(), sizeof(ascii_printables())}; } +inline char_set ascii_controls_set() { return char_set {ascii_controls(), sizeof(ascii_controls())}; } +inline char_set digits_set() { return char_set {digits(), sizeof(digits())}; } +inline char_set hexdigits_set() { return char_set {hexdigits(), sizeof(hexdigits())}; } +inline char_set octdigits_set() { return char_set {octdigits(), sizeof(octdigits())}; } +inline char_set punctuation_set() { return char_set {punctuation(), sizeof(punctuation())}; } +inline char_set whitespaces_set() { return char_set {whitespaces(), sizeof(whitespaces())}; } +inline char_set newlines_set() { return char_set {newlines(), sizeof(newlines())}; } +inline char_set base64_set() { return char_set {base64(), sizeof(base64())}; } /** * @brief A look-up table for character replacement operations. @@ -3446,7 +3452,8 @@ using string = basic_string>; static_assert(sizeof(string) == 4 * sizeof(void *), "String size must be 4 pointers."); namespace literals { -constexpr string_view operator""_sz(char const *str, std::size_t length) noexcept { return {str, length}; } +constexpr string_view operator""_sv(char const *str, std::size_t length) noexcept { return {str, length}; } +constexpr char_set operator""_cs(char const *str, std::size_t length) noexcept { return char_set {str, length}; } } // namespace literals template diff --git a/scripts/test.cpp b/scripts/test.cpp index ead0c88d..87db34c8 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -46,7 +46,8 @@ namespace sz = ashvardanian::stringzilla; using namespace sz::scripts; -using sz::literals::operator""_sz; +using sz::literals::operator""_sv; // for `sz::string_view` +using sz::literals::operator""_cs; // for `sz::char_set` /* * Instantiate all the templates to make the symbols visible and also check @@ -58,7 +59,7 @@ template class std::basic_string_view; template class sz::basic_string_slice; template class std::basic_string; template class sz::basic_string; -template class sz::basic_charset; +template class sz::basic_char_set; template class std::vector; template class std::map; @@ -137,6 +138,61 @@ static void test_arithmetical_utilities() { (static_cast(number) / static_cast(divisor))); } +/** + * @brief Tests various ASCII-based methods (e.g., `is_alpha`, `is_digit`) + * provided by `sz::string` and `sz::string_view`. + */ +template +static void test_ascii_utilities() { + + using str = string_type; + + assert("aaa"_cs.size() == 1ull); + assert("\0\0"_cs.size() == 1ull); + assert("abc"_cs.size() == 3ull); + assert("a\0bc"_cs.size() == 4ull); + + assert(!"abc"_cs.contains('\0')); + assert(str("bca").contains_only("abc"_cs)); + + assert(!str("").is_alpha()); + assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ").is_alpha()); + assert(!str("abc9").is_alpha()); + + assert(!str("").is_alnum()); + assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789").is_alnum()); + assert(!str("abc!").is_alnum()); + + assert(str("").is_ascii()); + assert(str("\x00x7F").is_ascii()); + assert(!str("abc123🔥").is_ascii()); + + assert(!str("").is_digit()); + assert(str("0123456789").is_digit()); + assert(!str("012a").is_digit()); + + assert(!str("").is_lower()); + assert(str("abcdefghijklmnopqrstuvwxyz").is_lower()); + assert(!str("abcA").is_lower()); + assert(!str("abc\n").is_lower()); + + assert(!str("").is_space()); + assert(str(" \t\n\r\f\v").is_space()); + assert(!str(" \t\r\na").is_space()); + + assert(!str("").is_upper()); + assert(str("ABCDEFGHIJKLMNOPQRSTUVWXYZ").is_upper()); + assert(!str("ABCa").is_upper()); + + assert(str("").is_printable()); + assert(str("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+").is_printable()); + assert(!str("012🔥").is_printable()); + + assert(str("").contains_only("abc"_cs)); + assert(str("abc").contains_only("abc"_cs)); + assert(!str("abcd").contains_only("abc"_cs)); +} + inline void expect_equality(char const *a, char const *b, std::size_t size) { if (std::memcmp(a, b, size) == 0) return; std::size_t mismatch_position = 0; @@ -838,9 +894,9 @@ void test_non_stl_extensions_for_updates() { assert_scoped(str s = "hello", s.replace_all("xx", "xx"), s == "hello"); assert_scoped(str s = "hello", s.replace_all("l", "1"), s == "he11o"); assert_scoped(str s = "hello", s.replace_all("he", "al"), s == "alllo"); - assert_scoped(str s = "hello", s.replace_all(sz::char_set("x"), "!"), s == "hello"); - assert_scoped(str s = "hello", s.replace_all(sz::char_set("o"), "!"), s == "hell!"); - assert_scoped(str s = "hello", s.replace_all(sz::char_set("ho"), "!"), s == "!ell!"); + assert_scoped(str s = "hello", s.replace_all("x"_cs, "!"), s == "hello"); + assert_scoped(str s = "hello", s.replace_all("o"_cs, "!"), s == "hell!"); + assert_scoped(str s = "hello", s.replace_all("ho"_cs, "!"), s == "!ell!"); // Shorter replacements. assert_scoped(str s = "hello", s.replace_all("xx", "x"), s == "hello"); @@ -848,8 +904,8 @@ void test_non_stl_extensions_for_updates() { assert_scoped(str s = "hello", s.replace_all("h", ""), s == "ello"); assert_scoped(str s = "hello", s.replace_all("o", ""), s == "hell"); assert_scoped(str s = "hello", s.replace_all("llo", "!"), s == "he!"); - assert_scoped(str s = "hello", s.replace_all(sz::char_set("x"), ""), s == "hello"); - assert_scoped(str s = "hello", s.replace_all(sz::char_set("lo"), ""), s == "he"); + assert_scoped(str s = "hello", s.replace_all("x"_cs, ""), s == "hello"); + assert_scoped(str s = "hello", s.replace_all("lo"_cs, ""), s == "he"); // Longer replacements. assert_scoped(str s = "hello", s.replace_all("xx", "xxx"), s == "hello"); @@ -857,8 +913,8 @@ void test_non_stl_extensions_for_updates() { assert_scoped(str s = "hello", s.replace_all("h", "hh"), s == "hhello"); assert_scoped(str s = "hello", s.replace_all("o", "oo"), s == "helloo"); assert_scoped(str s = "hello", s.replace_all("llo", "llo!"), s == "hello!"); - assert_scoped(str s = "hello", s.replace_all(sz::char_set("x"), "xx"), s == "hello"); - assert_scoped(str s = "hello", s.replace_all(sz::char_set("lo"), "lo"), s == "helololo"); + assert_scoped(str s = "hello", s.replace_all("x"_cs, "xx"), s == "hello"); + assert_scoped(str s = "hello", s.replace_all("lo"_cs, "lo"), s == "helololo"); // Directly mapping bytes using a Look-Up Table. sz::look_up_table invert_case = sz::look_up_table::identity(); @@ -872,8 +928,8 @@ void test_non_stl_extensions_for_updates() { assert(str(str("a") | str("b")) == "ab"); assert(str(str("a") | str("b") | str("ab")) == "abab"); - assert(str(sz::concatenate("a"_sz, "b"_sz)) == "ab"); - assert(str(sz::concatenate("a"_sz, "b"_sz, "c"_sz)) == "abc"); + assert(str(sz::concatenate("a"_sv, "b"_sv)) == "ab"); + assert(str(sz::concatenate("a"_sv, "b"_sv, "c"_sv)) == "abc"); // Randomization. assert(str::random(0).empty()); @@ -1062,15 +1118,15 @@ static void test_updates(std::size_t repetitions = 1024) { */ static void test_comparisons() { // Comparing relative order of the strings - assert("a"_sz.compare("a") == 0); - assert("a"_sz.compare("ab") == -1); - assert("ab"_sz.compare("a") == 1); - assert("a"_sz.compare("a\0"_sz) == -1); - assert("a\0"_sz.compare("a") == 1); - assert("a\0"_sz.compare("a\0"_sz) == 0); - assert("a"_sz == "a"_sz); - assert("a"_sz != "a\0"_sz); - assert("a\0"_sz == "a\0"_sz); + assert("a"_sv.compare("a") == 0); + assert("a"_sv.compare("ab") == -1); + assert("ab"_sv.compare("a") == 1); + assert("a"_sv.compare("a\0"_sv) == -1); + assert("a\0"_sv.compare("a") == 1); + assert("a\0"_sv.compare("a\0"_sv) == 0); + assert("a"_sv == "a"_sv); + assert("a"_sv != "a\0"_sv); + assert("a\0"_sv == "a\0"_sv); } /** @@ -1099,57 +1155,57 @@ static void test_search() { assert(sz::string_view(sz::ascii_printables(), sizeof(sz::ascii_printables())).find_first_of("~") != sz::string_view::npos); - assert("aabaa"_sz.remove_prefix("a") == "abaa"); - assert("aabaa"_sz.remove_suffix("a") == "aaba"); - assert("aabaa"_sz.lstrip(sz::char_set {"a"}) == "baa"); - assert("aabaa"_sz.rstrip(sz::char_set {"a"}) == "aab"); - assert("aabaa"_sz.strip(sz::char_set {"a"}) == "b"); + assert("aabaa"_sv.remove_prefix("a") == "abaa"); + assert("aabaa"_sv.remove_suffix("a") == "aaba"); + assert("aabaa"_sv.lstrip("a"_cs) == "baa"); + assert("aabaa"_sv.rstrip("a"_cs) == "aab"); + assert("aabaa"_sv.strip("a"_cs) == "b"); // Check more advanced composite operations - assert("abbccc"_sz.partition('b').before.size() == 1); - assert("abbccc"_sz.partition("bb").before.size() == 1); - assert("abbccc"_sz.partition("bb").match.size() == 2); - assert("abbccc"_sz.partition("bb").after.size() == 3); - assert("abbccc"_sz.partition("bb").before == "a"); - assert("abbccc"_sz.partition("bb").match == "bb"); - assert("abbccc"_sz.partition("bb").after == "ccc"); - assert("abb ccc"_sz.partition(sz::whitespaces_set()).after == "ccc"); + assert("abbccc"_sv.partition('b').before.size() == 1); + assert("abbccc"_sv.partition("bb").before.size() == 1); + assert("abbccc"_sv.partition("bb").match.size() == 2); + assert("abbccc"_sv.partition("bb").after.size() == 3); + assert("abbccc"_sv.partition("bb").before == "a"); + assert("abbccc"_sv.partition("bb").match == "bb"); + assert("abbccc"_sv.partition("bb").after == "ccc"); + assert("abb ccc"_sv.partition(sz::whitespaces_set()).after == "ccc"); // Check ranges of search matches - assert("hello"_sz.find_all("l").size() == 2); - assert("hello"_sz.rfind_all("l").size() == 2); - - assert(""_sz.find_all(".", sz::include_overlaps_type {}).size() == 0); - assert(""_sz.find_all(".", sz::exclude_overlaps_type {}).size() == 0); - assert("."_sz.find_all(".", sz::include_overlaps_type {}).size() == 1); - assert("."_sz.find_all(".", sz::exclude_overlaps_type {}).size() == 1); - assert(".."_sz.find_all(".", sz::include_overlaps_type {}).size() == 2); - assert(".."_sz.find_all(".", sz::exclude_overlaps_type {}).size() == 2); - assert(""_sz.rfind_all(".", sz::include_overlaps_type {}).size() == 0); - assert(""_sz.rfind_all(".", sz::exclude_overlaps_type {}).size() == 0); - assert("."_sz.rfind_all(".", sz::include_overlaps_type {}).size() == 1); - assert("."_sz.rfind_all(".", sz::exclude_overlaps_type {}).size() == 1); - assert(".."_sz.rfind_all(".", sz::include_overlaps_type {}).size() == 2); - assert(".."_sz.rfind_all(".", sz::exclude_overlaps_type {}).size() == 2); - - assert("a.b.c.d"_sz.find_all(".").size() == 3); - assert("a.,b.,c.,d"_sz.find_all(".,").size() == 3); - assert("a.,b.,c.,d"_sz.rfind_all(".,").size() == 3); - assert("a.b,c.d"_sz.find_all(sz::char_set(".,")).size() == 3); - assert("a...b...c"_sz.rfind_all("..").size() == 4); - assert("a...b...c"_sz.rfind_all("..", sz::include_overlaps_type {}).size() == 4); - assert("a...b...c"_sz.rfind_all("..", sz::exclude_overlaps_type {}).size() == 2); - - auto finds = "a.b.c"_sz.find_all(sz::char_set("abcd")).template to>(); + assert("hello"_sv.find_all("l").size() == 2); + assert("hello"_sv.rfind_all("l").size() == 2); + + assert(""_sv.find_all(".", sz::include_overlaps_type {}).size() == 0); + assert(""_sv.find_all(".", sz::exclude_overlaps_type {}).size() == 0); + assert("."_sv.find_all(".", sz::include_overlaps_type {}).size() == 1); + assert("."_sv.find_all(".", sz::exclude_overlaps_type {}).size() == 1); + assert(".."_sv.find_all(".", sz::include_overlaps_type {}).size() == 2); + assert(".."_sv.find_all(".", sz::exclude_overlaps_type {}).size() == 2); + assert(""_sv.rfind_all(".", sz::include_overlaps_type {}).size() == 0); + assert(""_sv.rfind_all(".", sz::exclude_overlaps_type {}).size() == 0); + assert("."_sv.rfind_all(".", sz::include_overlaps_type {}).size() == 1); + assert("."_sv.rfind_all(".", sz::exclude_overlaps_type {}).size() == 1); + assert(".."_sv.rfind_all(".", sz::include_overlaps_type {}).size() == 2); + assert(".."_sv.rfind_all(".", sz::exclude_overlaps_type {}).size() == 2); + + assert("a.b.c.d"_sv.find_all(".").size() == 3); + assert("a.,b.,c.,d"_sv.find_all(".,").size() == 3); + assert("a.,b.,c.,d"_sv.rfind_all(".,").size() == 3); + assert("a.b,c.d"_sv.find_all(".,"_cs).size() == 3); + assert("a...b...c"_sv.rfind_all("..").size() == 4); + assert("a...b...c"_sv.rfind_all("..", sz::include_overlaps_type {}).size() == 4); + assert("a...b...c"_sv.rfind_all("..", sz::exclude_overlaps_type {}).size() == 2); + + auto finds = "a.b.c"_sv.find_all("abcd"_cs).template to>(); assert(finds.size() == 3); assert(finds[0] == "a"); - auto rfinds = "a.b.c"_sz.rfind_all(sz::char_set("abcd")).template to>(); + auto rfinds = "a.b.c"_sv.rfind_all("abcd"_cs).template to>(); assert(rfinds.size() == 3); assert(rfinds[0] == "c"); { - auto splits = ".a..c."_sz.split(sz::char_set(".")).template to>(); + auto splits = ".a..c."_sv.split("."_cs).template to>(); assert(splits.size() == 5); assert(splits[0] == ""); assert(splits[1] == "a"); @@ -1157,36 +1213,36 @@ static void test_search() { } { - auto splits = "line1\nline2\nline3"_sz.split("line3").template to>(); + auto splits = "line1\nline2\nline3"_sv.split("line3").template to>(); assert(splits.size() == 2); assert(splits[0] == "line1\nline2\n"); assert(splits[1] == ""); } - assert(""_sz.split(".").size() == 1); - assert(""_sz.rsplit(".").size() == 1); - - assert("hello"_sz.split("l").size() == 3); - assert("hello"_sz.rsplit("l").size() == 3); - assert(*advanced("hello"_sz.split("l").begin(), 0) == "he"); - assert(*advanced("hello"_sz.rsplit("l").begin(), 0) == "o"); - assert(*advanced("hello"_sz.split("l").begin(), 1) == ""); - assert(*advanced("hello"_sz.rsplit("l").begin(), 1) == ""); - assert(*advanced("hello"_sz.split("l").begin(), 2) == "o"); - assert(*advanced("hello"_sz.rsplit("l").begin(), 2) == "he"); - - assert("a.b.c.d"_sz.split(".").size() == 4); - assert("a.b.c.d"_sz.rsplit(".").size() == 4); - assert(*("a.b.c.d"_sz.split(".").begin()) == "a"); - assert(*("a.b.c.d"_sz.rsplit(".").begin()) == "d"); - assert(*advanced("a.b.c.d"_sz.split(".").begin(), 1) == "b"); - assert(*advanced("a.b.c.d"_sz.rsplit(".").begin(), 1) == "c"); - assert(*advanced("a.b.c.d"_sz.split(".").begin(), 3) == "d"); - assert(*advanced("a.b.c.d"_sz.rsplit(".").begin(), 3) == "a"); - assert("a.b.,c,d"_sz.split(".,").size() == 2); - assert("a.b,c.d"_sz.split(sz::char_set(".,")).size() == 4); - - auto rsplits = ".a..c."_sz.rsplit(sz::char_set(".")).template to>(); + assert(""_sv.split(".").size() == 1); + assert(""_sv.rsplit(".").size() == 1); + + assert("hello"_sv.split("l").size() == 3); + assert("hello"_sv.rsplit("l").size() == 3); + assert(*advanced("hello"_sv.split("l").begin(), 0) == "he"); + assert(*advanced("hello"_sv.rsplit("l").begin(), 0) == "o"); + assert(*advanced("hello"_sv.split("l").begin(), 1) == ""); + assert(*advanced("hello"_sv.rsplit("l").begin(), 1) == ""); + assert(*advanced("hello"_sv.split("l").begin(), 2) == "o"); + assert(*advanced("hello"_sv.rsplit("l").begin(), 2) == "he"); + + assert("a.b.c.d"_sv.split(".").size() == 4); + assert("a.b.c.d"_sv.rsplit(".").size() == 4); + assert(*("a.b.c.d"_sv.split(".").begin()) == "a"); + assert(*("a.b.c.d"_sv.rsplit(".").begin()) == "d"); + assert(*advanced("a.b.c.d"_sv.split(".").begin(), 1) == "b"); + assert(*advanced("a.b.c.d"_sv.rsplit(".").begin(), 1) == "c"); + assert(*advanced("a.b.c.d"_sv.split(".").begin(), 3) == "d"); + assert(*advanced("a.b.c.d"_sv.rsplit(".").begin(), 3) == "a"); + assert("a.b.,c,d"_sv.split(".,").size() == 2); + assert("a.b,c.d"_sv.split(".,"_cs).size() == 4); + + auto rsplits = ".a..c."_sv.rsplit("."_cs).template to>(); assert(rsplits.size() == 5); assert(rsplits[0] == ""); assert(rsplits[1] == "c"); @@ -1557,6 +1613,7 @@ int main(int argc, char const **argv) { // Basic utilities test_arithmetical_utilities(); + test_ascii_utilities(); test_memory_utilities(); test_replacements(); From d18a1591c226692a53adabf53418a53cca85bb95 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 12 Feb 2025 21:51:56 +0000 Subject: [PATCH 071/751] Docs: Spelling `usnigned` --- .vscode/settings.json | 18 +++++++++--------- include/stringzilla/stringzilla.hpp | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 980956d1..87e4d065 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -188,6 +188,7 @@ "cwchar": "cpp", "cwctype": "cpp", "deque": "cpp", + "errno.h": "c", "exception": "cpp", "filesystem": "cpp", "format": "cpp", @@ -195,6 +196,7 @@ "fstream": "cpp", "functional": "cpp", "future": "cpp", + "immintrin.h": "c", "initializer_list": "cpp", "iomanip": "cpp", "ios": "cpp", @@ -216,6 +218,7 @@ "ostream": "cpp", "queue": "cpp", "random": "cpp", + "ranges": "cpp", "ratio": "cpp", "semaphore": "cpp", "set": "cpp", @@ -224,6 +227,7 @@ "span": "cpp", "sstream": "cpp", "stack": "cpp", + "stddef.h": "c", "stdexcept": "cpp", "stop_token": "cpp", "streambuf": "cpp", @@ -232,6 +236,7 @@ "stringzilla.h": "c", "strstream": "cpp", "system_error": "cpp", + "text_encoding": "cpp", "thread": "cpp", "tuple": "cpp", "type_traits": "cpp", @@ -242,12 +247,9 @@ "utility": "cpp", "variant": "cpp", "vector": "cpp", - "stddef.h": "c", - "immintrin.h": "c", - "xiosbase": "cpp", - "xstring": "cpp", "xfacet": "cpp", "xhash": "cpp", + "xiosbase": "cpp", "xlocale": "cpp", "xlocbuf": "cpp", "xlocinfo": "cpp", @@ -256,11 +258,9 @@ "xlocnum": "cpp", "xloctime": "cpp", "xmemory": "cpp", + "xstring": "cpp", "xtr1common": "cpp", "xtree": "cpp", - "xutility": "cpp", - "errno.h": "c", - "text_encoding": "cpp" - }, - "python.pythonPath": "~/miniconda3/bin/python" + "xutility": "cpp" + } } \ No newline at end of file diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 94c75cba..c5918005 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -359,7 +359,7 @@ class basic_look_up_table { : sizeof(char_type_) == 2 ? 65536ul : 4294967296ul; static constexpr std::size_t bytes_k = size_k * sizeof(char_type_); - using usnigned_type_ = typename std::make_unsigned::type; + using unsigned_type_ = typename std::make_unsigned::type; char_type_ lut_[size_k]; @@ -384,13 +384,13 @@ class basic_look_up_table { */ static basic_look_up_table identity() noexcept { basic_look_up_table result; - for (std::size_t i = 0; i < size_k; ++i) { result.lut_[i] = static_cast(i); } + for (std::size_t i = 0; i < size_k; ++i) { result.lut_[i] = static_cast(i); } return result; } inline sz_cptr_t raw() const noexcept { return reinterpret_cast(&lut_[0]); } - inline char_type &operator[](char_type c) noexcept { return lut_[sz_bitcast(usnigned_type_, c)]; } - inline char_type const &operator[](char_type c) const noexcept { return lut_[sz_bitcast(usnigned_type_, c)]; } + inline char_type &operator[](char_type c) noexcept { return lut_[sz_bitcast(unsigned_type_, c)]; } + inline char_type const &operator[](char_type c) const noexcept { return lut_[sz_bitcast(unsigned_type_, c)]; } }; using look_up_table = basic_look_up_table; From 0ef7cf1ea268c3a85b2d6bb023769d64b7217713 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 12 Feb 2025 23:38:19 +0000 Subject: [PATCH 072/751] Make: Renamed include/stringzilla/hash.h -> include/stringzilla/fingerprint.h --- include/stringzilla/{hash.h => fingerprint.h} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/{hash.h => fingerprint.h} (100%) diff --git a/include/stringzilla/hash.h b/include/stringzilla/fingerprint.h similarity index 100% rename from include/stringzilla/hash.h rename to include/stringzilla/fingerprint.h From 70522662130c34dc631e927d882489fe68c00592 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 12 Feb 2025 23:38:19 +0000 Subject: [PATCH 073/751] Make: Renamed include/stringzilla/hash.h -> temp-git-split-file --- include/stringzilla/hash.h => temp-git-split-file | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename include/stringzilla/hash.h => temp-git-split-file (100%) diff --git a/include/stringzilla/hash.h b/temp-git-split-file similarity index 100% rename from include/stringzilla/hash.h rename to temp-git-split-file From 5a36cb7dfc7b6f7e6cfe3c6c59b63a3b4cee98ec Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 12 Feb 2025 23:38:19 +0000 Subject: [PATCH 074/751] Make: Renamed temp-git-split-file -> include/stringzilla/hash.h --- temp-git-split-file => include/stringzilla/hash.h | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp-git-split-file => include/stringzilla/hash.h (100%) diff --git a/temp-git-split-file b/include/stringzilla/hash.h similarity index 100% rename from temp-git-split-file rename to include/stringzilla/hash.h From 38014ee288f64a715362fcfea7a40bad05af23e7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 12 Feb 2025 23:46:05 +0000 Subject: [PATCH 075/751] Break: Deprecate old fingerprinting --- include/stringzilla/fingerprint.h | 386 ------------------------- include/stringzilla/hash.h | 421 ---------------------------- include/stringzilla/stringzilla.hpp | 3 +- scripts/test.cpp | 51 +--- 4 files changed, 5 insertions(+), 856 deletions(-) diff --git a/include/stringzilla/fingerprint.h b/include/stringzilla/fingerprint.h index 262cbdc9..9cdfcc5e 100644 --- a/include/stringzilla/fingerprint.h +++ b/include/stringzilla/fingerprint.h @@ -26,32 +26,6 @@ extern "C" { #pragma region Core API -/** - * @brief Computes the 64-bit check-sum of bytes in a string. - * Similar to `std::ranges::accumulate`. - * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. - */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); - -/** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. - * - * @param text String to hash. - * @param length Number of bytes in the text. - * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length) { - sz_unused(text && length); - return 0; -} - /** * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. * Can be used for similarity scores, search, ranking, etc. @@ -124,56 +98,16 @@ SZ_PUBLIC sz_size_t sz_hashes_intersection( // sz_cptr_t text, sz_size_t length, sz_size_t window_length, // sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); -/** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. - * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. - */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); - -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); - -/** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); - /** @copydoc sz_hashes */ SZ_PUBLIC void sz_hashes_serial( // sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle); - -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial( // - sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, - void *generator) { - sz_unused(alphabet && cardinality && text && length && generate && generator); } #pragma endregion // Core API #pragma region Serial Implementation -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; - sz_u8_t const *text_u8 = (sz_u8_t const *)text; - sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; -} - /* * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. * Using a Boost-like mixer works very poorly in such case: @@ -188,117 +122,6 @@ SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { #define _sz_shift_high(x) ((x + 77ull) & 0xFFull) #define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } - - return _sz_hash_mix(hash_low, hash_high); -} - SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // sz_hash_callback_t callback, void *callback_handle) { @@ -387,86 +210,6 @@ SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback( // #pragma GCC target("avx2") #pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, - // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. - // For now, let's avoid the cases beyond the L2 size. - int is_huge = length > 1ull * 1024ull * 1024ull; - - // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } - else if (!is_huge) { - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - for (; length >= 32; text += 32, length -= 32) { - text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); - return result; - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. - else { - sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. - sz_u64_t result = 0; - - // Handle the head - while (head_length--) result += *text++; - - sz_u256_vec_t text_vec, sums_vec; - sums_vec.ymm = _mm256_setzero_si256(); - // Fill the aligned body of the buffer. - if (!is_huge) { - for (; body_length >= 32; text += 32, body_length -= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - } - // When the biffer is huge, we can traverse it in 2 directions. - else { - sz_u256_vec_t text_reversed_vec, sums_reversed_vec; - sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); - sums_reversed_vec.ymm = _mm256_add_epi64( - sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); - } - if (body_length >= 32) { - text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - } - sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); - } - - // Handle the tail - while (tail_length--) result += *text++; - - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - result += low + high; - return result; - } -} - /** * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. * This implementation is coming from Agner Fog's Vector Class Library. @@ -642,100 +385,6 @@ SZ_PUBLIC void sz_hashes_haswell(sz_cptr_t start, sz_size_t length, sz_size_t wi #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { - // The naive implementation of this function is very simple. - // It assumes the CPU is great at handling unaligned "loads". - // - // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, - // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. - // With two strings, we may consider the overall workload huge, if each exceeds 1 MB in length. - int const is_huge = length >= 1ull * 1024ull * 1024ull; - sz_u512_vec_t text_vec, sums_vec; - - // When the buffer is small, there isn't much to innovate. - if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); - text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); - sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); - return low + high; - } - else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); - text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); - sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. - __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); - __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); - __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); - sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); - sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); - return low + high; - } - else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); - text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - else if (!is_huge) { - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. - sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. - sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); - } - // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. - // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. - // - // Bidirectional traversal generally adds about 10% to such algorithms. - else { - sz_u512_vec_t text_reversed_vec, sums_reversed_vec; - sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; - sz_size_t tail_length = (sz_size_t)(text + length) % 64; - sz_size_t body_length = length - head_length - tail_length; - __mmask64 head_mask = _sz_u64_mask_until(head_length); - __mmask64 tail_mask = _sz_u64_mask_until(tail_length); - - text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); - sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); - sums_reversed_vec.zmm = - _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); - } - if (body_length >= 64) { - text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - } - - return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); - } -} - SZ_PUBLIC void sz_hashes_ice(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // sz_hash_callback_t callback, void *callback_handle) { @@ -875,24 +524,6 @@ SZ_PUBLIC void sz_hashes_ice(sz_cptr_t start, sz_size_t length, sz_size_t window #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { - uint64x2_t sum_vec = vdupq_n_u64(0); - - // Process 16 bytes (128 bits) at a time - for (; length >= 16; text += 16, length -= 16) { - uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes - uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits - uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results - uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results - sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum - } - - // Final reduction of `sum_vec` to a single scalar - sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); - return sum; -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_NEON @@ -918,18 +549,6 @@ SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { -#if SZ_USE_ICE - return sz_checksum_ice(text, length); -#elif SZ_USE_HASWELL - return sz_checksum_haswell(text, length); -#elif SZ_USE_NEON - return sz_checksum_neon(text, length); -#else - return sz_checksum_serial(text, length); -#endif -} - SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle) { #if SZ_USE_ICE @@ -941,11 +560,6 @@ SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_len #endif } -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); -} - #endif // !SZ_DYNAMIC_DISPATCH #pragma endregion // Compile Time Dispatching diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 262cbdc9..4afe9572 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -52,78 +52,6 @@ SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length) { return 0; } -/** - * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. - * Can be used for similarity scores, search, ranking, etc. - * - * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend - * on the choice of bases and the prime number. That's why, often two hashes from the same - * family are used with different bases. - * - * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. - * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. - * - * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are - * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, - * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. - * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. - * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param window_length Length of the rolling window in bytes. - * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. - * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. - * @param callback_handle Optional user-provided pointer to be passed to the `callback`. - * @see sz_hashes_fingerprint, sz_hashes_intersection - */ -SZ_DYNAMIC void sz_hashes( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - -/** - * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. - * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. - * - * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times - * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. - * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length - * and calling the same function multiple times for the same input ::text. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text String to hash. - * @param length Number of bytes in the string. - * @param fingerprint Output fingerprint buffer. - * @param fingerprint_bytes Number of bytes in the fingerprint buffer. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_intersection - */ -SZ_PUBLIC void sz_hashes_fingerprint( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_ptr_t fingerprint, sz_size_t fingerprint_bytes) { - sz_unused(text && length && window_length && fingerprint && fingerprint_bytes); -} - -/** - * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes - * of the incoming document. Can be used for document scoring and search. - * - * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, - * avoiding cache-coherency penalties of remote on-heap buffers. - * - * @param text Input document. - * @param length Number of bytes in the input document. - * @param fingerprint Reference document fingerprint. - * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. - * @param window_length Length of the rolling window in bytes. - * @see sz_hashes, sz_hashes_fingerprint - */ -SZ_PUBLIC sz_size_t sz_hashes_intersection( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, // - sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); - /** * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. * Similar to `text[i] = alphabet[rand() % cardinality]`. @@ -299,78 +227,6 @@ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { return _sz_hash_mix(hash_low, hash_high); } -SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Compute the initial hash value for the first window. - sz_u64_t hash_low = 0, hash_high = 0, hash_mix; - for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) - hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, - hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; - - // In most cases the fingerprint length will be a power of two. - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - - // Compute the hash value for every window, exporting into the fingerprint, - // using the expensive modulo operation. - sz_size_t cycles = 1; - sz_size_t const step_mask = step - 1; - for (; text < text_end; ++text, ++cycles) { - // Discard one character: - hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; - hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; - // And add a new one: - hash_low = 31ull * hash_low + _sz_shift_low(*text); - hash_high = 257ull * hash_high + _sz_shift_high(*text); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - // Mix only if we've skipped enough hashes. - if ((cycles & step_mask) == 0) { - hash_mix = _sz_hash_mix(hash_low, hash_high); - callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); - } - } -} - -/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ -SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback( // - sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { - sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; - sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; - sz_size_t fingerprint_bytes = fingerprint_buffer->length; - fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); - sz_unused(start && length); -} - -/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ -SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback( // - sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *scalar_handle) { - sz_unused(start && length && hash && scalar_handle); - sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; - *scalar_ptr ^= hash; -} - #undef _sz_shift_low #undef _sz_shift_high #undef _sz_hash_mix @@ -467,147 +323,6 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { } } -/** - * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. - * This implementation is coming from Agner Fog's Vector Class Library. - */ -SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { - __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); - __m256i prodlh = _mm256_mullo_epi32(a, bswap); - __m256i zero = _mm256_setzero_si256(); - __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); - __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); - __m256i prodll = _mm256_mul_epu32(a, b); - __m256i prod = _mm256_add_epi64(prodll, prodlh3); - return prod; -} - -SZ_PUBLIC void sz_hashes_haswell(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // Broadcast the constants into the registers. - sz_u256_vec_t prime_vec, golden_ratio_vec; - sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; - base_low_vec.ymm = _mm256_set1_epi64x(31ull); - base_high_vec.ymm = _mm256_set1_epi64x(257ull); - shift_high_vec.ymm = _mm256_set1_epi64x(77ull); - prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); - golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); - prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); - prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; - hash_low_vec.ymm = _mm256_setzero_si256(); - hash_high_vec.ymm = _mm256_setzero_si256(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8( // - hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8( // - hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t const step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_low_vec.ymm = _mm256_set_epi64x( // - text_fourth[-window_length], text_third[-window_length], text_second[-window_length], - text_first[-window_length]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - hash_low_vec.ymm = - _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); - hash_high_vec.ymm = - _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); - - // 1. Multiply the hashes by the base. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); - - // 3. Add the incoming characters. - hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); - hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_low_vec.ymm = _mm256_blendv_epi8( // - hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); - hash_high_vec.ymm = _mm256_blendv_epi8( // - hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), - _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); - hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); - hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_HASWELL @@ -736,131 +451,6 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { } } -SZ_PUBLIC void sz_hashes_ice(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - - if (length < window_length || !window_length) return; - if (length < 4 * window_length) { - sz_hashes_serial(start, length, window_length, step, callback, callback_handle); - return; - } - - // Using AVX2, we can perform 4 long integer multiplications and additions within one register. - // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. - sz_size_t const max_hashes = length - window_length + 1; - sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. - sz_u8_t const *text_first = (sz_u8_t const *)start; - sz_u8_t const *text_second = text_first + min_hashes_per_thread; - sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; - sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; - sz_u8_t const *text_end = text_first + length; - - // Broadcast the global constants into the registers. - // Both high and low hashes will work with the same prime and golden ratio. - sz_u512_vec_t prime_vec, golden_ratio_vec; - prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); - golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); - - // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. - sz_u64_t prime_power_low = 1, prime_power_high = 1; - for (sz_size_t i = 0; i + 1 < window_length; ++i) - prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, - prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; - - // We will be evaluating 4 offsets at a time with 2 different hash functions. - // We can fit all those 8 state variables in each of the following ZMM registers. - sz_u512_vec_t base_vec, prime_power_vec, shift_vec; - base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); - shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); - prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, - prime_power_high, prime_power_high, prime_power_high, prime_power_high); - - // Compute the initial hash values for every one of the four windows. - sz_u512_vec_t hash_vec, chars_vec; - hash_vec.zmm = _mm512_setzero_si512(); - for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; - ++text_first, ++text_second, ++text_third, ++text_fourth) { - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - } - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - sz_u512_vec_t hash_mix_vec; - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); - - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - - // Now repeat that operation for the remaining characters, discarding older characters. - sz_size_t cycle = 1; - sz_size_t step_mask = step - 1; - for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { - // 0. Load again the four characters we are dropping, shift them, and subtract. - chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length], // - text_fourth[-window_length], text_third[-window_length], - text_second[-window_length], text_first[-window_length]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); - - // 1. Multiply the hashes by the base. - hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); - - // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, - // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. - chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // - text_fourth[0], text_third[0], text_second[0], text_first[0]); - chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); - - // ... and prefetch the next four characters into Level 2 or higher. - _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); - _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); - - // 3. Add the incoming characters. - hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); - - // 4. Compute the modulo. Assuming there are only 59 values between our prime - // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. - hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, - _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); - - // 5. Compute the hash mix, that will be used to index into the fingerprint. - // This includes a serial step at the end. - hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); - hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // - _mm512_castsi512_si256(hash_mix_vec.zmm)); - - if ((cycle & step_mask) == 0) { - callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); - callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); - callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); - callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); - } - } -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_ICE @@ -930,17 +520,6 @@ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { #endif } -SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle) { -#if SZ_USE_ICE - sz_hashes_ice(text, length, window_length, window_step, callback, callback_handle); -#elif SZ_USE_HASWELL - sz_hashes_haswell(text, length, window_length, window_step, callback, callback_handle); -#else - sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); -#endif -} - SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 0f0aaef9..af677aad 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -4019,7 +4019,7 @@ void sorted_order(objects_type_ const *begin, objects_type_ const *end, sorted_i } #if !SZ_AVOID_STL - +#if _SZ_DEPRECATED_FINGERPRINTS /** * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. * @see sz_hashes @@ -4052,6 +4052,7 @@ template std::bitset hashes_fingerprint(basic_string const &str, std::size_t window_length) noexcept { return ashvardanian::stringzilla::hashes_fingerprint(str.view(), window_length); } +#endif /** * @brief Computes the permutation of an array, that would lead to sorted order. diff --git a/scripts/test.cpp b/scripts/test.cpp index dc13a656..2bf886d1 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -133,53 +133,6 @@ static void test_arithmetical_utilities() { #endif } -/** - * @brief Tests various ASCII-based methods (e.g., `is_alpha`, `is_digit`) - * provided by `sz::string` and `sz::string_view`. - */ -template -static void test_ascii_utilities() { - - using str = string_type; - - assert(!str("").is_alpha()); - assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ").is_alpha()); - assert(!str("abc9").is_alpha()); - - assert(!str("").is_alnum()); - assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789").is_alnum()); - assert(!str("abc!").is_alnum()); - - assert(str("").is_ascii()); - assert(str("\x00x7F").is_ascii()); - assert(!str("abc123🔥").is_ascii()); - - assert(!str("").is_digit()); - assert(str("0123456789").is_digit()); - assert(!str("012a").is_digit()); - - assert(!str("").is_lower()); - assert(str("abcdefghijklmnopqrstuvwxyz").is_lower()); - assert(!str("abcA").is_lower()); - assert(!str("abc\n").is_lower()); - - assert(!str("").is_space()); - assert(str(" \t\n\r\f\v").is_space()); - assert(!str(" \t\r\na").is_space()); - - assert(!str("").is_upper()); - assert(str("ABCDEFGHIJKLMNOPQRSTUVWXYZ").is_upper()); - assert(!str("ABCa").is_upper()); - - assert(str("").is_printable()); - assert(str("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+").is_printable()); - assert(!str("012🔥").is_printable()); - - assert(str("").contains_only(sz::char_set("abc"))); - assert(str("abc").contains_only(sz::char_set("abc"))); - assert(!str("abcd").contains_only(sz::char_set("abc"))); -} - /** * @brief Tests various ASCII-based methods (e.g., `is_alpha`, `is_digit`) * provided by `sz::string` and `sz::string_view`. @@ -892,6 +845,8 @@ static void test_non_stl_extensions_for_reads() { assert(str("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz").checksum() == arithmetic_sum('a', 'z') * 3); +#if _SZ_DEPRECATED_FINGERPRINTS + // Computing rolling fingerprints. assert(sz::hashes_fingerprint<512>(str("aaaa"), 3).count() == 1); assert(sz::hashes_fingerprint<512>(str("hello"), 4).count() == 2); @@ -903,7 +858,7 @@ static void test_non_stl_extensions_for_reads() { assert(sz::hashes_fingerprint<512>(str("aaa"), 3).count() == 1); assert(sz::hashes_fingerprint<512>(str("aaaa"), 3).count() == 1); assert(sz::hashes_fingerprint<512>(str("aaaaa"), 3).count() == 1); - +#endif // Computing fuzzy search results. } From 1de3166344e817e91132371e17e1560809a55194 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 12 Feb 2025 23:47:28 +0000 Subject: [PATCH 076/751] Make: Move drafts --- CONTRIBUTING.md | 4 ++-- c/lib.c | 2 +- include/stringzilla/drafts.h => drafts/bitap.h | 0 {include/stringzilla => drafts}/fingerprint.h | 0 include/stringzilla/hash.h | 11 +++-------- scripts/bench_token.cpp | 4 ++++ 6 files changed, 10 insertions(+), 11 deletions(-) rename include/stringzilla/drafts.h => drafts/bitap.h (100%) rename {include/stringzilla => drafts}/fingerprint.h (100%) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 231291c8..dfb4fb2f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -129,8 +129,8 @@ Using modern syntax, this is how you build and run the test suite: ```bash cmake -D STRINGZILLA_BUILD_TEST=1 -D CMAKE_BUILD_TYPE=Debug -B build_debug -cmake --build build_debug --config Debug # Which will produce the following targets: -build_debug/stringzilla_test_cpp20 # Unit test for the entire library compiled for current hardware +cmake --build build_debug --config Debug # Which will produce the following targets: +build_debug/stringzilla_test_cpp20 # Unit test for the entire library compiled for current hardware build_debug/stringzilla_test_cpp20_serial # x86 variant compiled for IvyBridge - last arch. before AVX2 build_debug/stringzilla_test_cpp20_serial # Arm variant compiled without Neon ``` diff --git a/c/lib.c b/c/lib.c index d829e379..52a6ce7a 100644 --- a/c/lib.c +++ b/c/lib.c @@ -224,7 +224,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->edit_distance = sz_edit_distance_serial; impl->alignment_score = sz_alignment_score_serial; - impl->hashes = sz_hashes_serial; + impl->hashes = 0; #if SZ_USE_HASWELL if (caps & sz_cap_haswell_k) { diff --git a/include/stringzilla/drafts.h b/drafts/bitap.h similarity index 100% rename from include/stringzilla/drafts.h rename to drafts/bitap.h diff --git a/include/stringzilla/fingerprint.h b/drafts/fingerprint.h similarity index 100% rename from include/stringzilla/fingerprint.h rename to drafts/fingerprint.h diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 4afe9572..52b6d372 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -78,11 +78,6 @@ SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); -/** @copydoc sz_hashes */ -SZ_PUBLIC void sz_hashes_serial( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // - sz_hash_callback_t callback, void *callback_handle); - /** @copydoc sz_generate */ SZ_PUBLIC void sz_generate_serial( // sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, @@ -261,7 +256,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); } - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. + // Accumulating 256 bits is harder, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); @@ -291,7 +286,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); } } - // When the biffer is huge, we can traverse it in 2 directions. + // When the buffer is huge, we can traverse it in 2 directions. else { sz_u256_vec_t text_reversed_vec, sums_reversed_vec; sums_reversed_vec.ymm = _mm256_setzero_si256(); @@ -312,7 +307,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { // Handle the tail while (tail_length--) result += *text++; - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. + // Accumulating 256 bits is harder, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index eb82dfd4..35369ac0 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -46,6 +46,7 @@ tracked_unary_functions_t hashing_functions() { } tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, std::size_t step) { +#if _SZ_DEPRECATED_FINGERPRINTS auto wrap_sz = [=](auto function) -> unary_function_t { return unary_function_t([function, window_width, step](std::string_view s) { sz_size_t mixed_hash = 0; @@ -53,8 +54,10 @@ tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, st return mixed_hash; }); }; +#endif std::string suffix = std::to_string(window_width) + ":step" + std::to_string(step); tracked_unary_functions_t result = { +#if _SZ_DEPRECATED_FINGERPRINTS #if SZ_USE_ICE {"sz_hashes_ice:" + suffix, wrap_sz(sz_hashes_ice)}, #endif @@ -62,6 +65,7 @@ tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, st {"sz_hashes_haswell:" + suffix, wrap_sz(sz_hashes_haswell)}, #endif {"sz_hashes_serial:" + suffix, wrap_sz(sz_hashes_serial)}, +#endif }; return result; } From 0a3e363a4439db233758e944727efef72e373243 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 13 Feb 2025 13:21:45 +0000 Subject: [PATCH 077/751] Improve: Relax many `constexpr`s from C++20 to C++14 --- include/stringzilla/stringzilla.hpp | 33 ++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index af677aad..126211c4 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -36,9 +36,20 @@ #define _SZ_IS_CPP98 (__cplusplus >= 199711L) /** - * @brief The `constexpr` keyword has different applicability scope in different C++ versions. + * @brief Expands to `constexpr` in C++20 and later, and to nothing in older C++ versions. * Useful for STL conversion operators, as several `std::string` members are `constexpr` in C++20. + * + * The `constexpr` keyword has different applicability scope in different C++ versions. + * - C++11: Introduced `constexpr`, but no loops or multiple `return` statements were allowed. + * - C++14: Allowed loops, multiple statements, and local variables in `constexpr` functions. + * - C++17: Added the `if constexpr` construct for compile-time branching. + * - C++20: Added some dynamic memory allocations, `virtual` functions, and `try`/`catch` blocks. */ +#if _SZ_IS_CPP14 +#define sz_constexpr_if_cpp14 constexpr +#else +#define sz_constexpr_if_cpp14 +#endif #if _SZ_IS_CPP20 #define sz_constexpr_if_cpp20 constexpr #else @@ -277,12 +288,12 @@ class basic_char_set { // ! Instead of relying on the `sz_charset_init`, we have to reimplement it to support `constexpr`. bitset_._u64s[0] = 0, bitset_._u64s[1] = 0, bitset_._u64s[2] = 0, bitset_._u64s[3] = 0; } - explicit constexpr basic_char_set(std::initializer_list chars) noexcept : basic_char_set() { + explicit sz_constexpr_if_cpp14 basic_char_set(std::initializer_list chars) noexcept : basic_char_set() { // ! Instead of relying on the `sz_charset_add(&bitset_, c)`, we have to reimplement it to support `constexpr`. for (auto c : chars) bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); } - explicit constexpr basic_char_set(char_type const *chars, std::size_t count_characters) noexcept + explicit sz_constexpr_if_cpp14 basic_char_set(char_type const *chars, std::size_t count_characters) noexcept : basic_char_set() { for (std::size_t i = 0; i < count_characters; ++i) { char_type c = chars[i]; @@ -291,7 +302,7 @@ class basic_char_set { } template - explicit constexpr basic_char_set(std::array const &chars) noexcept + explicit sz_constexpr_if_cpp14 basic_char_set(std::array const &chars) noexcept : basic_char_set() { static_assert(count_characters > 0, "Character array cannot be empty"); for (std::size_t i = 0; i < count_characters; ++i) { @@ -1232,8 +1243,8 @@ class basic_string_slice { : start_(c_string), length_(null_terminated_length(c_string)) {} constexpr basic_string_slice(pointer c_string, size_type length) noexcept : start_(c_string), length_(length) {} - sz_constexpr_if_cpp20 basic_string_slice(basic_string_slice const &other) noexcept = default; - sz_constexpr_if_cpp20 basic_string_slice &operator=(basic_string_slice const &other) noexcept = default; + constexpr basic_string_slice(basic_string_slice const &other) noexcept = default; + constexpr basic_string_slice &operator=(basic_string_slice const &other) noexcept = default; basic_string_slice(std::nullptr_t) = delete; /** @brief Exchanges the view with that of the `other`. */ @@ -1927,13 +1938,13 @@ class basic_string_slice { } private: - sz_constexpr_if_cpp20 string_slice &assign(string_view const &other) noexcept { + sz_constexpr_if_cpp14 string_slice &assign(string_view const &other) noexcept { start_ = (pointer)other.data(); length_ = other.size(); return *this; } - sz_constexpr_if_cpp20 static size_type null_terminated_length(const_pointer s) noexcept { + sz_constexpr_if_cpp14 static size_type null_terminated_length(const_pointer s) noexcept { const_pointer p = s; while (*p) ++p; return p - s; @@ -2080,7 +2091,7 @@ class basic_string { #pragma region Constructors and STL Utilities - sz_constexpr_if_cpp20 basic_string() noexcept { + sz_constexpr_if_cpp14 basic_string() noexcept { // ! Instead of relying on the `sz_string_init`, we have to reimplement it to support `constexpr`. string_.internal.start = &string_.internal.chars[0]; string_.words[1] = 0; @@ -3454,7 +3465,9 @@ static_assert(sizeof(string) == 4 * sizeof(void *), "String size must be 4 point namespace literals { constexpr string_view operator""_sv(char const *str, std::size_t length) noexcept { return {str, length}; } -constexpr char_set operator""_cs(char const *str, std::size_t length) noexcept { return char_set {str, length}; } +sz_constexpr_if_cpp14 char_set operator""_cs(char const *str, std::size_t length) noexcept { + return char_set {str, length}; +} } // namespace literals template From 554f50d5e3601c6ef2984e287e7a84f91eebf1ae Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 13 Feb 2025 14:05:53 +0000 Subject: [PATCH 078/751] Add: Separate Skylake-X & Ice Lake checksums --- c/lib.c | 1 + include/stringzilla/hash.h | 185 +++++++++++++++++++++++++++++++------ scripts/bench_token.cpp | 10 +- 3 files changed, 164 insertions(+), 32 deletions(-) diff --git a/c/lib.c b/c/lib.c index 52a6ce7a..3a447d99 100644 --- a/c/lib.c +++ b/c/lib.c @@ -259,6 +259,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->rfind = sz_rfind_skylake; impl->find_byte = sz_find_byte_skylake; impl->rfind_byte = sz_rfind_byte_skylake; + impl->checksum = sz_checksum_skylake; } #endif diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 52b6d372..bdffd583 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -5,15 +5,9 @@ * * Includes core APIs: * - * - `sz_checksum` - for byte-level checksums. + * - `sz_checksum` - for byte-level 64-bit unsigned checksums. * - `sz_hash` - for 64-bit single-shot hashing. - * - `sz_hashes` - producing the rolling hashes of a string. * - `sz_generate` - populating buffers with random data. - * - * Convenience functions for character-set matching: - * - * - `sz_hashes_fingerprint` - * - `sz_hashes_intersection` */ #ifndef STRINGZILLA_HASH_H_ #define STRINGZILLA_HASH_H_ @@ -334,6 +328,106 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) +SZ_PUBLIC sz_u64_t sz_checksum_skylake(sz_cptr_t text, sz_size_t length) { + // The naive implementation of this function is very simple. + // It assumes the CPU is great at handling unaligned "loads". + // + // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, + // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. + // With two strings, we may consider the overall workload huge, if each exceeds 1 MB in length. + int const is_huge = length >= 1ull * 1024ull * 1024ull; + sz_u512_vec_t text_vec, sums_vec; + + // When the buffer is small, there isn't much to innovate. + // Separately handling even smaller payloads doesn't increase performance even on synthetic benchmarks. + if (length <= 16) { + __mmask16 mask = _sz_u16_mask_until(length); + text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); + sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); + sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); + sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); + return low + high; + } + else if (length <= 32) { + __mmask32 mask = _sz_u32_mask_until(length); + text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); + sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); + // Accumulating 256 bits is harder, as we need to extract the 128-bit sums first. + __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); + __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); + __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); + sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); + sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); + return low + high; + } + else if (length <= 64) { + __mmask64 mask = _sz_u64_mask_until(length); + text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); + sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); + return _mm512_reduce_add_epi64(sums_vec.zmm); + } + // For large buffers, fitting into L1 cache sizes, there are other tricks we can use. + // + // 1. Moving in both directions to maximize the throughput, when fetching from multiple + // memory pages. Also helps with cache set-associativity issues, as we won't always + // be fetching the same buckets in the lookup table. + // + // Bidirectional traversal generally adds about 10% to such algorithms. + else if (!is_huge) { + sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. + sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. + sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. + _sz_assert(body_length % 64 == 0 && head_length < 64 && tail_length < 64); + __mmask64 head_mask = _sz_u64_mask_until(head_length); + __mmask64 tail_mask = _sz_u64_mask_until(tail_length); + + text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); + sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); + for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { + text_vec.zmm = _mm512_load_si512((__m512i const *)text); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + } + text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + return _mm512_reduce_add_epi64(sums_vec.zmm); + } + // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. + // + // 1. Using non-temporal loads to avoid polluting the cache. + // 2. Prefetching the next cache line, to avoid stalling the CPU. This generally useless + // for predictable patterns, so disregard this advice. + // + // Bidirectional traversal generally adds about 10% to such algorithms. + else { + sz_u512_vec_t text_reversed_vec, sums_reversed_vec; + sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; + sz_size_t tail_length = (sz_size_t)(text + length) % 64; + sz_size_t body_length = length - head_length - tail_length; + __mmask64 head_mask = _sz_u64_mask_until(head_length); + __mmask64 tail_mask = _sz_u64_mask_until(tail_length); + + text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); + sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); + text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); + sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); + + // Now in the main loop, we can use non-temporal loads, performing the operation in both directions. + for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { + text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); + sums_reversed_vec.zmm = + _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); + } + if (body_length >= 64) { + text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + } + + return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); + } +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SKYLAKE @@ -341,16 +435,17 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { /* AVX512 implementation of the string search algorithms for Ice Lake and newer CPUs. * Includes extensions: - * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, - * - 2018 CannonLake: IFMA, VBMI, - * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. + * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, + * - 2018 CannonLake: IFMA, VBMI, + * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. */ #pragma region Ice Lake Implementation #if SZ_USE_ICE #pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ - apply_to = function) +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "avx512vnni", "bmi", "bmi2") +#pragma clang attribute push( \ + __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vnni,bmi,bmi2"))), \ + apply_to = function) SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. @@ -363,6 +458,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { sz_u512_vec_t text_vec, sums_vec; // When the buffer is small, there isn't much to innovate. + // Separately handling even smaller payloads doesn't increase performance even on synthetic benchmarks. if (length <= 16) { __mmask16 mask = _sz_u16_mask_until(length); text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); @@ -375,7 +471,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { __mmask32 mask = _sz_u32_mask_until(length); text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); - // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. + // Accumulating 256 bits is harder, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); @@ -389,30 +485,60 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); return _mm512_reduce_add_epi64(sums_vec.zmm); } + // For large buffers, fitting into L1 cache sizes, there are other tricks we can use. + // + // 1. Moving in both directions to maximize the throughput, when fetching from multiple + // memory pages. Also helps with cache set-associativity issues, as we won't always + // be fetching the same buckets in the lookup table. + // 2. Port-level parallelism, can be used to hide the latency of expensive SIMD instructions. + // - `VPSADBW (ZMM, ZMM, ZMM)` combination with `VPADDQ (ZMM, ZMM, ZMM)`: + // - On Ice Lake, the `VPSADBW` is 3 cycles on port 5; the `VPADDQ` is 1 cycle on ports 0/5. + // - On Zen 4, the `VPSADBW` is 3 cycles on ports 0/1; the `VPADDQ` is 1 cycle on ports 0/1/2/3. + // - `VPDPBUSDS (ZMM, ZMM, ZMM)`: + // - On Ice Lake, the `VPDPBUSDS` is 5 cycles on port 0. + // - On Zen 4, the `VPDPBUSDS` is 4 cycles on ports 0/1. + // + // Bidirectional traversal generally adds about 10% to such algorithms. else if (!is_huge) { sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. + _sz_assert(body_length % 64 == 0 && head_length < 64 && tail_length < 64); __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); + + sz_u512_vec_t zeros_vec, ones_vec; + zeros_vec.zmm = _mm512_setzero_si512(); + ones_vec.zmm = _mm512_set1_epi8(1); + + // Take care of the unaligned head and tail! + sz_u512_vec_t text_reversed_vec, sums_reversed_vec; text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); - sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); - for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { - text_vec.zmm = _mm512_load_si512((__m512i const *)text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); + sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, zeros_vec.zmm); + text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); + sums_reversed_vec.zmm = _mm512_dpbusds_epi32(zeros_vec.zmm, text_reversed_vec.zmm, ones_vec.zmm); + + // Now in the main loop, we can use aligned loads, performing the operation in both directions. + for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { + text_reversed_vec.zmm = _mm512_load_si512((__m512i *)(text + body_length - 64)); + sums_reversed_vec.zmm = _mm512_dpbusds_epi32(sums_reversed_vec.zmm, text_reversed_vec.zmm, ones_vec.zmm); + text_vec.zmm = _mm512_load_si512((__m512i *)(text)); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, zeros_vec.zmm)); } - text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); - sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); - return _mm512_reduce_add_epi64(sums_vec.zmm); + // There may be an aligned chunk of 64 bytes left. + if (body_length >= 64) { + _sz_assert(body_length == 64); + text_vec.zmm = _mm512_load_si512((__m512i *)(text)); + sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, zeros_vec.zmm)); + } + + return _mm512_reduce_add_epi64(sums_vec.zmm) + _mm512_reduce_add_epi32(sums_reversed_vec.zmm); } // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. // - // 1. Moving in both directions to maximize the throughput, when fetching from multiple - // memory pages. Also helps with cache set-associativity issues, as we won't always - // be fetching the same entries in the lookup table. - // 2. Using non-temporal stores to avoid polluting the cache. - // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless - // for predictable patterns, so disregard this advice. + // 1. Using non-temporal loads to avoid polluting the cache. + // 2. Prefetching the next cache line, to avoid stalling the CPU. This generally useless + // for predictable patterns, so disregard this advice. // // Bidirectional traversal generally adds about 10% to such algorithms. else { @@ -428,8 +554,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); - // Now in the main loop, we can use non-temporal loads and stores, - // performing the operation in both directions. + // Now in the main loop, we can use non-temporal loads, performing the operation in both directions. for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); @@ -506,6 +631,8 @@ SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { #if SZ_USE_ICE return sz_checksum_ice(text, length); +#elif SZ_USE_SKYLAKE + return sz_checksum_skylake(text, length); #elif SZ_USE_HASWELL return sz_checksum_haswell(text, length); #elif SZ_USE_NEON diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 35369ac0..684adb05 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -21,14 +21,18 @@ tracked_unary_functions_t checksum_functions() { return std::accumulate(s.begin(), s.end(), (std::size_t)0, [](std::size_t sum, char c) { return sum + static_cast(c); }); }}, - {"sz_checksum_serial", wrap_sz(sz_checksum_serial), true}, + {"sz_checksum_serial", wrap_sz(sz_checksum_serial), false}, #if SZ_USE_HASWELL - {"sz_checksum_haswell", wrap_sz(sz_checksum_haswell), true}, + {"sz_checksum_haswell", wrap_sz(sz_checksum_haswell), false}, +#endif +#if SZ_USE_SKYLAKE + {"sz_checksum_skylake", wrap_sz(sz_checksum_skylake), false}, #endif #if SZ_USE_ICE + {"sz_checksum_ice", wrap_sz(sz_checksum_ice), false}, #endif #if SZ_USE_NEON - {"sz_checksum_neon", wrap_sz(sz_checksum_neon), true}, + {"sz_checksum_neon", wrap_sz(sz_checksum_neon), false}, #endif }; return result; From 509b58b754431bad3be050aaa90fd34641596994 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 13 Feb 2025 15:03:51 +0000 Subject: [PATCH 079/751] Fix: Loop in `sz_checksum_haswell` --- include/stringzilla/hash.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index bdffd583..a69d85bf 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -284,16 +284,18 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { else { sz_u256_vec_t text_reversed_vec, sums_reversed_vec; sums_reversed_vec.ymm = _mm256_setzero_si256(); - for (; body_length >= 64; text += 64, body_length -= 64) { + for (; body_length >= 64; text += 32, body_length -= 64) { text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); - text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); + text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 32)); sums_reversed_vec.ymm = _mm256_add_epi64( sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); } if (body_length >= 32) { + _sz_assert(body_length == 32); text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); + text += 32; } sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); } From 4044855653b9571eff857ea303e3286086d502ee Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 13 Feb 2025 15:04:12 +0000 Subject: [PATCH 080/751] Fix: Loops in AVX-512 checksums --- include/stringzilla/hash.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index a69d85bf..00200f29 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -414,7 +414,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_skylake(sz_cptr_t text, sz_size_t length) { sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); // Now in the main loop, we can use non-temporal loads, performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { + for (text += head_length; body_length >= 128; text += 64, body_length -= 128) { text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); @@ -501,6 +501,8 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { // - On Zen 4, the `VPDPBUSDS` is 4 cycles on ports 0/1. // // Bidirectional traversal generally adds about 10% to such algorithms. + // Port level parallelism can yield more, but remember that one of the instructions accumulates + // with 32-bit integers and the other one will be using 64-bit integers. else if (!is_huge) { sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. @@ -521,7 +523,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { sums_reversed_vec.zmm = _mm512_dpbusds_epi32(zeros_vec.zmm, text_reversed_vec.zmm, ones_vec.zmm); // Now in the main loop, we can use aligned loads, performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { + for (text += head_length; body_length >= 128; text += 64, body_length -= 128) { text_reversed_vec.zmm = _mm512_load_si512((__m512i *)(text + body_length - 64)); sums_reversed_vec.zmm = _mm512_dpbusds_epi32(sums_reversed_vec.zmm, text_reversed_vec.zmm, ones_vec.zmm); text_vec.zmm = _mm512_load_si512((__m512i *)(text)); @@ -557,7 +559,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); // Now in the main loop, we can use non-temporal loads, performing the operation in both directions. - for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { + for (text += head_length; body_length >= 128; text += 64, body_length -= 128) { text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); From 84cb4c8ab6bcd22bef935bcb350584dfdd4b6de7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 14 Feb 2025 15:45:35 +0000 Subject: [PATCH 081/751] Fix: Tail handling in `sz_checksum_haswell` --- include/stringzilla/hash.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 00200f29..e0c0447f 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -270,6 +270,8 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { // Handle the head while (head_length--) result += *text++; + // Handle the tail + while (tail_length) result += text[length - (tail_length--) - 1]; sz_u256_vec_t text_vec, sums_vec; sums_vec.ymm = _mm256_setzero_si256(); @@ -300,9 +302,6 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); } - // Handle the tail - while (tail_length--) result += *text++; - // Accumulating 256 bits is harder, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); From 5bbd9715a3521c150e5ec62b6334224e192a7d2a Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 14 Feb 2025 23:47:33 +0000 Subject: [PATCH 082/751] Fix: Infer allocators `value_type` --- include/stringzilla/stringzilla.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 126211c4..d64b0c03 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -1058,7 +1058,8 @@ static void *_call_allocate(sz_size_t n, void *allocator_state) noexcept { template static void _call_free(void *ptr, sz_size_t n, void *allocator_state) noexcept { - return reinterpret_cast(allocator_state)->deallocate(reinterpret_cast(ptr), n); + using value_type_ = typename allocator_type_::value_type; + return reinterpret_cast(allocator_state)->deallocate(reinterpret_cast(ptr), n); } template From ec81663483539fefbd3fe496bda62733f4408b99 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:02:35 +0000 Subject: [PATCH 083/751] Break: `sz_sort` now takes allocators --- README.md | 2 +- c/lib.c | 5 +- include/stringzilla/sort.h | 521 +++++++++++++--------------- include/stringzilla/stringzilla.hpp | 5 +- include/stringzilla/types.h | 9 +- 5 files changed, 256 insertions(+), 286 deletions(-) diff --git a/README.md b/README.md index fb2a0384..52f80d41 100644 --- a/README.md +++ b/README.md @@ -632,7 +632,7 @@ sz_size_t substring_position = sz_find_neon(haystack.start, haystack.length, nee sz_u64_t hash = sz_hash(haystack.start, haystack.length); // Perform collection level operations -sz_sequence_t array = {your_order, your_count, your_get_start, your_get_length, your_handle}; +sz_sequence_t array = {your_handle, your_count, your_get_start, your_get_length}; sz_sort(&array, &your_config); ``` diff --git a/c/lib.c b/c/lib.c index 3a447d99..5a4183cd 100644 --- a/c/lib.c +++ b/c/lib.c @@ -188,7 +188,8 @@ typedef struct sz_implementations_t { sz_edit_distance_t edit_distance; sz_alignment_score_t alignment_score; - sz_hashes_t hashes; + + sz_sort_t sort; } sz_implementations_t; @@ -224,7 +225,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->edit_distance = sz_edit_distance_serial; impl->alignment_score = sz_alignment_score_serial; - impl->hashes = 0; + impl->sort = sz_sort_serial; #if SZ_USE_HASWELL if (caps & sz_cap_haswell_k) { diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 7a8de124..e517159d 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -1,14 +1,12 @@ /** - * @brief Hardware-accelerated string sorting. + * @brief Hardware-accelerated string collection sorting and intersections. * @file sort.h * @author Ash Vardanian * * Includes core APIs: * - * - `sz_partition` - to split the sequence into two parts based on a predicate. - * - `sz_merge` - to merge two consecutive sorted chunks forming the same continuous `sequence`. - * - `sz_sort` - to sort an arbitrary string sequence. - * - `sz_sort_partial` - to partially sort an arbitrary string sequence. + * - `sz_sort` - to sort an arbitrary string collection. + * - TODO: `sz_stable_sort` - to sort a string collection while preserving the relative order of equal elements. */ #ifndef STRINGZILLA_SORT_H_ #define STRINGZILLA_SORT_H_ @@ -24,320 +22,293 @@ extern "C" { #pragma region Core API /** - * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. - * The algorithm is unstable, meaning that elements may change relative order, as long - * as they are in the right partition. This is the simpler algorithm for partitioning. - */ -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); - -/** - * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. + * @brief Faster `std::sort` for an arbitrary string sequence. * - * @param partition The number of elements in the first sub-sequence in `sequence`. - * @param less Comparison function, to determine the lexicographic ordering. + * @param collection The collection of strings to sort. + * @param alloc Memory allocator for temporary storage. + * @param order The output - indices of the sorted collection elements. + * @return Whether the operation was successful. */ -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); +SZ_PUBLIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** - * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); +/** @copydoc sz_sort */ +SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -/** - * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word - * and a follow-up by a more conventional sorting procedure on equally prefixed parts. - */ -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); +/** @copydoc sz_sort */ +SZ_PUBLIC sz_bool_t sz_sort_skylake(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -/** - * @brief Intro-Sort algorithm that supports custom comparators. - */ -SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); +/** @copydoc sz_sort */ +SZ_PUBLIC sz_bool_t sz_sort_sve(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); #pragma endregion #pragma region Serial Implementation -SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { - - sz_size_t matches = 0; - while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; - - for (sz_size_t i = matches + 1; i < sequence->count; ++i) - if (predicate(sequence, sequence->order[i])) - sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; - - return matches; +typedef sz_size_t _sz_sorting_window_t; + +SZ_PUBLIC void _sz_sort_serial_export_prefixes( // + sz_sequence_t const *const collection, // + _sz_sorting_window_t *const global_windows, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // + sz_size_t const start_character) { + + // Depending on the architecture, we will export a different number of bytes. + // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. + sz_size_t const window_capacity = sizeof(_sz_sorting_window_t) - 1; + + // Perform the same operation for every string. + for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { + // Get the string slice in global memory. + sz_cptr_t const source_str = collection->get_start(collection, i); + sz_size_t const length = collection->get_length(collection, i); + sz_size_t const remaining_length = length > start_character ? length - start_character : 0; + sz_size_t const exported_length = remaining_length > window_capacity ? window_capacity : remaining_length; + // Fill with zeros, export a slice, and mark the exported length. + sz_size_t *target_integer = &global_windows[i]; + sz_ptr_t target_str = (sz_ptr_t)target_integer; + *target_integer = 0; + for (sz_size_t j = 0; j < exported_length; ++j) target_str[j] = source_str[j + start_character]; + target_str[window_capacity] = exported_length; +#if defined(_SZ_IS_64_BIT) + *target_integer = sz_u64_bytes_reverse(*target_integer); +#else + *target_integer = sz_u32_bytes_reverse(*target_integer); +#endif + } } -SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { - - sz_size_t start_b = partition + 1; - - // If the direct merge is already sorted - if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; - - sz_size_t start_a = 0; - while (start_a <= partition && start_b <= sequence->count) { - - // If element 1 is in right place - if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } +/** + * @brief Helper function of the serial QuickSort algorithm, that rearranges the elements in + * such a way, that all entries around the pivot are less than the pivot. + * + * It means that no relative order among the elements on the left or right side of the pivot is preserved. + * We chose the pivot point using Robert Sedgewick's method - the median of three elements - the first, + * the middle, and the last element of the given range. + */ +SZ_PUBLIC sz_size_t _sz_sort_serial_partition( // + _sz_sorting_window_t *const global_windows, sz_sorted_idx_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection) { + + // Chose the pivot offset. + sz_size_t pivot_offset; + _sz_sorting_window_t pivot_window; + { + sz_size_t const middle_offset = start_in_collection + (end_in_collection - start_in_collection) / 2; + sz_size_t const last_offset = end_in_collection - 1; + sz_size_t const first_offset = start_in_collection; + _sz_sorting_window_t const first_window = global_windows[first_offset]; + _sz_sorting_window_t const middle_window = global_windows[middle_offset]; + _sz_sorting_window_t const last_window = global_windows[last_offset]; + if (first_window < middle_window) { + if (middle_window < last_window) { pivot_offset = middle_offset, pivot_window = middle_window; } + else if (first_window < last_window) { pivot_offset = last_offset, pivot_window = last_window; } + else { pivot_offset = first_offset, pivot_window = first_window; } + } else { - sz_size_t value = sequence->order[start_b]; - sz_size_t index = start_b; - - // Shift all the elements between element 1 - // element 2, right by 1. - while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } - sequence->order[start_a] = value; - - // Update all the pointers - start_a++; - partition++; - start_b++; + if (first_window < last_window) { pivot_offset = first_offset, pivot_window = first_window; } + else if (middle_window < last_window) { pivot_offset = last_offset, pivot_window = last_window; } + else { pivot_offset = middle_offset, pivot_window = middle_window; } } } -} -SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - sz_u64_t *keys = sequence->order; - sz_size_t keys_count = sequence->count; - for (sz_size_t i = 1; i < keys_count; i++) { - sz_u64_t i_key = keys[i]; - sz_size_t j = i; - for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; - keys[j] = i_key; + // Loop through the collection and move the elements around the pivot. + sz_size_t left_offset = start_in_collection; + sz_size_t right_offset = end_in_collection - 1; + while (left_offset <= right_offset) { + // Find the first element on the left that is greater than the pivot. + while (global_windows[left_offset] < pivot_window) ++left_offset; + // Find the first element on the right that is less than the pivot. + while (global_windows[right_offset] > pivot_window) --right_offset; + // Swap the elements if they are in the wrong order. + if (left_offset <= right_offset) { +#if defined(_SZ_IS_64_BIT) + sz_u64_swap(&global_order[left_offset], &global_order[right_offset]); + sz_u64_swap(&global_windows[left_offset], &global_windows[right_offset]); +#else + sz_u32_swap(&global_order[left_offset], &global_order[right_offset]); + sz_u32_swap(&global_windows[left_offset], &global_windows[right_offset]); +#endif + ++left_offset; + --right_offset; + } } -} -SZ_INTERNAL void _sz_sift_down( // - sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, sz_size_t end) { - sz_size_t root = start; - while (2 * root + 1 <= end) { - sz_size_t child = 2 * root + 1; - if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } - if (!less(sequence, order[root], order[child])) { return; } - sz_u64_swap(order + root, order + child); - root = child; - } + return pivot_offset; } -SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { - sz_size_t start = (count - 2) / 2; - while (1) { - _sz_sift_down(sequence, less, order, start, count - 1); - if (start == 0) return; - start--; +SZ_PUBLIC void _sz_sort_serial_recursively( // + sz_sequence_t const *const collection, // + _sz_sorting_window_t *const global_windows, sz_size_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // + sz_size_t const start_character) { + // Partition the collection around some pivot + sz_size_t pivot_index = + _sz_sort_serial_partition(global_windows, global_order, start_in_collection, end_in_collection); + + // Recursively sort the left partition + if (start_in_collection < pivot_index) { + _sz_sort_serial_recursively(collection, global_windows, global_order, start_in_collection, pivot_index, + start_character); } -} -SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { - sz_u64_t *order = sequence->order; - sz_size_t count = last - first; - _sz_heapify(sequence, less, order + first, count); - sz_size_t end = count - 1; - while (end > 0) { - sz_u64_swap(order + first, order + first + end); - end--; - _sz_sift_down(sequence, less, order + first, 0, end); + // Recursively sort the right partition + if (pivot_index + 1 < end_in_collection) { + _sz_sort_serial_recursively(collection, global_windows, global_order, pivot_index + 1, end_in_collection, + start_character); } } -SZ_PUBLIC void sz_sort_introsort_recursion( // - sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last, sz_size_t depth) { - - sz_size_t length = last - first; - switch (length) { - case 0: - case 1: return; - case 2: - if (less(sequence, sequence->order[first + 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); - return; - case 3: { - sz_u64_t a = sequence->order[first]; - sz_u64_t b = sequence->order[first + 1]; - sz_u64_t c = sequence->order[first + 2]; - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - if (less(sequence, c, b)) sz_u64_swap(&c, &b); - if (less(sequence, b, a)) sz_u64_swap(&a, &b); - sequence->order[first] = a; - sequence->order[first + 1] = b; - sequence->order[first + 2] = c; - return; - } - } - // Until a certain length, the quadratic-complexity insertion-sort is fine - if (length <= 16) { - sz_sequence_t sub_seq = *sequence; - sub_seq.order += first; - sub_seq.count = length; - sz_sort_insertion(&sub_seq, less); - return; +SZ_PUBLIC void _sz_sort_serial_next_window( // + sz_sequence_t const *const collection, // + _sz_sorting_window_t *const global_windows, sz_size_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // + sz_size_t const start_character) { + + // Prepare the new range of windows + _sz_sort_serial_export_prefixes(collection, global_windows, start_in_collection, end_in_collection, + start_character); + + // Sort current windows with a quicksort + _sz_sort_serial_recursively(collection, global_windows, global_order, start_in_collection, end_in_collection, + start_character); + + // Depending on the architecture, we will export a different number of bytes. + // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. + sz_size_t const window_capacity = sizeof(_sz_sorting_window_t) - 1; + + // Repeat the procedure for the identical windows + sz_size_t nested_start = start_in_collection; + sz_size_t nested_end = start_in_collection; + while (nested_end != end_in_collection) { + // Find the end of the identical windows + _sz_sorting_window_t current_window_integer = global_windows[nested_start]; + while (nested_end != end_in_collection && current_window_integer == global_windows[nested_end]) ++nested_end; + + // If the identical windows are not trivial and each string has more characters, sort them recursively + sz_cptr_t current_window_str = (sz_cptr_t)¤t_window_integer; + int current_window_length = current_window_str[window_capacity]; + if (nested_end - nested_start > 1 && current_window_length == window_capacity) { + _sz_sort_serial_next_window(collection, global_windows, global_order, nested_start, nested_end, + start_character + window_capacity); + } + // Move to the next + nested_start = nested_end; } +} - // Fallback to N-logN-complexity heap-sort - if (depth == 0) { - _sz_heapsort(sequence, less, first, last); - return; - } +SZ_PUBLIC void _sz_sort_serial_insertion(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + // This algorithm needs no memory allocations: + sz_unused(alloc); - --depth; - - // Median-of-three logic to choose pivot - sz_size_t median = first + length / 2; - if (less(sequence, sequence->order[median], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[median]); - if (less(sequence, sequence->order[last - 1], sequence->order[first])) - sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); - if (less(sequence, sequence->order[median], sequence->order[last - 1])) - sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); - - // Partition using the median-of-three as the pivot - sz_u64_t pivot = sequence->order[median]; - sz_size_t left = first; - sz_size_t right = last - 1; - while (1) { - while (less(sequence, sequence->order[left], pivot)) left++; - while (less(sequence, pivot, sequence->order[right])) right--; - if (left >= right) break; - sz_u64_swap(&sequence->order[left], &sequence->order[right]); - left++; - right--; + // Assume `order` is already initialized with 0, 1, 2, ... N. + for (sz_size_t i = 1; i < collection->count; ++i) { + sz_sorted_idx_t current_idx = order[i]; + sz_size_t j = i; + while (j > 0) { + // Get the two strings to compare. + sz_sorted_idx_t previous_idx = order[j - 1]; + sz_cptr_t previous_start = collection->get_start(collection, previous_idx); + sz_cptr_t current_start = collection->get_start(collection, current_idx); + sz_size_t previous_length = collection->get_length(collection, previous_idx); + sz_size_t current_length = collection->get_length(collection, current_idx); + + // Use the provided sz_order to compare. + sz_ordering_t ordering = sz_order(previous_start, previous_length, current_start, current_length); + + // If the previous string is not greater than current_idx, we're done. + if (ordering != sz_greater_k) break; + + // Otherwise, shift the previous element to the right. + order[j] = order[j - 1]; + --j; + } + order[j] = current_idx; } - - // Recursively sort the partitions - sz_sort_introsort_recursion(sequence, less, first, left, depth); - sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); -} - -SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { - if (sequence->count == 0) return; - sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; - sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; - sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); } -SZ_PUBLIC void sz_sort_recursion( // - sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, - sz_size_t partial_order_length) { +SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { - if (!sequence->count) return; + // First, initialize the `order` with `std::iota`-like behavior. + for (sz_size_t i = 0; i != collection->count; ++i) order[i] = i; - // Array of size one doesn't need sorting - only needs the prefix to be discarded. - if (sequence->count == 1) { - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - order_half_words[1] = 0; - return; + // On very small collections - just use the quadratic-complexity insertion sort + // without any smart optimizations or memory allocations. + if (collection->count <= 32) { + _sz_sort_serial_insertion(collection, alloc, order); + return sz_true_k; } - // Partition a range of integers according to a specific bit value - sz_size_t split = 0; - sz_u64_t mask = (1ull << 63) >> bit_idx; - - // The clean approach would be to perform a single pass over the sequence. + // One of the reasons for slow string operations is the significant overhead of branching when performing + // individual string comparisons. // - // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; - // for (sz_size_t i = split + 1; i < sequence->count; ++i) - // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; + // The core idea of our algorithm is to minimize character-level loops in string comparisons and + // instead operate on larger integer words - 4 or 8 bytes at once, on 32-bit or 64-bit architectures, respectively. + // Let's say we have N strings and the pointer size is P. // - // This, however, doesn't take into account the high relative cost of writes and swaps. - // To circumvent that, we can first count the total number entries to be mapped into either part. - // And then walk through both parts, swapping the entries that are in the wrong part. - // This would often lead to ~15% performance gain. - sz_size_t count_with_bit_set = 0; - for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; - split = sequence->count - count_with_bit_set; - - // It's possible that the sequence is already partitioned. - if (split != 0 && split != sequence->count) { - // Use two pointers to efficiently reposition elements. - // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. - sz_size_t left = 0; - sz_size_t right = sequence->count - 1; - while (1) { - // Find the next element with the bit set on the left side. - while (left < split && !(sequence->order[left] & mask)) ++left; - // Find the next element without the bit set on the right side. - while (right >= split && (sequence->order[right] & mask)) --right; - // Swap the mispositioned elements. - if (left < split && right >= split) { - sz_u64_swap(sequence->order + left, sequence->order + right); - ++left; - --right; - } - else { break; } - } - } - - // Go down recursively. - if (bit_idx < bit_max) { - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); - } - // Reached the end of recursion. - else { - // Discard the prefixes. - sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; - for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } - - sz_sequence_t a = *sequence; - a.count = split; - sz_sort_introsort(&a, comparator); - - sz_sequence_t b = *sequence; - b.order += split; - b.count -= split; - sz_sort_introsort(&b, comparator); - } + // Our recursive algorithm will take the first P bytes of each string and sort them as integers. + // Assuming that some strings may contain or even end with NULL bytes, we need to make sure, that their length + // is included in those P-long words. So, in reality, we will be taking (P-1) bytes from each string on every + // iteration of a recursive algorithm. + _sz_sorting_window_t *windows = + (_sz_sorting_window_t *)alloc->allocate(collection->count * sizeof(_sz_sorting_window_t), alloc); + if (!windows) return sz_false_k; + + // Recursively sort the whole collection. + _sz_sort_serial_recursively(collection, windows, order, 0, collection->count, 0); + + // Free temporary storage. + alloc->free(windows, collection->count * sizeof(_sz_sorting_window_t), alloc); + return sz_true_k; } -SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { - sz_cptr_t i_str = sequence->get_start(sequence, i_key); - sz_cptr_t j_str = sequence->get_start(sequence, j_key); - sz_size_t i_len = sequence->get_length(sequence, i_key); - sz_size_t j_len = sequence->get_length(sequence, j_key); - return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); -} - -SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { - -#if _SZ_IS_BIG_ENDIAN - // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. - sz_unused(partial_order_length); - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else +#pragma endregion // Serial Implementation - // Export up to 4 bytes into the `sequence` bits themselves - for (sz_size_t i = 0; i != sequence->count; ++i) { - sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); - sz_size_t length = sequence->get_length(sequence, sequence->order[i]); - length = length > 4u ? 4u : length; - sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; - for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; +#pragma region Ice Lake Implementation + +SZ_PUBLIC void _sz_sort_ice_recursively( // + sz_sequence_t const *const collection, // + _sz_sorting_window_t *const global_windows, sz_size_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // + sz_size_t const start_character) { + + // Prepare the new range of windows + _sz_sort_serial_export_prefixes(collection, global_windows, start_in_collection, end_in_collection, + start_character); + + // We can implement a form of a Radix sort here, that will count the number of elements with + // a certain bit set. The naive approach may require too many loops over data. A more "vectorized" + // approach would be to maintain a histogram for several bits at once. For 4 bits we will + // need 2^4 = 16 counters. + sz_size_t histogram[16] = {0}; + for (sz_size_t byte_in_window = 0; byte_in_window != sizeof(_sz_sorting_window_t); ++byte_in_window) { + // First sort based on the low nibble of each byte. + for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { + sz_size_t const byte = (global_windows[i] >> (byte_in_window * 8)) & 0xFF; + ++histogram[byte]; + } + sz_size_t offset = start_in_collection; + for (sz_size_t i = 0; i != 16; ++i) { + sz_size_t const count = histogram[i]; + histogram[i] = offset; + offset += count; + } + for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { + sz_size_t const byte = (global_windows[i] >> (byte_in_window * 8)) & 0xFF; + global_order[histogram[byte]] = i; + ++histogram[byte]; + } } - - // Perform optionally-parallel radix sort on them - sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); -#endif } -SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { -#if _SZ_IS_BIG_ENDIAN - sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); -#else - sz_sort_partial(sequence, sequence->count); -#endif -} +#pragma endregion // Ice Lake Implementation -#pragma endregion // Serial Implementation +SZ_PUBLIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { + return sz_sort_serial(collection, alloc, order); +} #ifdef __cplusplus } diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index d64b0c03..89fbd39b 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -4024,12 +4024,13 @@ void sorted_order(objects_type_ const *begin, objects_type_ const *end, sorted_i for (std::size_t i = 0; i != args.count; ++i) order[i] = static_cast(i); sz_sequence_t array; - array.order = reinterpret_cast(order); array.count = args.count; array.handle = &args; array.get_start = _call_sequence_member_start; array.get_length = _call_sequence_member_length; - sz_sort(&array); + + using sz_alloc_type = sz_memory_allocator_t; + _with_alloc>([&](sz_alloc_type &alloc) { return sz_sort(&array, &alloc, order); }); } #if !SZ_AVOID_STL diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index a170b6b0..9fb67112 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -320,7 +320,8 @@ typedef char *sz_ptr_t; // A type alias for `char *` typedef char const *sz_cptr_t; // A type alias for `char const *` typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions -typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings +struct sz_sequence_t; // Forward declaration of an ordered collection of strings +typedef sz_size_t sz_sorted_idx_t; // Index of a sorted string in a list of strings typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> @@ -626,20 +627,16 @@ SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_le #pragma region String Sequences API -struct sz_sequence_t; - typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); typedef struct sz_sequence_t { - sz_sorted_idx_t *order; + void const *handle; sz_size_t count; sz_sequence_member_start_t get_start; sz_sequence_member_length_t get_length; - void const *handle; } sz_sequence_t; /** From b20d7cdcd70dfc1c702cd18525b3f454d6efbaf6 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:20:27 +0000 Subject: [PATCH 084/751] Fix: Tail sum order in `checksum_haswell` --- include/stringzilla/hash.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index e0c0447f..415b4b67 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -38,8 +38,6 @@ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); * @param text String to hash. * @param length Number of bytes in the text. * @return 64-bit hash value. - * - * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection */ SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length) { sz_unused(text && length); @@ -268,10 +266,10 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. sz_u64_t result = 0; + // Handle the tail before we start updating the `text` pointer + while (tail_length) result += text[length - (tail_length--)]; // Handle the head while (head_length--) result += *text++; - // Handle the tail - while (tail_length) result += text[length - (tail_length--) - 1]; sz_u256_vec_t text_vec, sums_vec; sums_vec.ymm = _mm256_setzero_si256(); From abe8d07cc84d62038e1001ec67204102a1b955b0 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:21:17 +0000 Subject: [PATCH 085/751] Improve: Validate checksums in benchmark --- scripts/bench_token.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 684adb05..2e694588 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -21,18 +21,18 @@ tracked_unary_functions_t checksum_functions() { return std::accumulate(s.begin(), s.end(), (std::size_t)0, [](std::size_t sum, char c) { return sum + static_cast(c); }); }}, - {"sz_checksum_serial", wrap_sz(sz_checksum_serial), false}, + {"sz_checksum_serial", wrap_sz(sz_checksum_serial), true}, #if SZ_USE_HASWELL - {"sz_checksum_haswell", wrap_sz(sz_checksum_haswell), false}, + {"sz_checksum_haswell", wrap_sz(sz_checksum_haswell), true}, #endif #if SZ_USE_SKYLAKE - {"sz_checksum_skylake", wrap_sz(sz_checksum_skylake), false}, + {"sz_checksum_skylake", wrap_sz(sz_checksum_skylake), true}, #endif #if SZ_USE_ICE - {"sz_checksum_ice", wrap_sz(sz_checksum_ice), false}, + {"sz_checksum_ice", wrap_sz(sz_checksum_ice), true}, #endif #if SZ_USE_NEON - {"sz_checksum_neon", wrap_sz(sz_checksum_neon), false}, + {"sz_checksum_neon", wrap_sz(sz_checksum_neon), true}, #endif }; return result; @@ -242,6 +242,7 @@ void bench_on_synthetic_data() { int main(int argc, char const **argv) { std::printf("StringZilla. Starting token-level benchmarks.\n"); + std::printf("- Seconds per benchmark: %zu\n", seconds_per_benchmark); if (argc < 2) { bench_on_synthetic_data(); } else { bench_on_input_data(argc, argv); } From bce107af19ec7894904ab763064986336771b3e4 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:21:55 +0000 Subject: [PATCH 086/751] Improve: Wrap `std::accumulate` for checksums --- scripts/test.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/scripts/test.cpp b/scripts/test.cpp index 2bf886d1..0cf11552 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -797,6 +797,23 @@ static void test_non_stl_extensions_for_reads() { assert((str("hello")[{100, -100}] == "")); assert((str("hello")[{-100, -100}] == "")); + // Checksums + auto accumulate_bytes = [](str const &s) -> std::size_t { + return std::accumulate(s.begin(), s.end(), (std::size_t)0, + [](std::size_t sum, char c) { return sum + static_cast(c); }); + }; + assert(str("a").checksum() == (std::size_t)'a'); + assert(str("0").checksum() == (std::size_t)'0'); + assert(str("0123456789").checksum() == arithmetic_sum('0', '9')); + assert(str("abcdefghijklmnopqrstuvwxyz").checksum() == arithmetic_sum('a', 'z')); + assert(str("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz").checksum() == + arithmetic_sum('a', 'z') * 3); + assert_scoped( + str s = + "近来,加文出席微博之夜时对着镜头频繁摆出假笑表情、一度累瘫睡倒在沙发上的照片被广泛转发,引发对他失去童年、" + "被过度消费的担忧。八岁的加文,已当网红近六年了,可以说,自懂事以来,他没有过过一天没有名气的日子。", + (void)0, s.checksum() == accumulate_bytes(s)); + // Computing edit-distances. assert(sz::hamming_distance(str("hello"), str("hello")) == 0); assert(sz::hamming_distance(str("hello"), str("hell")) == 1); @@ -837,14 +854,6 @@ static void test_non_stl_extensions_for_reads() { assert(sz::alignment_score(str("hello"), str("hello"), costs, -1) == 0); assert(sz::alignment_score(str("hello"), str("hell"), costs, -1) == -1); - // Checksums - assert(str("a").checksum() == (std::size_t)'a'); - assert(str("0").checksum() == (std::size_t)'0'); - assert(str("0123456789").checksum() == arithmetic_sum('0', '9')); - assert(str("abcdefghijklmnopqrstuvwxyz").checksum() == arithmetic_sum('a', 'z')); - assert(str("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz").checksum() == - arithmetic_sum('a', 'z') * 3); - #if _SZ_DEPRECATED_FINGERPRINTS // Computing rolling fingerprints. From 982dd4d3c5459c309c3b8426191e1d730ec4133d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:22:18 +0000 Subject: [PATCH 087/751] Docs: Signatures and typos --- include/stringzilla/types.h | 96 ++++++++++++++++++++++--------------- scripts/bench.hpp | 16 +++---- scripts/bench_sort.cpp | 10 ++-- 3 files changed, 70 insertions(+), 52 deletions(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 9fb67112..89cf1ce9 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -267,10 +267,11 @@ typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits #else // if SZ_AVOID_LIBC: - -// ! The C standard doesn't specify the signedness of char. -// ! On x86 char is signed by default while on Arm it is unsigned by default. -// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. +/** + * ! The C standard doesn't specify the signedness of char. + * ! On x86 char is signed by default while on Arm it is unsigned by default. + * ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. + */ typedef signed char sz_i8_t; // Always 8 bits typedef unsigned char sz_u8_t; // Always 8 bits typedef unsigned short sz_u16_t; // Always 16 bits @@ -279,22 +280,24 @@ typedef unsigned int sz_u32_t; // Always 32 bits typedef long long sz_i64_t; // Always 64 bits typedef unsigned long long sz_u64_t; // Always 64 bits -// Now we need to redefine the `size_t`. -// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, -// where integers, pointers, and long types have different sizes: -// -// > `int` is 32 bits -// > `long` is 32 bits -// > `long long` is 64 bits -// > pointer (thus, `size_t`) is 64 bits -// -// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: -// -// > `int` is 32 bits -// > `long` and pointer (thus, `size_t`) are 64 bits -// > `long long` is also 64 bits -// -// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models +/** + * Now we need to redefine the `size_t`. + * Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, + * where integers, pointers, and long types have different sizes: + * + * > `int` is 32 bits + * > `long` is 32 bits + * > `long long` is 64 bits + * > pointer (thus, `size_t`) is 64 bits + * + * In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: + * + * > `int` is 32 bits + * > `long` and pointer (thus, `size_t`) are 64 bits + * > `long long` is also 64 bits + * + * Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models + */ #if _SZ_IS_64_BIT typedef unsigned long long sz_size_t; // 64-bit. typedef long long sz_ssize_t; // 64-bit. @@ -438,36 +441,48 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void #pragma region API Signature Types +/** @brief Signature of ::sz_hash. */ typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); + +/** @brief Signature of ::sz_checksum. */ typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); + +/** @brief Signature of ::sz_equal. */ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); + +/** @brief Signature of ::sz_order. */ typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); +/** @brief Signature of ::sz_look_up_transform. */ typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); +/** @brief Signature of ::sz_move. */ typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); +/** @brief Signature of ::sz_fill. */ typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); +/** @brief Signature of ::sz_find_byte. */ typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); + +/** @brief Signature of ::sz_find. */ typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); + +/** @brief Signature of ::sz_find_set. */ typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); +/** @brief Signature of ::sz_hamming_distance. */ typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); +/** @brief Signature of ::sz_edit_distance. */ typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); +/** @brief Signature of ::sz_alignment_score. */ typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, sz_error_cost_t, sz_memory_allocator_t *); -typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); - -typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); - -typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); - -typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); +/** @brief Signature of ::sz_sort. */ +typedef sz_bool_t (*sz_sort_t)(sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); #pragma endregion @@ -728,15 +743,16 @@ SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int l */ #if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL #include - -// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, -// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. -// TODO: In the future we can switch to a more efficient De Bruijn's algorithm. -// https://www.chessprogramming.org/BitScan -// https://www.chessprogramming.org/De_Bruijn_Sequence -// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f -// -// Use the serial version on 32-bit x86 and on Arm. +/* + * Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, + * `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. + * TODO: In the future we can switch to a more efficient De Bruijn's algorithm. + * https://www.chessprogramming.org/BitScan + * https://www.chessprogramming.org/De_Bruijn_Sequence + * https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f + * + * Use the serial version on 32-bit x86 and on Arm. + */ #if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { _sz_assert(x != 0); @@ -780,8 +796,10 @@ SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } #endif -// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, -// which breaks when `SZ_AVOID_LIBC` is given +/* + * Force the byteswap functions to be intrinsics, because when `/Oi-` is given, + * these will turn into CRT function calls, which breaks when `SZ_AVOID_LIBC` is given. + */ #pragma intrinsic(_byteswap_uint64) SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } #pragma intrinsic(_byteswap_ulong) diff --git a/scripts/bench.hpp b/scripts/bench.hpp index ecdf3bb2..b321fa7e 100644 --- a/scripts/bench.hpp +++ b/scripts/bench.hpp @@ -63,7 +63,7 @@ struct tracked_function_gt { void print() const { bool is_binary = std::is_same(); - // If failures have occured, output them to file tos implify the debugging process. + // If failures have occurred, output them to file to simplify the debugging process. bool contains_failures = !failed_strings.empty(); if (contains_failures) { // The file name is made of the string hash and the function name. @@ -161,7 +161,7 @@ inline std::vector filter_by_length(std::vector benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&function) { namespace stdc = std::chrono; - using stdcc = stdc::high_resolution_clock; - stdcc::time_point t1 = stdcc::now(); + using clock_t = stdc::high_resolution_clock; + clock_t::time_point t1 = clock_t::now(); benchmark_result_t result; std::size_t lookup_mask = bit_floor(strings.size()) - 1; @@ -254,7 +254,7 @@ benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&funct result.iterations += 4; } - stdcc::time_point t2 = stdcc::now(); + clock_t::time_point t2 = clock_t::now(); result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; if (result.seconds > seconds_per_benchmark) break; } @@ -273,8 +273,8 @@ template benchmark_result_t bench_on_token_pairs(strings_type &&strings, function_type &&function) { namespace stdc = std::chrono; - using stdcc = stdc::high_resolution_clock; - stdcc::time_point t1 = stdcc::now(); + using clock_t = stdc::high_resolution_clock; + clock_t::time_point t1 = clock_t::now(); benchmark_result_t result; std::size_t lookup_mask = bit_floor(strings.size()) - 1; std::size_t largest_prime = static_cast(18446744073709551557ull); @@ -290,7 +290,7 @@ benchmark_result_t bench_on_token_pairs(strings_type &&strings, function_type && result.iterations += 4; } - stdcc::time_point t2 = stdcc::now(); + clock_t::time_point t2 = clock_t::now(); result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; if (result.seconds > seconds_per_benchmark) break; } diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index f46be4a3..742d1b9b 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -127,9 +127,9 @@ void expect_same(permute_t const &permute_base, permute_t const &permute_new) { template void bench_permute(char const *name, strings_t &strings, permute_t &permute, algo_at &&algo) { namespace stdc = std::chrono; - using stdcc = stdc::high_resolution_clock; + using clock_t = stdc::high_resolution_clock; constexpr std::size_t iterations = 3; - stdcc::time_point t1 = stdcc::now(); + clock_t::time_point t1 = clock_t::now(); // Run multiple iterations for (std::size_t i = 0; i != iterations; ++i) { @@ -138,10 +138,10 @@ void bench_permute(char const *name, strings_t &strings, permute_t &permute, alg } // Measure elapsed time - stdcc::time_point t2 = stdcc::now(); + clock_t::time_point t2 = clock_t::now(); double dif = stdc::duration_cast(t2 - t1).count() * 1.0; - double milisecs = dif / (iterations * 1e6); - std::printf("Elapsed time is %.2lf miliseconds/iteration for %s.\n", milisecs, name); + double millisecs = dif / (iterations * 1e6); + std::printf("Elapsed time is %.2lf milliseconds/iteration for %s.\n", millisecs, name); } int main(int argc, char const **argv) { From a0318eb4e3e61547a927877df92f405e299ebb3f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:23:05 +0000 Subject: [PATCH 088/751] Make: Renamed scripts/bench_token.cpp -> scripts/bench_fingerprint.cpp --- scripts/{bench_token.cpp => bench_fingerprint.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{bench_token.cpp => bench_fingerprint.cpp} (100%) diff --git a/scripts/bench_token.cpp b/scripts/bench_fingerprint.cpp similarity index 100% rename from scripts/bench_token.cpp rename to scripts/bench_fingerprint.cpp From 07d2239431c2cba513f951803b0ea0f7fc942366 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:23:05 +0000 Subject: [PATCH 089/751] Make: Renamed scripts/bench_token.cpp -> temp-git-split-file --- scripts/bench_token.cpp => temp-git-split-file | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/bench_token.cpp => temp-git-split-file (100%) diff --git a/scripts/bench_token.cpp b/temp-git-split-file similarity index 100% rename from scripts/bench_token.cpp rename to temp-git-split-file From 031bedfca6ec7a120e83eba196d8a174b6872796 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:23:05 +0000 Subject: [PATCH 090/751] Make: Renamed temp-git-split-file -> scripts/bench_token.cpp --- temp-git-split-file => scripts/bench_token.cpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename temp-git-split-file => scripts/bench_token.cpp (100%) diff --git a/temp-git-split-file b/scripts/bench_token.cpp similarity index 100% rename from temp-git-split-file rename to scripts/bench_token.cpp From 187e0bdbeab85634e6a655225ebc54089da11bef Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:25:50 +0000 Subject: [PATCH 091/751] Improve: Separate fingerprinting benchmarks --- scripts/bench_fingerprint.cpp | 113 ---------------------------------- scripts/bench_token.cpp | 68 +------------------- 2 files changed, 1 insertion(+), 180 deletions(-) diff --git a/scripts/bench_fingerprint.cpp b/scripts/bench_fingerprint.cpp index 2e694588..82064a29 100644 --- a/scripts/bench_fingerprint.cpp +++ b/scripts/bench_fingerprint.cpp @@ -11,44 +11,6 @@ using namespace ashvardanian::stringzilla::scripts; -tracked_unary_functions_t checksum_functions() { - auto wrap_sz = [](auto function) -> unary_function_t { - return unary_function_t([function](std::string_view s) { return function(s.data(), s.size()); }); - }; - tracked_unary_functions_t result = { - {"std::accumulate", - [](std::string_view s) { - return std::accumulate(s.begin(), s.end(), (std::size_t)0, - [](std::size_t sum, char c) { return sum + static_cast(c); }); - }}, - {"sz_checksum_serial", wrap_sz(sz_checksum_serial), true}, -#if SZ_USE_HASWELL - {"sz_checksum_haswell", wrap_sz(sz_checksum_haswell), true}, -#endif -#if SZ_USE_SKYLAKE - {"sz_checksum_skylake", wrap_sz(sz_checksum_skylake), true}, -#endif -#if SZ_USE_ICE - {"sz_checksum_ice", wrap_sz(sz_checksum_ice), true}, -#endif -#if SZ_USE_NEON - {"sz_checksum_neon", wrap_sz(sz_checksum_neon), true}, -#endif - }; - return result; -} - -tracked_unary_functions_t hashing_functions() { - auto wrap_sz = [](auto function) -> unary_function_t { - return unary_function_t([function](std::string_view s) { return function(s.data(), s.size()); }); - }; - tracked_unary_functions_t result = { - {"sz_hash_serial", wrap_sz(sz_hash_serial)}, - {"std::hash", [](std::string_view s) { return std::hash {}(s); }}, - }; - return result; -} - tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, std::size_t step) { #if _SZ_DEPRECATED_FINGERPRINTS auto wrap_sz = [=](auto function) -> unary_function_t { @@ -116,59 +78,6 @@ tracked_unary_functions_t random_generation_functions(std::size_t token_length) return result; } -tracked_binary_functions_t equality_functions() { - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](std::string_view a, std::string_view b) { - return (a.size() == b.size() && function(a.data(), b.data(), a.size())); - }); - }; - tracked_binary_functions_t result = { - {"std::string_view.==", [](std::string_view a, std::string_view b) { return (a == b); }}, - {"sz_equal_serial", wrap_sz(sz_equal_serial), true}, -#if SZ_USE_HASWELL - {"sz_equal_haswell", wrap_sz(sz_equal_haswell), true}, -#endif -#if SZ_USE_SKYLAKE - {"sz_equal_skylake", wrap_sz(sz_equal_skylake), true}, -#endif - {"memcmp", - [](std::string_view a, std::string_view b) { - return (a.size() == b.size() && memcmp(a.data(), b.data(), a.size()) == 0); - }}, - }; - return result; -} - -tracked_binary_functions_t ordering_functions() { - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](std::string_view a, std::string_view b) { - return function(a.data(), a.size(), b.data(), b.size()); - }); - }; - tracked_binary_functions_t result = { - {"std::string_view.compare", - [](std::string_view a, std::string_view b) { - auto order = a.compare(b); - return (order == 0 ? sz_equal_k : (order < 0 ? sz_less_k : sz_greater_k)); - }}, - {"sz_order_serial", wrap_sz(sz_order_serial), true}, -#if SZ_USE_HASWELL - {"sz_order_haswell", wrap_sz(sz_order_haswell), true}, -#endif -#if SZ_USE_SKYLAKE - {"sz_order_skylake", wrap_sz(sz_order_skylake), true}, -#endif - {"memcmp", - [](std::string_view a, std::string_view b) { - auto order = memcmp(a.data(), b.data(), a.size() < b.size() ? a.size() : b.size()); - return order != 0 ? (a.size() == b.size() ? (order < 0 ? sz_less_k : sz_greater_k) - : (a.size() < b.size() ? sz_less_k : sz_greater_k)) - : sz_equal_k; - }}, - }; - return result; -} - template void bench_dereferencing(std::string name, std::vector strings) { auto func = unary_function_t([](std::string_view s) { return s.size(); }); @@ -183,8 +92,6 @@ void bench(strings_type &&strings) { // Benchmark logical operations bench_unary_functions(strings, checksum_functions()); bench_unary_functions(strings, hashing_functions()); - bench_unary_functions(strings, sliding_hashing_functions(8, 1)); - bench_unary_functions(strings, fingerprinting_functions()); bench_binary_functions(strings, equality_functions()); bench_binary_functions(strings, ordering_functions()); @@ -198,11 +105,7 @@ void bench(strings_type &&strings) { void bench_on_input_data(int argc, char const **argv) { dataset_t dataset = prepare_benchmark_environment(argc, argv); -#if 0 std::printf("Benchmarking on the entire dataset:\n"); - bench_unary_functions(dataset.tokens, random_generation_functions(100)); - bench_unary_functions(dataset.tokens, random_generation_functions(20)); - bench_unary_functions(dataset.tokens, random_generation_functions(5)); // When performing fingerprinting, it's extremely important to: // 1. Have small output fingerprints that fit the cache. @@ -215,25 +118,9 @@ void bench_on_input_data(int argc, char const **argv) { bench_unary_functions>({dataset.text}, sliding_hashing_functions(33, 8)); bench_unary_functions>({dataset.text}, sliding_hashing_functions(127, 16)); - bench_unary_functions>({dataset.text}, hashing_functions()); - bench_unary_functions>({dataset.text}, fingerprinting_functions(128, 4 * 1024)); bench_unary_functions>({dataset.text}, fingerprinting_functions(128, 64 * 1024)); bench_unary_functions>({dataset.text}, fingerprinting_functions(128, 1024 * 1024)); -#endif - // Baseline benchmarks for real words, coming in all lengths - std::printf("Benchmarking on real words:\n"); - bench(dataset.tokens); - std::printf("Benchmarking on real lines:\n"); - bench(dataset.lines); - std::printf("Benchmarking on entire dataset:\n"); - bench>({dataset.text}); - - // Run benchmarks on tokens of different length - for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { - std::printf("Benchmarking on real words of length %zu:\n", token_length); - bench(filter_by_length(dataset.tokens, token_length)); - } } void bench_on_synthetic_data() { diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 2e694588..749daa85 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -49,48 +49,6 @@ tracked_unary_functions_t hashing_functions() { return result; } -tracked_unary_functions_t sliding_hashing_functions(std::size_t window_width, std::size_t step) { -#if _SZ_DEPRECATED_FINGERPRINTS - auto wrap_sz = [=](auto function) -> unary_function_t { - return unary_function_t([function, window_width, step](std::string_view s) { - sz_size_t mixed_hash = 0; - function(s.data(), s.size(), window_width, step, _sz_hashes_fingerprint_scalar_callback, &mixed_hash); - return mixed_hash; - }); - }; -#endif - std::string suffix = std::to_string(window_width) + ":step" + std::to_string(step); - tracked_unary_functions_t result = { -#if _SZ_DEPRECATED_FINGERPRINTS -#if SZ_USE_ICE - {"sz_hashes_ice:" + suffix, wrap_sz(sz_hashes_ice)}, -#endif -#if SZ_USE_HASWELL - {"sz_hashes_haswell:" + suffix, wrap_sz(sz_hashes_haswell)}, -#endif - {"sz_hashes_serial:" + suffix, wrap_sz(sz_hashes_serial)}, -#endif - }; - return result; -} - -tracked_unary_functions_t fingerprinting_functions(std::size_t window_width = 8, std::size_t fingerprint_bytes = 4096) { - using fingerprint_slot_t = std::uint8_t; - static std::vector fingerprint; - fingerprint.resize(fingerprint_bytes / sizeof(fingerprint_slot_t)); - auto wrap_sz = [](auto function) -> unary_function_t { - return unary_function_t([function](std::string_view s) { - sz_size_t mixed_hash = 0; - sz_unused(s); - return mixed_hash; - }); - }; - tracked_unary_functions_t result = {}; - sz_unused(window_width && fingerprint_bytes); - sz_unused(wrap_sz); - return result; -} - tracked_unary_functions_t random_generation_functions(std::size_t token_length) { static std::vector buffer; if (buffer.size() < token_length) buffer.resize(token_length); @@ -183,8 +141,6 @@ void bench(strings_type &&strings) { // Benchmark logical operations bench_unary_functions(strings, checksum_functions()); bench_unary_functions(strings, hashing_functions()); - bench_unary_functions(strings, sliding_hashing_functions(8, 1)); - bench_unary_functions(strings, fingerprinting_functions()); bench_binary_functions(strings, equality_functions()); bench_binary_functions(strings, ordering_functions()); @@ -198,29 +154,7 @@ void bench(strings_type &&strings) { void bench_on_input_data(int argc, char const **argv) { dataset_t dataset = prepare_benchmark_environment(argc, argv); -#if 0 - std::printf("Benchmarking on the entire dataset:\n"); - bench_unary_functions(dataset.tokens, random_generation_functions(100)); - bench_unary_functions(dataset.tokens, random_generation_functions(20)); - bench_unary_functions(dataset.tokens, random_generation_functions(5)); - - // When performing fingerprinting, it's extremely important to: - // 1. Have small output fingerprints that fit the cache. - // 2. Have that memory in close affinity to the core, ideally on stack, to avoid cache coherency problems. - // This introduces an additional challenge for efficient fingerprinting, as the CPU caches vary a lot. - // On the Intel Sapphire Rapids 6455B Gold CPU they are 96 KiB x2 for L1d, 4 MiB x2 for L2. - // Spilling into the L3 is a bad idea. - bench_unary_functions>({dataset.text}, sliding_hashing_functions(7, 1)); - bench_unary_functions>({dataset.text}, sliding_hashing_functions(17, 4)); - bench_unary_functions>({dataset.text}, sliding_hashing_functions(33, 8)); - bench_unary_functions>({dataset.text}, sliding_hashing_functions(127, 16)); - - bench_unary_functions>({dataset.text}, hashing_functions()); - - bench_unary_functions>({dataset.text}, fingerprinting_functions(128, 4 * 1024)); - bench_unary_functions>({dataset.text}, fingerprinting_functions(128, 64 * 1024)); - bench_unary_functions>({dataset.text}, fingerprinting_functions(128, 1024 * 1024)); -#endif + // Baseline benchmarks for real words, coming in all lengths std::printf("Benchmarking on real words:\n"); bench(dataset.tokens); From 66f2ac91cff0793490bc690d1911c21976c62f82 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:27:25 +0000 Subject: [PATCH 092/751] Fix: Sorting benchmarks for new API --- scripts/bench_sort.cpp | 126 ++++++++++++++++------------------------- 1 file changed, 48 insertions(+), 78 deletions(-) diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 742d1b9b..a2c9817f 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -153,91 +153,61 @@ int main(int argc, char const **argv) { permute_base.resize(strings.size()); permute_new.resize(strings.size()); - // Partitioning - { - std::printf("---- Partitioning:\n"); - bench_permute("std::partition", strings, permute_base, [](strings_t const &strings, permute_t &permute) { - std::partition(permute.begin(), permute.end(), [&](size_t i) { return strings[i].size() < 4; }); - }); - expect_partitioned_by_length(strings, permute_base); - - bench_permute("std::stable_partition", strings, permute_base, [](strings_t const &strings, permute_t &permute) { - std::stable_partition(permute.begin(), permute.end(), [&](size_t i) { return strings[i].size() < 4; }); - }); - expect_partitioned_by_length(strings, permute_base); - - bench_permute("sz_partition", strings, permute_new, [](strings_t const &strings, permute_t &permute) { - sz_sequence_t array; - array.order = permute.data(); - array.count = strings.size(); - array.handle = &strings; - sz_partition(&array, &has_under_four_chars); - }); - expect_partitioned_by_length(strings, permute_new); - // TODO: expect_same(permute_base, permute_new); - } - // Sorting - { - std::printf("---- Sorting:\n"); - bench_permute("std::sort", strings, permute_base, [](strings_t const &strings, permute_t &permute) { - std::sort(permute.begin(), permute.end(), [&](idx_t i, idx_t j) { return strings[i] < strings[j]; }); - }); - expect_sorted(strings, permute_base); - - bench_permute("sz_sort", strings, permute_new, [](strings_t const &strings, permute_t &permute) { - sz_sequence_t array; - array.order = permute.data(); - array.count = strings.size(); - array.handle = &strings; - array.get_start = get_start; - array.get_length = get_length; - sz_sort(&array); - }); - expect_sorted(strings, permute_new); + bench_permute("std::sort", strings, permute_base, [](strings_t const &strings, permute_t &permute) { + std::sort(permute.begin(), permute.end(), [&](idx_t i, idx_t j) { return strings[i] < strings[j]; }); + }); + expect_sorted(strings, permute_base); + + bench_permute("sz_sort", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + sz_sequence_t array; + array.count = strings.size(); + array.handle = &strings; + array.get_start = get_start; + array.get_length = get_length; + sz_sort(&array, NULL, permute.data()); + }); + expect_sorted(strings, permute_new); #if __linux__ && defined(_GNU_SOURCE) - bench_permute("qsort_r", strings, permute_new, [](strings_t const &strings, permute_t &permute) { - sz_sequence_t array; - array.order = permute.data(); - array.count = strings.size(); - array.handle = &strings; - array.get_start = get_start; - array.get_length = get_length; - qsort_r(array.order, array.count, sizeof(sz_u64_t), _get_qsort_order, &array); - }); - expect_sorted(strings, permute_new); + bench_permute("qsort_r", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + sz_sequence_t array; + array.count = strings.size(); + array.handle = &strings; + array.get_start = get_start; + array.get_length = get_length; + qsort_r(permute.data(), array.count, sizeof(sz_u64_t), _get_qsort_order, &array); + }); + expect_sorted(strings, permute_new); #elif defined(_MSC_VER) - bench_permute("qsort_s", strings, permute_new, [](strings_t const &strings, permute_t &permute) { - sz_sequence_t array; - array.order = permute.data(); - array.count = strings.size(); - array.handle = &strings; - array.get_start = get_start; - array.get_length = get_length; - qsort_s(array.order, array.count, sizeof(sz_u64_t), _get_qsort_order, &array); - }); - expect_sorted(strings, permute_new); + bench_permute("qsort_s", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + sz_sequence_t array; + array.count = strings.size(); + array.handle = &strings; + array.get_start = get_start; + array.get_length = get_length; + qsort_s(permute.data(), array.count, sizeof(sz_u64_t), _get_qsort_order, &array); + }); + expect_sorted(strings, permute_new); #else - sz_unused(_get_qsort_order); + sz_unused(_get_qsort_order); #endif - bench_permute("hybrid_sort_cpp", strings, permute_new, - [](strings_t const &strings, permute_t &permute) { hybrid_sort_cpp(strings, permute.data()); }); - expect_sorted(strings, permute_new); - - std::printf("---- Stable Sorting:\n"); - bench_permute("std::stable_sort", strings, permute_base, [](strings_t const &strings, permute_t &permute) { - std::stable_sort(permute.begin(), permute.end(), [&](idx_t i, idx_t j) { return strings[i] < strings[j]; }); - }); - expect_sorted(strings, permute_base); - - bench_permute( - "hybrid_stable_sort_cpp", strings, permute_base, - [](strings_t const &strings, permute_t &permute) { hybrid_stable_sort_cpp(strings, permute.data()); }); - expect_sorted(strings, permute_new); - expect_same(permute_base, permute_new); - } + bench_permute("hybrid_sort_cpp", strings, permute_new, + [](strings_t const &strings, permute_t &permute) { hybrid_sort_cpp(strings, permute.data()); }); + expect_sorted(strings, permute_new); + + std::printf("---- Stable Sorting:\n"); + bench_permute("std::stable_sort", strings, permute_base, [](strings_t const &strings, permute_t &permute) { + std::stable_sort(permute.begin(), permute.end(), [&](idx_t i, idx_t j) { return strings[i] < strings[j]; }); + }); + expect_sorted(strings, permute_base); + + bench_permute("hybrid_stable_sort_cpp", strings, permute_base, [](strings_t const &strings, permute_t &permute) { + hybrid_stable_sort_cpp(strings, permute.data()); + }); + expect_sorted(strings, permute_new); + expect_same(permute_base, permute_new); return 0; } \ No newline at end of file From 13bace253bc112cdfd41aa7ede824b0a2bc96790 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:36:22 +0000 Subject: [PATCH 093/751] Fix: In C++11 `constexpr` constructor must be empty --- include/stringzilla/stringzilla.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 89fbd39b..664ce607 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -284,7 +284,7 @@ class basic_char_set { public: using char_type = char_type_; - constexpr basic_char_set() noexcept { + sz_constexpr_if_cpp14 basic_char_set() noexcept { // ! Instead of relying on the `sz_charset_init`, we have to reimplement it to support `constexpr`. bitset_._u64s[0] = 0, bitset_._u64s[1] = 0, bitset_._u64s[2] = 0, bitset_._u64s[3] = 0; } @@ -311,8 +311,8 @@ class basic_char_set { } } - constexpr basic_char_set(basic_char_set const &other) noexcept : bitset_(other.bitset_) {} - constexpr basic_char_set &operator=(basic_char_set const &other) noexcept { + sz_constexpr_if_cpp14 basic_char_set(basic_char_set const &other) noexcept : bitset_(other.bitset_) {} + sz_constexpr_if_cpp14 basic_char_set &operator=(basic_char_set const &other) noexcept { bitset_ = other.bitset_; return *this; } @@ -1244,8 +1244,8 @@ class basic_string_slice { : start_(c_string), length_(null_terminated_length(c_string)) {} constexpr basic_string_slice(pointer c_string, size_type length) noexcept : start_(c_string), length_(length) {} - constexpr basic_string_slice(basic_string_slice const &other) noexcept = default; - constexpr basic_string_slice &operator=(basic_string_slice const &other) noexcept = default; + basic_string_slice(basic_string_slice const &other) noexcept = default; + basic_string_slice &operator=(basic_string_slice const &other) noexcept = default; basic_string_slice(std::nullptr_t) = delete; /** @brief Exchanges the view with that of the `other`. */ From eab31371a28d07a552ab542cf371c837ab946905 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:38:31 +0000 Subject: [PATCH 094/751] Fox: C library build --- c/lib.c | 6 ++---- include/stringzilla/sort.h | 13 +++++++++++-- include/stringzilla/types.h | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/c/lib.c b/c/lib.c index 5a4183cd..361cd049 100644 --- a/c/lib.c +++ b/c/lib.c @@ -416,10 +416,8 @@ SZ_DYNAMIC sz_ssize_t sz_alignment_score( // return sz_dispatch_table.alignment_score(a, a_length, b, b_length, subs, gap, alloc); } -SZ_DYNAMIC void sz_hashes( // - sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // - sz_hash_callback_t callback, void *callback_handle) { - sz_dispatch_table.hashes(text, length, window_length, step, callback, callback_handle); +SZ_DYNAMIC sz_bool_t sz_sort(sz_sequence_t const *array, sz_memory_allocator_t *alloc, sz_size_t *order) { + return sz_dispatch_table.sort(array, alloc, order); } SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index e517159d..3ab89737 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -29,7 +29,7 @@ extern "C" { * @param order The output - indices of the sorted collection elements. * @return Whether the operation was successful. */ -SZ_PUBLIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); +SZ_DYNAMIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); /** @copydoc sz_sort */ SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, @@ -306,10 +306,19 @@ SZ_PUBLIC void _sz_sort_ice_recursively( / #pragma endregion // Ice Lake Implementation -SZ_PUBLIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { +/* Pick the right implementation for the string search algorithms. + * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. + */ +#pragma region Compile Time Dispatching +#if !SZ_DYNAMIC_DISPATCH + +SZ_DYNAMIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { return sz_sort_serial(collection, alloc, order); } +#endif // !SZ_DYNAMIC_DISPATCH +#pragma endregion // Compile Time Dispatching + #ifdef __cplusplus } #endif // __cplusplus diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 89cf1ce9..d241b69f 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -482,7 +482,7 @@ typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_s sz_error_cost_t, sz_memory_allocator_t *); /** @brief Signature of ::sz_sort. */ -typedef sz_bool_t (*sz_sort_t)(sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); +typedef sz_bool_t (*sz_sort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); #pragma endregion From 17f28a33f5c4d0ae6ddaf4ebf71ac76235df2231 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:45:40 +0000 Subject: [PATCH 095/751] Fix: `uniform_int_distribution` upper bound --- scripts/test.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test.hpp b/scripts/test.hpp index 261f90a5..6d85dec2 100644 --- a/scripts/test.hpp +++ b/scripts/test.hpp @@ -52,7 +52,7 @@ struct uniform_uint8_distribution_t { }; inline void randomize_string(char *string, std::size_t length, char const *alphabet, std::size_t cardinality) { - uniform_uint8_distribution_t distribution(cardinality); + uniform_uint8_distribution_t distribution(cardinality - 1); std::generate(string, string + length, [&]() -> char { return alphabet[distribution(global_random_generator())]; }); } From a818f978ccd1aad769c5268560929b04c2836b6b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:47:05 +0000 Subject: [PATCH 096/751] Make: Recommend pretty-printing GDB symbols --- .vscode/launch.json | 18 ++++++++++++++++-- CONTRIBUTING.md | 8 ++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 71d59186..34ec245d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -19,7 +19,14 @@ ], "stopAtEntry": false, "linux": { - "MIMode": "gdb" + "MIMode": "gdb", + "setupCommands": [ + { + "description": "Enable pretty-printing for GDB", + "text": "-enable-pretty-printing", + "ignoreFailures": true + } + ] }, "osx": { "MIMode": "lldb" @@ -48,7 +55,14 @@ "stopAtEntry": false, "preLaunchTask": "Build Benchmarks: Debug", "linux": { - "MIMode": "gdb" + "MIMode": "gdb", + "setupCommands": [ + { + "description": "Enable pretty-printing for GDB", + "text": "-enable-pretty-printing", + "ignoreFailures": true + } + ] }, "osx": { "MIMode": "lldb" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dfb4fb2f..d6009a30 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -104,11 +104,15 @@ For Python code: The primary C implementation and the C++ wrapper are built with CMake. Assuming the extensive use of new SIMD intrinsics and recent C++ language features, using a recent compiler is recommended. -We prefer GCC 12, which is available from default Ubuntu repositories with Ubuntu 22.04 LTS onwards. +We prefer GCC 12 or newer, which is available from default Ubuntu repositories with Ubuntu 22.04 LTS onwards. If this is your first experience with CMake, use the following commands to get started on Ubuntu: ```bash -sudo apt-get update && sudo apt-get install cmake build-essential libjemalloc-dev g++-12 gcc-12 +sudo apt-get update +sudo apt-get install build-essential +sudo apt-get install cmake # Consider pulling a newer version from PyPI +sudo apt-get install g++-12 gcc-12 # You may already have a newer version on Ubuntu 24 +sudo apt install libstdc++6-12-dbg # STL debugging symbols for GCC 12 ``` On MacOS it's recommended to use Homebrew and install Clang, as opposed to "Apple Clang". From 5970fa40abe7e16d6a82436ec43bbf8e978db3ba Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:47:23 +0000 Subject: [PATCH 097/751] Fix: Underflow in serial sorting --- include/stringzilla/sort.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 3ab89737..c96757e6 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -116,7 +116,7 @@ SZ_PUBLIC sz_size_t _sz_sort_serial_partition( // Loop through the collection and move the elements around the pivot. sz_size_t left_offset = start_in_collection; sz_size_t right_offset = end_in_collection - 1; - while (left_offset <= right_offset) { + while (left_offset < right_offset) { // Find the first element on the left that is greater than the pivot. while (global_windows[left_offset] < pivot_window) ++left_offset; // Find the first element on the right that is less than the pivot. @@ -188,7 +188,7 @@ SZ_PUBLIC void _sz_sort_serial_next_window( / // If the identical windows are not trivial and each string has more characters, sort them recursively sz_cptr_t current_window_str = (sz_cptr_t)¤t_window_integer; - int current_window_length = current_window_str[window_capacity]; + sz_size_t current_window_length = (sz_size_t)current_window_str[window_capacity]; if (nested_end - nested_start > 1 && current_window_length == window_capacity) { _sz_sort_serial_next_window(collection, global_windows, global_order, nested_start, nested_end, start_character + window_capacity); From 50d82910b2300377b01fb1b7044eb1d2d78387ca Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:47:37 +0000 Subject: [PATCH 098/751] Improve: Drop hybrid sort code --- scripts/bench_sort.cpp | 147 ++--------------------------------------- 1 file changed, 5 insertions(+), 142 deletions(-) diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 9b4ee90d..ac81e233 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -16,6 +16,7 @@ #include using namespace ashvardanian::stringzilla::scripts; +namespace sz = ashvardanian::stringzilla; using strings_t = std::vector; using idx_t = sz_size_t; @@ -33,11 +34,6 @@ static sz_size_t get_length(sz_sequence_t const *array_c, sz_size_t i) { return array[i].size(); } -static sz_bool_t has_under_four_chars(sz_sequence_t const *array_c, sz_size_t i) { - strings_t const &array = *reinterpret_cast(array_c->handle); - return (sz_bool_t)(array[i].size() < 4); -} - #if defined(_MSC_VER) static int _get_qsort_order(void *arg, const void *a, const void *b) { #else @@ -47,8 +43,8 @@ static int _get_qsort_order(const void *a, const void *b, void *arg) { sz_size_t idx_a = *(sz_size_t *)a; sz_size_t idx_b = *(sz_size_t *)b; - const char *str_a = sequence->get_start(sequence, idx_a); - const char *str_b = sequence->get_start(sequence, idx_b); + char const *str_a = sequence->get_start(sequence, idx_a); + char const *str_b = sequence->get_start(sequence, idx_b); sz_size_t len_a = sequence->get_length(sequence, idx_a); sz_size_t len_b = sequence->get_length(sequence, idx_b); @@ -58,136 +54,12 @@ static int _get_qsort_order(const void *a, const void *b, void *arg) { #pragma endregion -void populate_from_file(std::string path, strings_t &strings, - std::size_t limit = std::numeric_limits::max()) { - - std::ifstream f(path, std::ios::in); - std::string s; - while (strings.size() < limit && std::getline(f, s, ' ')) strings.push_back(s); -} - -constexpr size_t offset_in_word = 4; - -static idx_t hybrid_sort_cpp(strings_t const &strings, sz_u64_t *order) { - - // What if we take up-to 4 first characters and the index - for (size_t i = 0; i != strings.size(); ++i) { - size_t index = order[i]; - - for (size_t j = 0; j < std::min(strings[(sz_size_t)index].size(), 4ul); ++j) { - std::memcpy((char *)&order[i] + offset_in_word + 3 - j, strings[(sz_size_t)index].c_str() + j, 1ul); - } - } - - std::sort(order, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { - char *i_bytes = (char *)&i; - char *j_bytes = (char *)&j; - return *(uint32_t *)(i_bytes + offset_in_word) < *(uint32_t *)(j_bytes + offset_in_word); - }); - - const auto extract_bytes = [](sz_u64_t v) -> uint32_t { - char *bytes = (char *)&v; - return *(uint32_t *)(bytes + offset_in_word); - }; - - if (strings.size() >= 2) { - size_t prev_index = 0; - uint64_t prev_bytes = extract_bytes(order[0]); - - for (size_t i = 1; i < strings.size(); ++i) { - uint32_t bytes = extract_bytes(order[i]); - if (bytes != prev_bytes) { - std::sort(order + prev_index, order + i, [&](sz_u64_t i, sz_u64_t j) { - // Assumes: offset_in_word==4 - sz_size_t i_index = i & 0xFFFF'FFFF; - sz_size_t j_index = j & 0xFFFF'FFFF; - return strings[i_index] < strings[j_index]; - }); - prev_index = i; - prev_bytes = bytes; - } - } - - std::sort(order + prev_index, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { - sz_size_t i_index = i & 0xFFFF'FFFF; - sz_size_t j_index = j & 0xFFFF'FFFF; - return strings[i_index] < strings[j_index]; - }); - } - - for (size_t i = 0; i != strings.size(); ++i) std::memset((char *)&order[i] + offset_in_word, 0, 4ul); - - return strings.size(); -} - -static idx_t hybrid_stable_sort_cpp(strings_t const &strings, sz_u64_t *order) { - - // What if we take up-to 4 first characters and the index - for (size_t i = 0; i != strings.size(); ++i) { - size_t index = order[i]; - - for (size_t j = 0; j < std::min(strings[(sz_size_t)index].size(), 4ul); ++j) { - std::memcpy((char *)&order[i] + offset_in_word + 3 - j, strings[(sz_size_t)index].c_str() + j, 1ul); - } - } - - std::stable_sort(order, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { - char *i_bytes = (char *)&i; - char *j_bytes = (char *)&j; - return *(uint32_t *)(i_bytes + offset_in_word) < *(uint32_t *)(j_bytes + offset_in_word); - }); - - const auto extract_bytes = [](sz_u64_t v) -> uint32_t { - char *bytes = (char *)&v; - return *(uint32_t *)(bytes + offset_in_word); - }; - - if (strings.size() >= 2) { - size_t prev_index = 0; - uint64_t prev_bytes = extract_bytes(order[0]); - - for (size_t i = 1; i < strings.size(); ++i) { - uint32_t bytes = extract_bytes(order[i]); - if (bytes != prev_bytes) { - std::stable_sort(order + prev_index, order + i, [&](sz_u64_t i, sz_u64_t j) { - // Assumes: offset_in_word==4 - sz_size_t i_index = i & 0xFFFF'FFFF; - sz_size_t j_index = j & 0xFFFF'FFFF; - return strings[i_index] < strings[j_index]; - }); - prev_index = i; - prev_bytes = bytes; - } - } - - std::stable_sort(order + prev_index, order + strings.size(), [&](sz_u64_t i, sz_u64_t j) { - sz_size_t i_index = i & 0xFFFF'FFFF; - sz_size_t j_index = j & 0xFFFF'FFFF; - return strings[i_index] < strings[j_index]; - }); - } - - for (size_t i = 0; i != strings.size(); ++i) std::memset((char *)&order[i] + offset_in_word, 0, 4ul); - - return strings.size(); -} - -void expect_partitioned_by_length(strings_t const &strings, permute_t const &permute) { - if (!std::is_partitioned(permute.begin(), permute.end(), [&](size_t i) { return strings[i].size() < 4; })) - throw std::runtime_error("Partitioning failed!"); -} - void expect_sorted(strings_t const &strings, permute_t const &permute) { if (!std::is_sorted(permute.begin(), permute.end(), [&](std::size_t i, std::size_t j) { return strings[i] < strings[j]; })) throw std::runtime_error("Sorting failed!"); } -void expect_same(permute_t const &permute_base, permute_t const &permute_new) { - if (!std::equal(permute_base.begin(), permute_base.end(), permute_new.begin())) - throw std::runtime_error("Permutations differ!"); -} - template void bench_permute(char const *name, strings_t &strings, permute_t &permute, algo_at &&algo) { namespace stdc = std::chrono; @@ -229,7 +101,8 @@ int main(int argc, char const **argv) { array.handle = &strings; array.get_start = get_start; array.get_length = get_length; - sz_sort(&array, NULL, permute.data()); + sz::_with_alloc>( + [&](sz_memory_allocator_t &alloc) { return sz_sort(&array, &alloc, permute.data()); }); }); expect_sorted(strings, permute_new); @@ -257,21 +130,11 @@ int main(int argc, char const **argv) { sz_unused(_get_qsort_order); #endif - bench_permute("hybrid_sort_cpp", strings, permute_new, - [](strings_t const &strings, permute_t &permute) { hybrid_sort_cpp(strings, permute.data()); }); - expect_sorted(strings, permute_new); - std::printf("---- Stable Sorting:\n"); bench_permute("std::stable_sort", strings, permute_base, [](strings_t const &strings, permute_t &permute) { std::stable_sort(permute.begin(), permute.end(), [&](idx_t i, idx_t j) { return strings[i] < strings[j]; }); }); expect_sorted(strings, permute_base); - bench_permute("hybrid_stable_sort_cpp", strings, permute_new, [](strings_t const &strings, permute_t &permute) { - hybrid_stable_sort_cpp(strings, permute.data()); - }); - expect_sorted(strings, permute_new); - expect_same(permute_base, permute_new); - return 0; } From c670ccd62c36bcf8d09e8874ebc1e41d4bae93ff Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:48:04 +0000 Subject: [PATCH 099/751] Add: String sorting tests for different lengths --- scripts/test.cpp | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/scripts/test.cpp b/scripts/test.cpp index 0cf11552..d4df88eb 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1595,22 +1595,49 @@ static void test_sequence_algorithms() { using strs_t = std::vector; using order_t = std::vector; + // Basic tests with predetermined orders. assert_scoped(strs_t x({"a", "b", "c", "d"}), (void)0, sz::sorted_order(x) == order_t({0u, 1u, 2u, 3u})); assert_scoped(strs_t x({"b", "c", "d", "a"}), (void)0, sz::sorted_order(x) == order_t({3u, 0u, 1u, 2u})); assert_scoped(strs_t x({"b", "a", "d", "c"}), (void)0, sz::sorted_order(x) == order_t({1u, 0u, 3u, 2u})); - // Generate random strings of different lengths. - for (std::size_t dataset_size : {10, 100, 1000, 10000}) { - // Build the dataset. + // Test on long strings of identical length. + for (std::size_t dataset_size : {10u, 40u, 1000u, 10000u}) { strs_t dataset; - for (std::size_t i = 0; i != dataset_size; ++i) - dataset.push_back(sz::scripts::random_string(i % 32, "abcdefghijklmnopqrstuvwxyz", 26)); + constexpr std::size_t long_length = 20; + dataset.reserve(dataset_size); + for (std::size_t i = 0; i < dataset_size; ++i) + dataset.push_back(sz::scripts::random_string(long_length, "abcd", 4)); + + auto order = sz::sorted_order(dataset); + for (std::size_t i = 1; i < dataset.size(); ++i) assert(dataset[order[i - 1]] <= dataset[order[i]]); + } + + // Test on random strings of varying (but small) lengths. + for (std::size_t dataset_size : {10u, 40u, 1000u, 10000u}) { + strs_t dataset; + dataset.reserve(dataset_size); + for (std::size_t i = 0; i < dataset_size; ++i) dataset.push_back(sz::scripts::random_string(i % 32, "abcd", 4)); + + // Run several iterations of fuzzy tests. + for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { + std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); + auto order = sz::sorted_order(dataset); + for (std::size_t i = 1; i < dataset_size; ++i) { assert(dataset[order[i - 1]] <= dataset[order[i]]); } + } + } + + // Test on random strings of varying lengths with zero characters. + for (std::size_t dataset_size : {10u, 100u, 1000u, 10000u}) { + strs_t dataset; + dataset.reserve(dataset_size); + for (std::size_t i = 0; i < dataset_size; ++i) + dataset.push_back(sz::scripts::random_string(i % 32, "abcd\0", 5)); // Run several iterations of fuzzy tests. - for (std::size_t experiment_idx = 0; experiment_idx != 10; ++experiment_idx) { + for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); auto order = sz::sorted_order(dataset); - for (std::size_t i = 1; i != dataset_size; ++i) { assert(dataset[order[i - 1]] <= dataset[order[i]]); } + for (std::size_t i = 1; i < dataset_size; ++i) { assert(dataset[order[i - 1]] <= dataset[order[i]]); } } } } From 0fda5a55523d66bb1bfed3c4cd81635e9df582da Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 12:23:09 +0000 Subject: [PATCH 100/751] Fix: `sz_sort_serial` passes for same length inputs --- include/stringzilla/sort.h | 30 +++++++++++++++++++++++++----- scripts/test.cpp | 26 ++++++-------------------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index c96757e6..d72429ce 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -76,7 +76,28 @@ SZ_PUBLIC void _sz_sort_serial_export_prefixes( // #else *target_integer = sz_u32_bytes_reverse(*target_integer); #endif + _sz_assert( // + (length <= start_in_collection) == (*target_integer == 0) && // + "We can have a zero value if only the string is shorter than other strings at this position."); } + + // As our goal is to sort the strings using the exported integer "windows", + // this is a good place to validate the correctness of the exported data. + if (SZ_DEBUG && start_character == 0) + for (sz_size_t i = start_in_collection + 1; i < end_in_collection; ++i) { + _sz_sorting_window_t const previous_window = global_windows[i - 1]; + _sz_sorting_window_t const current_window = global_windows[i]; + sz_cptr_t const previous_str = collection->get_start(collection, i - 1); + sz_size_t const previous_length = collection->get_length(collection, i - 1); + sz_cptr_t const current_str = collection->get_start(collection, i); + sz_size_t const current_length = collection->get_length(collection, i); + sz_ordering_t const ordering = sz_order( // + previous_str, previous_length > window_capacity ? window_capacity : previous_length, // + current_str, current_length > window_capacity ? window_capacity : current_length); + _sz_assert( // + (previous_window < current_window) == (ordering == sz_less_k) && // + "The exported windows should be in the same order as the original strings."); + } } /** @@ -143,21 +164,20 @@ SZ_PUBLIC void _sz_sort_serial_recursively( / _sz_sorting_window_t *const global_windows, sz_size_t *const global_order, // sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t const start_character) { + // Partition the collection around some pivot sz_size_t pivot_index = _sz_sort_serial_partition(global_windows, global_order, start_in_collection, end_in_collection); // Recursively sort the left partition - if (start_in_collection < pivot_index) { + if (start_in_collection < pivot_index) _sz_sort_serial_recursively(collection, global_windows, global_order, start_in_collection, pivot_index, start_character); - } // Recursively sort the right partition - if (pivot_index + 1 < end_in_collection) { + if (pivot_index + 1 < end_in_collection) _sz_sort_serial_recursively(collection, global_windows, global_order, pivot_index + 1, end_in_collection, start_character); - } } SZ_PUBLIC void _sz_sort_serial_next_window( // @@ -258,7 +278,7 @@ SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_al if (!windows) return sz_false_k; // Recursively sort the whole collection. - _sz_sort_serial_recursively(collection, windows, order, 0, collection->count, 0); + _sz_sort_serial_next_window(collection, windows, order, 0, collection->count, 0); // Free temporary storage. alloc->free(windows, collection->count * sizeof(_sz_sorting_window_t), alloc); diff --git a/scripts/test.cpp b/scripts/test.cpp index d4df88eb..04471495 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1601,22 +1601,22 @@ static void test_sequence_algorithms() { assert_scoped(strs_t x({"b", "a", "d", "c"}), (void)0, sz::sorted_order(x) == order_t({1u, 0u, 3u, 2u})); // Test on long strings of identical length. - for (std::size_t dataset_size : {10u, 40u, 1000u, 10000u}) { + for (std::size_t dataset_size : {10u, 100u, 1000u, 10000u}) { strs_t dataset; constexpr std::size_t long_length = 20; dataset.reserve(dataset_size); for (std::size_t i = 0; i < dataset_size; ++i) - dataset.push_back(sz::scripts::random_string(long_length, "abcd", 4)); + dataset.push_back(sz::scripts::random_string(long_length, "ab", 2)); auto order = sz::sorted_order(dataset); for (std::size_t i = 1; i < dataset.size(); ++i) assert(dataset[order[i - 1]] <= dataset[order[i]]); } // Test on random strings of varying (but small) lengths. - for (std::size_t dataset_size : {10u, 40u, 1000u, 10000u}) { + for (std::size_t dataset_size : {10u, 100u, 1000u, 10000u}) { strs_t dataset; dataset.reserve(dataset_size); - for (std::size_t i = 0; i < dataset_size; ++i) dataset.push_back(sz::scripts::random_string(i % 32, "abcd", 4)); + for (std::size_t i = 0; i < dataset_size; ++i) dataset.push_back(sz::scripts::random_string(i % 32, "ab", 2)); // Run several iterations of fuzzy tests. for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { @@ -1630,8 +1630,7 @@ static void test_sequence_algorithms() { for (std::size_t dataset_size : {10u, 100u, 1000u, 10000u}) { strs_t dataset; dataset.reserve(dataset_size); - for (std::size_t i = 0; i < dataset_size; ++i) - dataset.push_back(sz::scripts::random_string(i % 32, "abcd\0", 5)); + for (std::size_t i = 0; i < dataset_size; ++i) dataset.push_back(sz::scripts::random_string(i % 32, "ab\0", 3)); // Run several iterations of fuzzy tests. for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { @@ -1658,20 +1657,7 @@ static void test_stl_containers() { } int main(int argc, char const **argv) { - - auto dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, SZ_SIZE_MAX); - _sz_assert(dist == 5); - dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 3); - _sz_assert(dist == SZ_SIZE_MAX); - dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 4); - _sz_assert(dist == SZ_SIZE_MAX); - dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 5); - _sz_assert(dist == 5); - dist = _sz_edit_distance_skewed_diagonals_upto63_ice("kiten", 5, "katerinas", 9, 6); - _sz_assert(dist == 5); - - // Similarity measures and fuzzy search - test_levenshtein_distances(); + test_sequence_algorithms(); // Let's greet the user nicely sz_unused(argc && argv); From bdee11181a0e86263b91da0705b6a4bb45489eab Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 20:41:43 +0000 Subject: [PATCH 101/751] Fix: `uniform_int_distribution` lower bound --- scripts/test.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test.hpp b/scripts/test.hpp index 6d85dec2..6c37e9f6 100644 --- a/scripts/test.hpp +++ b/scripts/test.hpp @@ -52,7 +52,7 @@ struct uniform_uint8_distribution_t { }; inline void randomize_string(char *string, std::size_t length, char const *alphabet, std::size_t cardinality) { - uniform_uint8_distribution_t distribution(cardinality - 1); + uniform_uint8_distribution_t distribution(0, cardinality - 1); std::generate(string, string + length, [&]() -> char { return alphabet[distribution(global_random_generator())]; }); } From 8bad799b72ba38728ad0829e1fba25dff9cd6746 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 15 Feb 2025 22:04:29 +0000 Subject: [PATCH 102/751] Fix: `sz_sort_serial` passes tests Benchmarks on Sapphire Rapids suggest: - For 8.3 M words in Leipzig1M.txt of length ~5 -- `std::sort` is 2 seconds -- `sz_sort_serial` is 0.6 seconds -- `qsort_r` is 3.2 seconds - For 268 M words in XLSum.csv of length ~8 -- `std::sort` is 147 seconds -- `sz_sort_serial` is 29 seconds -- `qsort_r` is 192 seconds --- include/stringzilla/sort.h | 120 +++++++++++++++++++++++-------------- scripts/bench_sort.cpp | 4 +- scripts/test.cpp | 36 ++++++++--- 3 files changed, 104 insertions(+), 56 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index d72429ce..2e35282f 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -48,10 +48,10 @@ SZ_PUBLIC sz_bool_t sz_sort_sve(sz_sequence_t const *collection, sz_memory_alloc typedef sz_size_t _sz_sorting_window_t; -SZ_PUBLIC void _sz_sort_serial_export_prefixes( // - sz_sequence_t const *const collection, // - _sz_sorting_window_t *const global_windows, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_sort_serial_export_prefixes( // + sz_sequence_t const *const collection, // + _sz_sorting_window_t *const global_windows, sz_sorted_idx_t const *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t const start_character) { // Depending on the architecture, we will export a different number of bytes. @@ -60,11 +60,18 @@ SZ_PUBLIC void _sz_sort_serial_export_prefixes( // // Perform the same operation for every string. for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { + + // On the first recursion level, the `global_order` is the identity permutation. + sz_sorted_idx_t const partial_order_index = global_order[i]; + if (SZ_DEBUG && start_character == 0) + _sz_assert(partial_order_index == i && "At start this must be an identity permutation."); + // Get the string slice in global memory. - sz_cptr_t const source_str = collection->get_start(collection, i); - sz_size_t const length = collection->get_length(collection, i); + sz_cptr_t const source_str = collection->get_start(collection, partial_order_index); + sz_size_t const length = collection->get_length(collection, partial_order_index); sz_size_t const remaining_length = length > start_character ? length - start_character : 0; sz_size_t const exported_length = remaining_length > window_capacity ? window_capacity : remaining_length; + // Fill with zeros, export a slice, and mark the exported length. sz_size_t *target_integer = &global_windows[i]; sz_ptr_t target_str = (sz_ptr_t)target_integer; @@ -76,8 +83,8 @@ SZ_PUBLIC void _sz_sort_serial_export_prefixes( // #else *target_integer = sz_u32_bytes_reverse(*target_integer); #endif - _sz_assert( // - (length <= start_in_collection) == (*target_integer == 0) && // + _sz_assert( // + (length <= start_character) == (*target_integer == 0) && // "We can have a zero value if only the string is shorter than other strings at this position."); } @@ -101,19 +108,25 @@ SZ_PUBLIC void _sz_sort_serial_export_prefixes( // } /** - * @brief Helper function of the serial QuickSort algorithm, that rearranges the elements in + * @brief The most important part of the QuickSort algorithm, that rearranges the elements in * such a way, that all entries around the pivot are less than the pivot. * * It means that no relative order among the elements on the left or right side of the pivot is preserved. * We chose the pivot point using Robert Sedgewick's method - the median of three elements - the first, * the middle, and the last element of the given range. + * + * Moreover, considering our iterative refinement procedure, we can't just use the normal 2-way partitioning, + * as it will scatter the values equal to the pivot into the left and right partitions. Instead we use the + * Dutch National Flag @b 3-way partitioning, outputting the range of values equal to the pivot. + * + * @see https://en.wikipedia.org/wiki/Dutch_national_flag_problem */ -SZ_PUBLIC sz_size_t _sz_sort_serial_partition( // +SZ_PUBLIC void _sz_sort_serial_3way_partition( // _sz_sorting_window_t *const global_windows, sz_sorted_idx_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection) { + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // + sz_size_t *first_pivot_offset, sz_size_t *last_pivot_offset) { - // Chose the pivot offset. - sz_size_t pivot_offset; + // Chose the pivot offset with Sedgewick's method. _sz_sorting_window_t pivot_window; { sz_size_t const middle_offset = start_in_collection + (end_in_collection - start_in_collection) / 2; @@ -123,40 +136,52 @@ SZ_PUBLIC sz_size_t _sz_sort_serial_partition( _sz_sorting_window_t const middle_window = global_windows[middle_offset]; _sz_sorting_window_t const last_window = global_windows[last_offset]; if (first_window < middle_window) { - if (middle_window < last_window) { pivot_offset = middle_offset, pivot_window = middle_window; } - else if (first_window < last_window) { pivot_offset = last_offset, pivot_window = last_window; } - else { pivot_offset = first_offset, pivot_window = first_window; } + if (middle_window < last_window) { pivot_window = middle_window; } + else if (first_window < last_window) { pivot_window = last_window; } + else { pivot_window = first_window; } } else { - if (first_window < last_window) { pivot_offset = first_offset, pivot_window = first_window; } - else if (middle_window < last_window) { pivot_offset = last_offset, pivot_window = last_window; } - else { pivot_offset = middle_offset, pivot_window = middle_window; } + if (first_window < last_window) { pivot_window = first_window; } + else if (middle_window < last_window) { pivot_window = last_window; } + else { pivot_window = middle_window; } } } - // Loop through the collection and move the elements around the pivot. - sz_size_t left_offset = start_in_collection; - sz_size_t right_offset = end_in_collection - 1; - while (left_offset < right_offset) { - // Find the first element on the left that is greater than the pivot. - while (global_windows[left_offset] < pivot_window) ++left_offset; - // Find the first element on the right that is less than the pivot. - while (global_windows[right_offset] > pivot_window) --right_offset; - // Swap the elements if they are in the wrong order. - if (left_offset <= right_offset) { + // Loop through the collection and move the elements around the pivot with the 3-way partitioning. + sz_size_t partitioning_progress = start_in_collection; // Current index. + sz_size_t less_than_pivot_offset = start_in_collection; // Boundary for elements < pivot_window. + sz_size_t greater_than_pivot_offset = end_in_collection - 1; // Boundary for elements > pivot_window. + + while (partitioning_progress <= greater_than_pivot_offset) { + // Element is less than pivot: swap into the < pivot region. + if (global_windows[partitioning_progress] < pivot_window) { +#if defined(_SZ_IS_64_BIT) + sz_u64_swap(&global_order[partitioning_progress], &global_order[less_than_pivot_offset]); + sz_u64_swap(&global_windows[partitioning_progress], &global_windows[less_than_pivot_offset]); +#else + sz_u32_swap(&global_order[partitioning_progress], &global_order[less_than_pivot_offset]); + sz_u32_swap(&global_windows[partitioning_progress], &global_windows[less_than_pivot_offset]); +#endif + ++partitioning_progress; + ++less_than_pivot_offset; + } + // Element is greater than pivot: swap into the > pivot region. + else if (global_windows[partitioning_progress] > pivot_window) { #if defined(_SZ_IS_64_BIT) - sz_u64_swap(&global_order[left_offset], &global_order[right_offset]); - sz_u64_swap(&global_windows[left_offset], &global_windows[right_offset]); + sz_u64_swap(&global_order[partitioning_progress], &global_order[greater_than_pivot_offset]); + sz_u64_swap(&global_windows[partitioning_progress], &global_windows[greater_than_pivot_offset]); #else - sz_u32_swap(&global_order[left_offset], &global_order[right_offset]); - sz_u32_swap(&global_windows[left_offset], &global_windows[right_offset]); + sz_u32_swap(&global_order[partitioning_progress], &global_order[greater_than_pivot_offset]); + sz_u32_swap(&global_windows[partitioning_progress], &global_windows[greater_than_pivot_offset]); #endif - ++left_offset; - --right_offset; + --greater_than_pivot_offset; } + // Element equals pivot_window: leave it in place. + else { ++partitioning_progress; } } - return pivot_offset; + *first_pivot_offset = less_than_pivot_offset; + *last_pivot_offset = greater_than_pivot_offset; } SZ_PUBLIC void _sz_sort_serial_recursively( // @@ -165,18 +190,21 @@ SZ_PUBLIC void _sz_sort_serial_recursively( / sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t const start_character) { - // Partition the collection around some pivot - sz_size_t pivot_index = - _sz_sort_serial_partition(global_windows, global_order, start_in_collection, end_in_collection); + // Partition the collection around some pivot or 2 pivots in a 3-way partitioning + sz_size_t first_pivot_index, last_pivot_index; + _sz_sort_serial_3way_partition( // + global_windows, global_order, // + start_in_collection, end_in_collection, // + &first_pivot_index, &last_pivot_index); // Recursively sort the left partition - if (start_in_collection < pivot_index) - _sz_sort_serial_recursively(collection, global_windows, global_order, start_in_collection, pivot_index, + if (start_in_collection < first_pivot_index) + _sz_sort_serial_recursively(collection, global_windows, global_order, start_in_collection, first_pivot_index, start_character); // Recursively sort the right partition - if (pivot_index + 1 < end_in_collection) - _sz_sort_serial_recursively(collection, global_windows, global_order, pivot_index + 1, end_in_collection, + if (last_pivot_index + 1 < end_in_collection) + _sz_sort_serial_recursively(collection, global_windows, global_order, last_pivot_index + 1, end_in_collection, start_character); } @@ -187,7 +215,7 @@ SZ_PUBLIC void _sz_sort_serial_next_window( / sz_size_t const start_character) { // Prepare the new range of windows - _sz_sort_serial_export_prefixes(collection, global_windows, start_in_collection, end_in_collection, + _sz_sort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, start_character); // Sort current windows with a quicksort @@ -208,7 +236,7 @@ SZ_PUBLIC void _sz_sort_serial_next_window( / // If the identical windows are not trivial and each string has more characters, sort them recursively sz_cptr_t current_window_str = (sz_cptr_t)¤t_window_integer; - sz_size_t current_window_length = (sz_size_t)current_window_str[window_capacity]; + sz_size_t current_window_length = (sz_size_t)current_window_str[0]; //! The byte order was swapped if (nested_end - nested_start > 1 && current_window_length == window_capacity) { _sz_sort_serial_next_window(collection, global_windows, global_order, nested_start, nested_end, start_character + window_capacity); @@ -296,7 +324,7 @@ SZ_PUBLIC void _sz_sort_ice_recursively( / sz_size_t const start_character) { // Prepare the new range of windows - _sz_sort_serial_export_prefixes(collection, global_windows, start_in_collection, end_in_collection, + _sz_sort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, start_character); // We can implement a form of a Radix sort here, that will count the number of elements with diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index ac81e233..75800582 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -95,14 +95,14 @@ int main(int argc, char const **argv) { }); expect_sorted(strings, permute_base); - bench_permute("sz_sort", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + bench_permute("sz_sort_serial", strings, permute_new, [](strings_t const &strings, permute_t &permute) { sz_sequence_t array; array.count = strings.size(); array.handle = &strings; array.get_start = get_start; array.get_length = get_length; sz::_with_alloc>( - [&](sz_memory_allocator_t &alloc) { return sz_sort(&array, &alloc, permute.data()); }); + [&](sz_memory_allocator_t &alloc) { return sz_sort_serial(&array, &alloc, permute.data()); }); }); expect_sorted(strings, permute_new); diff --git a/scripts/test.cpp b/scripts/test.cpp index 04471495..d8f0cdd6 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1601,22 +1601,43 @@ static void test_sequence_algorithms() { assert_scoped(strs_t x({"b", "a", "d", "c"}), (void)0, sz::sorted_order(x) == order_t({1u, 0u, 3u, 2u})); // Test on long strings of identical length. + for (std::size_t string_length : {5u, 25u}) { + for (std::size_t dataset_size : {10u, 100u, 1000u, 10000u}) { + strs_t dataset; + dataset.reserve(dataset_size); + for (std::size_t i = 0; i < dataset_size; ++i) + dataset.push_back(sz::scripts::random_string(string_length, "ab", 2)); + + // Run several iterations of fuzzy tests. + for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { + std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); + auto order = sz::sorted_order(dataset); + for (std::size_t i = 1; i < dataset.size(); ++i) assert(dataset[order[i - 1]] <= dataset[order[i]]); + } + } + } + + // Test on random very small strings of varying lengths, likely with many equal inputs. for (std::size_t dataset_size : {10u, 100u, 1000u, 10000u}) { strs_t dataset; - constexpr std::size_t long_length = 20; dataset.reserve(dataset_size); - for (std::size_t i = 0; i < dataset_size; ++i) - dataset.push_back(sz::scripts::random_string(long_length, "ab", 2)); + for (std::size_t i = 0; i < dataset_size; ++i) dataset.push_back(sz::scripts::random_string(i % 6, "ab", 2)); - auto order = sz::sorted_order(dataset); - for (std::size_t i = 1; i < dataset.size(); ++i) assert(dataset[order[i - 1]] <= dataset[order[i]]); + // Run several iterations of fuzzy tests. + for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { + std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); + auto order = sz::sorted_order(dataset); + for (std::size_t i = 1; i < dataset_size; ++i) { assert(dataset[order[i - 1]] <= dataset[order[i]]); } + } } - // Test on random strings of varying (but small) lengths. + // Test on random strings of varying lengths. for (std::size_t dataset_size : {10u, 100u, 1000u, 10000u}) { strs_t dataset; dataset.reserve(dataset_size); - for (std::size_t i = 0; i < dataset_size; ++i) dataset.push_back(sz::scripts::random_string(i % 32, "ab", 2)); + constexpr std::size_t min_length = 6; + for (std::size_t i = 0; i < dataset_size; ++i) + dataset.push_back(sz::scripts::random_string(min_length + i % 32, "ab", 2)); // Run several iterations of fuzzy tests. for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { @@ -1657,7 +1678,6 @@ static void test_stl_containers() { } int main(int argc, char const **argv) { - test_sequence_algorithms(); // Let's greet the user nicely sz_unused(argc && argv); From 6191cc6173b4867223e086599342d876c5bf09ec Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 16 Feb 2025 10:55:42 +0000 Subject: [PATCH 103/751] Improve: Rename `sz_sort` to `sz_qsort` Makes it easier to differentiate stable `sz_msort` --- include/stringzilla/sort.h | 140 ++++++++++++++++++++----------------- 1 file changed, 75 insertions(+), 65 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 2e35282f..9af46a20 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -5,8 +5,13 @@ * * Includes core APIs: * - * - `sz_sort` - to sort an arbitrary string collection. - * - TODO: `sz_stable_sort` - to sort a string collection while preserving the relative order of equal elements. + * - `sz_qsort` - to sort an arbitrary string collection with QuickSort-like algorithm. + * - `sz_qsort_addresses` - to sort a collection of continuous pointer-sized integers with QuickSort-like algorithm. + * - `sz_msort` - to sort an arbitrary string collection with a MergeSort-like algorithm. + * - `sz_msort_addresses` - to sort a collection of continuous pointer-sized integers with a MergeSort-like algorithm. + * + * The `qsort` variants are not guaranteed to be stable. + * The `msort` variants are guaranteed to be stable. */ #ifndef STRINGZILLA_SORT_H_ #define STRINGZILLA_SORT_H_ @@ -29,34 +34,39 @@ extern "C" { * @param order The output - indices of the sorted collection elements. * @return Whether the operation was successful. */ -SZ_DYNAMIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); - -/** @copydoc sz_sort */ -SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_DYNAMIC sz_bool_t sz_qsort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** @copydoc sz_sort */ -SZ_PUBLIC sz_bool_t sz_sort_skylake(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, +/** @copydoc sz_qsort */ +SZ_PUBLIC sz_bool_t sz_qsort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** @copydoc sz_sort */ -SZ_PUBLIC sz_bool_t sz_sort_sve(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); +/** @copydoc sz_qsort */ +SZ_PUBLIC sz_bool_t sz_qsort_skylake(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_qsort */ +SZ_PUBLIC sz_bool_t sz_qsort_sve(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); #pragma endregion #pragma region Serial Implementation -typedef sz_size_t _sz_sorting_window_t; +/** + * The core idea of all following string algorithms is to sort strings not based on 1 character at a time, + * but on a larger integer word fitting in 4 or 8 bytes at once, on 32-bit or 64-bit architectures, respectively. + * That word is pointer-sized, but it may contain extra information aside from N characters. + */ +typedef sz_size_t _sz_sort_ngram_t; -SZ_PUBLIC void _sz_sort_serial_export_prefixes( // - sz_sequence_t const *const collection, // - _sz_sorting_window_t *const global_windows, sz_sorted_idx_t const *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_qsort_serial_export_prefixes( // + sz_sequence_t const *const collection, // + _sz_sort_ngram_t *const global_windows, sz_sorted_idx_t const *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t const start_character) { // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. - sz_size_t const window_capacity = sizeof(_sz_sorting_window_t) - 1; + sz_size_t const window_capacity = sizeof(_sz_sort_ngram_t) - 1; // Perform the same operation for every string. for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { @@ -92,8 +102,8 @@ SZ_PUBLIC void _sz_sort_serial_export_prefixes( // this is a good place to validate the correctness of the exported data. if (SZ_DEBUG && start_character == 0) for (sz_size_t i = start_in_collection + 1; i < end_in_collection; ++i) { - _sz_sorting_window_t const previous_window = global_windows[i - 1]; - _sz_sorting_window_t const current_window = global_windows[i]; + _sz_sort_ngram_t const previous_window = global_windows[i - 1]; + _sz_sort_ngram_t const current_window = global_windows[i]; sz_cptr_t const previous_str = collection->get_start(collection, i - 1); sz_size_t const previous_length = collection->get_length(collection, i - 1); sz_cptr_t const current_str = collection->get_start(collection, i); @@ -121,20 +131,20 @@ SZ_PUBLIC void _sz_sort_serial_export_prefixes( * * @see https://en.wikipedia.org/wiki/Dutch_national_flag_problem */ -SZ_PUBLIC void _sz_sort_serial_3way_partition( // - _sz_sorting_window_t *const global_windows, sz_sorted_idx_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_qsort_serial_3way_partition( // + _sz_sort_ngram_t *const global_windows, sz_sorted_idx_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t *first_pivot_offset, sz_size_t *last_pivot_offset) { // Chose the pivot offset with Sedgewick's method. - _sz_sorting_window_t pivot_window; + _sz_sort_ngram_t pivot_window; { sz_size_t const middle_offset = start_in_collection + (end_in_collection - start_in_collection) / 2; sz_size_t const last_offset = end_in_collection - 1; sz_size_t const first_offset = start_in_collection; - _sz_sorting_window_t const first_window = global_windows[first_offset]; - _sz_sorting_window_t const middle_window = global_windows[middle_offset]; - _sz_sorting_window_t const last_window = global_windows[last_offset]; + _sz_sort_ngram_t const first_window = global_windows[first_offset]; + _sz_sort_ngram_t const middle_window = global_windows[middle_offset]; + _sz_sort_ngram_t const last_window = global_windows[last_offset]; if (first_window < middle_window) { if (middle_window < last_window) { pivot_window = middle_window; } else if (first_window < last_window) { pivot_window = last_window; } @@ -184,70 +194,70 @@ SZ_PUBLIC void _sz_sort_serial_3way_partition( *last_pivot_offset = greater_than_pivot_offset; } -SZ_PUBLIC void _sz_sort_serial_recursively( // - sz_sequence_t const *const collection, // - _sz_sorting_window_t *const global_windows, sz_size_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_qsort_serial_recursively( // + sz_sequence_t const *const collection, // + _sz_sort_ngram_t *const global_windows, sz_size_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t const start_character) { // Partition the collection around some pivot or 2 pivots in a 3-way partitioning sz_size_t first_pivot_index, last_pivot_index; - _sz_sort_serial_3way_partition( // + _sz_qsort_serial_3way_partition( // global_windows, global_order, // start_in_collection, end_in_collection, // &first_pivot_index, &last_pivot_index); // Recursively sort the left partition if (start_in_collection < first_pivot_index) - _sz_sort_serial_recursively(collection, global_windows, global_order, start_in_collection, first_pivot_index, - start_character); + _sz_qsort_serial_recursively(collection, global_windows, global_order, start_in_collection, first_pivot_index, + start_character); // Recursively sort the right partition if (last_pivot_index + 1 < end_in_collection) - _sz_sort_serial_recursively(collection, global_windows, global_order, last_pivot_index + 1, end_in_collection, - start_character); + _sz_qsort_serial_recursively(collection, global_windows, global_order, last_pivot_index + 1, end_in_collection, + start_character); } -SZ_PUBLIC void _sz_sort_serial_next_window( // - sz_sequence_t const *const collection, // - _sz_sorting_window_t *const global_windows, sz_size_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_qsort_serial_next_window( // + sz_sequence_t const *const collection, // + _sz_sort_ngram_t *const global_windows, sz_size_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t const start_character) { // Prepare the new range of windows - _sz_sort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, - start_character); + _sz_qsort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, + start_character); // Sort current windows with a quicksort - _sz_sort_serial_recursively(collection, global_windows, global_order, start_in_collection, end_in_collection, - start_character); + _sz_qsort_serial_recursively(collection, global_windows, global_order, start_in_collection, end_in_collection, + start_character); // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. - sz_size_t const window_capacity = sizeof(_sz_sorting_window_t) - 1; + sz_size_t const window_capacity = sizeof(_sz_sort_ngram_t) - 1; // Repeat the procedure for the identical windows sz_size_t nested_start = start_in_collection; sz_size_t nested_end = start_in_collection; while (nested_end != end_in_collection) { // Find the end of the identical windows - _sz_sorting_window_t current_window_integer = global_windows[nested_start]; + _sz_sort_ngram_t current_window_integer = global_windows[nested_start]; while (nested_end != end_in_collection && current_window_integer == global_windows[nested_end]) ++nested_end; // If the identical windows are not trivial and each string has more characters, sort them recursively sz_cptr_t current_window_str = (sz_cptr_t)¤t_window_integer; sz_size_t current_window_length = (sz_size_t)current_window_str[0]; //! The byte order was swapped if (nested_end - nested_start > 1 && current_window_length == window_capacity) { - _sz_sort_serial_next_window(collection, global_windows, global_order, nested_start, nested_end, - start_character + window_capacity); + _sz_qsort_serial_next_window(collection, global_windows, global_order, nested_start, nested_end, + start_character + window_capacity); } // Move to the next nested_start = nested_end; } } -SZ_PUBLIC void _sz_sort_serial_insertion(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC void _sz_qsort_serial_insertion(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // This algorithm needs no memory allocations: sz_unused(alloc); @@ -277,8 +287,8 @@ SZ_PUBLIC void _sz_sort_serial_insertion(sz_sequence_t const *collection, sz_mem } } -SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_bool_t sz_qsort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != collection->count; ++i) order[i] = i; @@ -286,7 +296,7 @@ SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_al // On very small collections - just use the quadratic-complexity insertion sort // without any smart optimizations or memory allocations. if (collection->count <= 32) { - _sz_sort_serial_insertion(collection, alloc, order); + _sz_qsort_serial_insertion(collection, alloc, order); return sz_true_k; } @@ -301,15 +311,15 @@ SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_al // Assuming that some strings may contain or even end with NULL bytes, we need to make sure, that their length // is included in those P-long words. So, in reality, we will be taking (P-1) bytes from each string on every // iteration of a recursive algorithm. - _sz_sorting_window_t *windows = - (_sz_sorting_window_t *)alloc->allocate(collection->count * sizeof(_sz_sorting_window_t), alloc); + _sz_sort_ngram_t *windows = + (_sz_sort_ngram_t *)alloc->allocate(collection->count * sizeof(_sz_sort_ngram_t), alloc); if (!windows) return sz_false_k; // Recursively sort the whole collection. - _sz_sort_serial_next_window(collection, windows, order, 0, collection->count, 0); + _sz_qsort_serial_next_window(collection, windows, order, 0, collection->count, 0); // Free temporary storage. - alloc->free(windows, collection->count * sizeof(_sz_sorting_window_t), alloc); + alloc->free(windows, collection->count * sizeof(_sz_sort_ngram_t), alloc); return sz_true_k; } @@ -317,22 +327,22 @@ SZ_PUBLIC sz_bool_t sz_sort_serial(sz_sequence_t const *collection, sz_memory_al #pragma region Ice Lake Implementation -SZ_PUBLIC void _sz_sort_ice_recursively( // - sz_sequence_t const *const collection, // - _sz_sorting_window_t *const global_windows, sz_size_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_qsort_ice_recursively( // + sz_sequence_t const *const collection, // + _sz_sort_ngram_t *const global_windows, sz_size_t *const global_order, // + sz_size_t const start_in_collection, sz_size_t const end_in_collection, // sz_size_t const start_character) { // Prepare the new range of windows - _sz_sort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, - start_character); + _sz_qsort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, + start_character); // We can implement a form of a Radix sort here, that will count the number of elements with // a certain bit set. The naive approach may require too many loops over data. A more "vectorized" // approach would be to maintain a histogram for several bits at once. For 4 bits we will // need 2^4 = 16 counters. sz_size_t histogram[16] = {0}; - for (sz_size_t byte_in_window = 0; byte_in_window != sizeof(_sz_sorting_window_t); ++byte_in_window) { + for (sz_size_t byte_in_window = 0; byte_in_window != sizeof(_sz_sort_ngram_t); ++byte_in_window) { // First sort based on the low nibble of each byte. for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { sz_size_t const byte = (global_windows[i] >> (byte_in_window * 8)) & 0xFF; @@ -360,8 +370,8 @@ SZ_PUBLIC void _sz_sort_ice_recursively( / #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_bool_t sz_sort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { - return sz_sort_serial(collection, alloc, order); +SZ_DYNAMIC sz_bool_t sz_qsort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { + return sz_qsort_serial(collection, alloc, order); } #endif // !SZ_DYNAMIC_DISPATCH From dcf6c653931b19df6379ebc5476d1e302076c259 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 16 Feb 2025 22:59:16 +0000 Subject: [PATCH 104/751] Improve: Introduce typed `_sz_swap` macro --- include/stringzilla/similarity.h | 16 ++++++++-------- include/stringzilla/types.h | 10 ++++++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index 5c521a40..188169ff 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -437,8 +437,8 @@ SZ_PUBLIC sz_size_t sz_edit_distance_serial( // // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); + _sz_swap(sz_size_t, longer_length, shorter_length); + _sz_swap(sz_cptr_t, longer, shorter); } // Skip the matching prefixes and suffixes, they won't affect the distance. @@ -478,8 +478,8 @@ SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); + _sz_swap(sz_size_t, longer_length, shorter_length); + _sz_swap(sz_cptr_t, longer, shorter); } // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. @@ -513,7 +513,7 @@ SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // } // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); + _sz_swap(sz_ssize_t *, previous_distances, current_distances); } // Cache scalar before `free` call. @@ -1101,8 +1101,8 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. if (shorter_length > longer_length) { - sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); - sz_pointer_swap((void **)&longer, (void **)&shorter); + _sz_swap(sz_size_t, longer_length, shorter_length); + _sz_swap(sz_cptr_t, longer, shorter); } // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. @@ -1291,7 +1291,7 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // } // Swap previous_distances and current_distances pointers - sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); + _sz_swap(sz_i32_t *, previous_distances, current_distances); } // Cache scalar before `free` call. diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index d241b69f..6d8086f2 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -863,6 +863,16 @@ SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { retur */ #define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) +/** + * Convenience macro to swap two values of the same type. + */ +#define _sz_swap(type, a, b) \ + do { \ + type _tmp = (a); \ + (a) = (b); \ + (b) = _tmp; \ + } while (0) + /** @brief Branchless minimum function for two signed 32-bit integers. */ SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } From 0c38bff7399f28494fa39593488f77105ab3646a Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 16 Feb 2025 23:01:56 +0000 Subject: [PATCH 105/751] Break: Pointer-sized N-gram Sorting This huge commit brings many new sorting APIs, as well as a new naming convention to differentiate inplace sorting helpers from "argsort" operations. Also refactors the testing and micro-benchmarking helpers. --- README.md | 12 +- c/lib.c | 8 +- include/stringzilla/sort.h | 582 +++++++++++++++++++++------- include/stringzilla/stringzilla.h | 2 +- include/stringzilla/stringzilla.hpp | 29 +- include/stringzilla/types.h | 25 +- python/lib.c | 2 +- scripts/bench.hpp | 98 ++--- scripts/bench_sort.cpp | 118 ++++-- scripts/test.cpp | 14 +- 10 files changed, 602 insertions(+), 288 deletions(-) diff --git a/README.md b/README.md index 52f80d41..c5253c4c 100644 --- a/README.md +++ b/README.md @@ -229,7 +229,7 @@ __Who is this for?__ arm: 13.00 s - sz_sort
+ sz_sequence_argsort
x86: 1.91 · arm: 2.37 s @@ -429,7 +429,7 @@ lines: Strs = text.split(separator='\n') # 4 bytes per line overhead for under 4 batch: Strs = lines.sample(seed=42) # 10x faster than `random.choices` lines.shuffle(seed=42) # or shuffle all lines in place and shard with slices # WIP: lines.sort() # explodes to 16 bytes per line overhead for any length text -# WIP: sorted_order: tuple = lines.argsort() # similar to `numpy.argsort` +# WIP: argsort: tuple = lines.argsort() # similar to `numpy.argsort` ``` Working on [RedPajama][redpajama], addressing 20 Billion annotated english documents, one will need only 160 GB of RAM instead of Terabytes. @@ -633,7 +633,7 @@ sz_u64_t hash = sz_hash(haystack.start, haystack.length); // Perform collection level operations sz_sequence_t array = {your_handle, your_count, your_get_start, your_get_length}; -sz_sort(&array, &your_config); +sz_sequence_argsort(&array, &your_config); ```
@@ -1129,14 +1129,14 @@ C++ generic algorithm is not perfect either. There is no guarantee in the standard that `std::sort` won't allocate any memory. If you are running on embedded, in real-time or on 100+ CPU cores per node, you may want to avoid that. StringZilla doesn't solve the general case, but hopes to improve the performance for strings. -Use `sz_sort`, or the high-level `sz::sorted_order`, which can be used sort any collection of elements convertible to `sz::string_view`. +Use `sz_sequence_argsort`, or the high-level `sz::argsort`, which can be used sort any collection of elements convertible to `sz::string_view`. ```cpp std::vector data({"c", "b", "a"}); -std::vector order = sz::sorted_order(data); //< Simple shortcut +std::vector order = sz::argsort(data); //< Simple shortcut // Or, taking care of memory allocation: -sz::sorted_order(data.begin(), data.end(), order.data(), [](auto const &x) -> sz::string_view { return x; }); +sz::argsort(data.begin(), data.end(), order.data(), [](auto const &x) -> sz::string_view { return x; }); ``` ### Standard C++ Containers with String Keys diff --git a/c/lib.c b/c/lib.c index 361cd049..64e7b61a 100644 --- a/c/lib.c +++ b/c/lib.c @@ -189,7 +189,7 @@ typedef struct sz_implementations_t { sz_edit_distance_t edit_distance; sz_alignment_score_t alignment_score; - sz_sort_t sort; + sz_sequence_argsort_t sequence_argsort; } sz_implementations_t; @@ -225,7 +225,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->edit_distance = sz_edit_distance_serial; impl->alignment_score = sz_alignment_score_serial; - impl->sort = sz_sort_serial; + impl->sequence_argsort = sz_sequence_argsort_serial; #if SZ_USE_HASWELL if (caps & sz_cap_haswell_k) { @@ -416,8 +416,8 @@ SZ_DYNAMIC sz_ssize_t sz_alignment_score( // return sz_dispatch_table.alignment_score(a, a_length, b, b_length, subs, gap, alloc); } -SZ_DYNAMIC sz_bool_t sz_sort(sz_sequence_t const *array, sz_memory_allocator_t *alloc, sz_size_t *order) { - return sz_dispatch_table.sort(array, alloc, order); +SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *array, sz_memory_allocator_t *alloc, sz_size_t *order) { + return sz_dispatch_table.sequence_argsort(array, alloc, order); } SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 9af46a20..9ea19e8d 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -1,17 +1,24 @@ /** - * @brief Hardware-accelerated string collection sorting and intersections. + * @brief Hardware-accelerated string collection sorting. * @file sort.h * @author Ash Vardanian * - * Includes core APIs: + * Includes core APIs for `sz_sequence_t` string collections: * - * - `sz_qsort` - to sort an arbitrary string collection with QuickSort-like algorithm. - * - `sz_qsort_addresses` - to sort a collection of continuous pointer-sized integers with QuickSort-like algorithm. - * - `sz_msort` - to sort an arbitrary string collection with a MergeSort-like algorithm. - * - `sz_msort_addresses` - to sort a collection of continuous pointer-sized integers with a MergeSort-like algorithm. + * - `sz_sequence_argsort` - to get the sorting permutation of a string collection with QuickSort. + * - `sz_sequence_argsort_stable` - to get the stable-sorting permutation of a string collection with a MergeSort. + * + * The core idea of all following string algorithms is to sort strings not based on 1 character at a time, + * but on a larger "Pointer-sized N-grams" fitting in 4 or 8 bytes at once, on 32-bit or 64-bit architectures, + * respectively. In reality we may not use the full pointer size, but only a few bytes from it, and keep the rest + * for some metadata. + * + * That, however, means, that unsigned integer sorting is a constituent part of our string sorting and we can + * expose it as an additional set of APIs for the users: + * + * - `sz_pgrams_sort` - to inplace sort continuous pointer-sized integers with QuickSort. + * - `sz_pgrams_sort_stable` - to inplace stable-sort continuous pointer-sized integers with a MergeSort. * - * The `qsort` variants are not guaranteed to be stable. - * The `msort` variants are guaranteed to be stable. */ #ifndef STRINGZILLA_SORT_H_ #define STRINGZILLA_SORT_H_ @@ -27,49 +34,131 @@ extern "C" { #pragma region Core API /** - * @brief Faster `std::sort` for an arbitrary string sequence. + * @brief Faster @b arg-sort for an arbitrary @b string sequence, using QuickSort. + * Outputs the ::order of elements in the immutable ::sequence, that would sort it. + * The algorithm doesn't guarantee stability, meaning that the relative order of equal elements + * may not be preserved. * - * @param collection The collection of strings to sort. + * @param sequence The sequence of strings to sort. * @param alloc Memory allocator for temporary storage. - * @param order The output - indices of the sorted collection elements. + * @param order The output - indices of the sorted sequence elements. * @return Whether the operation was successful. */ -SZ_DYNAMIC sz_bool_t sz_qsort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); +SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -/** @copydoc sz_qsort */ -SZ_PUBLIC sz_bool_t sz_qsort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, +/** + * @brief Faster @b inplace `std::sort` for a continuous @b unsigned-integer sequence, using QuickSort. + * Overwrites the input ::sequence with the sorted sequence and exports the permutation ::order. + * The algorithm doesn't guarantee stability, meaning that the relative order of equal elements + * may not be preserved. + * + * @param pgrams The continuous buffer of unsigned integers to sort in place. + * @param count The number of elements in the sequence. + * @param alloc Memory allocator for temporary storage. + * @param order The output - indices of the sorted sequence elements. + * @return Whether the operation was successful. + */ +SZ_DYNAMIC sz_bool_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** @copydoc sz_qsort */ -SZ_PUBLIC sz_bool_t sz_qsort_skylake(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +/** + * @brief Faster @b arg-sort for an arbitrary @b string sequence, using MergeSort. + * Outputs the ::order of elements in the immutable ::sequence, that would sort it. + * The algorithm guarantees stability, meaning that the relative order of equal elements is preserved. + * + * This algorithm uses more memory than `sz_sequence_argsort`, but it's performance is more predictable. + * It's also preferred for very large inputs, as most memory access happens in a predictable sequential order. + * + * @param sequence The sequence of strings to sort. + * @param alloc Memory allocator for temporary storage. + * @param order The output - indices of the sorted sequence elements. + * @return Whether the operation was successful. + */ +SZ_DYNAMIC sz_bool_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -/** @copydoc sz_qsort */ -SZ_PUBLIC sz_bool_t sz_qsort_sve(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); +/** + * @brief Faster @b inplace `std::stable_sort sort` for a continuous @b unsigned-integer sequence, using MergeSort. + * Overwrites the input ::sequence with the sorted sequence and exports the permutation ::order. + * The algorithm guarantees stability, meaning that the relative order of equal elements is preserved. + * + * This algorithm uses more memory than `sz_pgrams_sort`, but it's performance is more predictable. + * It's also preferred for very large inputs, as most memory access happens in a predictable sequential order. + * + * @param pgrams The continuous buffer of unsigned integers to sort in place. + * @param count The number of elements in the sequence. + * @param alloc Memory allocator for temporary storage. + * @param order The output - indices of the sorted sequence elements. + * @return Whether the operation was successful. + */ +SZ_DYNAMIC sz_bool_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -#pragma endregion +/** @copydoc sz_sequence_argsort */ +SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -#pragma region Serial Implementation +/** @copydoc sz_pgrams_sort */ +SZ_PUBLIC sz_bool_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -/** - * The core idea of all following string algorithms is to sort strings not based on 1 character at a time, - * but on a larger integer word fitting in 4 or 8 bytes at once, on 32-bit or 64-bit architectures, respectively. - * That word is pointer-sized, but it may contain extra information aside from N characters. - */ -typedef sz_size_t _sz_sort_ngram_t; +/** @copydoc sz_sequence_argsort */ +SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); -SZ_PUBLIC void _sz_qsort_serial_export_prefixes( // - sz_sequence_t const *const collection, // - _sz_sort_ngram_t *const global_windows, sz_sorted_idx_t const *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +/** @copydoc sz_pgrams_sort */ +SZ_PUBLIC sz_bool_t sz_pgrams_sort_ice(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_sequence_argsort */ +SZ_PUBLIC sz_bool_t sz_sequence_argsort_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_pgrams_sort */ +SZ_PUBLIC sz_bool_t sz_pgrams_sort_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_sequence_argsort_stable */ +SZ_PUBLIC sz_bool_t sz_sequence_argsort_stable_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_pgrams_sort_stable */ +SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_sequence_argsort_stable */ +SZ_PUBLIC sz_bool_t sz_sequence_argsort_stable_ice(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_pgrams_sort_stable */ +SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_ice(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_sequence_argsort_stable */ +SZ_PUBLIC sz_bool_t sz_sequence_argsort_stable_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +/** @copydoc sz_pgrams_sort_stable */ +SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +#pragma endregion + +#pragma region Serial QuickSort Implementation + +SZ_PUBLIC void _sz_sequence_argsort_serial_export_next_pgrams( // + sz_sequence_t const *const sequence, // + sz_pgram_t *const global_pgrams, sz_sorted_idx_t const *const global_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // sz_size_t const start_character) { // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. - sz_size_t const window_capacity = sizeof(_sz_sort_ngram_t) - 1; + sz_size_t const window_capacity = sizeof(sz_pgram_t) - 1; // Perform the same operation for every string. - for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { + for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { // On the first recursion level, the `global_order` is the identity permutation. sz_sorted_idx_t const partial_order_index = global_order[i]; @@ -77,37 +166,37 @@ SZ_PUBLIC void _sz_qsort_serial_export_prefixes( _sz_assert(partial_order_index == i && "At start this must be an identity permutation."); // Get the string slice in global memory. - sz_cptr_t const source_str = collection->get_start(collection, partial_order_index); - sz_size_t const length = collection->get_length(collection, partial_order_index); + sz_cptr_t const source_str = sequence->get_start(sequence, partial_order_index); + sz_size_t const length = sequence->get_length(sequence, partial_order_index); sz_size_t const remaining_length = length > start_character ? length - start_character : 0; sz_size_t const exported_length = remaining_length > window_capacity ? window_capacity : remaining_length; // Fill with zeros, export a slice, and mark the exported length. - sz_size_t *target_integer = &global_windows[i]; - sz_ptr_t target_str = (sz_ptr_t)target_integer; - *target_integer = 0; + sz_pgram_t *target_pgram = &global_pgrams[i]; + sz_ptr_t target_str = (sz_ptr_t)target_pgram; + *target_pgram = 0; for (sz_size_t j = 0; j < exported_length; ++j) target_str[j] = source_str[j + start_character]; target_str[window_capacity] = exported_length; #if defined(_SZ_IS_64_BIT) - *target_integer = sz_u64_bytes_reverse(*target_integer); + *target_pgram = sz_u64_bytes_reverse(*target_pgram); #else - *target_integer = sz_u32_bytes_reverse(*target_integer); + *target_pgram = sz_u32_bytes_reverse(*target_pgram); #endif - _sz_assert( // - (length <= start_character) == (*target_integer == 0) && // + _sz_assert( // + (length <= start_character) == (*target_pgram == 0) && // "We can have a zero value if only the string is shorter than other strings at this position."); } // As our goal is to sort the strings using the exported integer "windows", // this is a good place to validate the correctness of the exported data. if (SZ_DEBUG && start_character == 0) - for (sz_size_t i = start_in_collection + 1; i < end_in_collection; ++i) { - _sz_sort_ngram_t const previous_window = global_windows[i - 1]; - _sz_sort_ngram_t const current_window = global_windows[i]; - sz_cptr_t const previous_str = collection->get_start(collection, i - 1); - sz_size_t const previous_length = collection->get_length(collection, i - 1); - sz_cptr_t const current_str = collection->get_start(collection, i); - sz_size_t const current_length = collection->get_length(collection, i); + for (sz_size_t i = start_in_sequence + 1; i < end_in_sequence; ++i) { + sz_pgram_t const previous_window = global_pgrams[i - 1]; + sz_pgram_t const current_window = global_pgrams[i]; + sz_cptr_t const previous_str = sequence->get_start(sequence, i - 1); + sz_size_t const previous_length = sequence->get_length(sequence, i - 1); + sz_cptr_t const current_str = sequence->get_start(sequence, i); + sz_size_t const current_length = sequence->get_length(sequence, i); sz_ordering_t const ordering = sz_order( // previous_str, previous_length > window_capacity ? window_capacity : previous_length, // current_str, current_length > window_capacity ? window_capacity : current_length); @@ -131,20 +220,20 @@ SZ_PUBLIC void _sz_qsort_serial_export_prefixes( * * @see https://en.wikipedia.org/wiki/Dutch_national_flag_problem */ -SZ_PUBLIC void _sz_qsort_serial_3way_partition( // - _sz_sort_ngram_t *const global_windows, sz_sorted_idx_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_sequence_argsort_serial_3way_partition( // + sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // sz_size_t *first_pivot_offset, sz_size_t *last_pivot_offset) { // Chose the pivot offset with Sedgewick's method. - _sz_sort_ngram_t pivot_window; + sz_pgram_t pivot_window; { - sz_size_t const middle_offset = start_in_collection + (end_in_collection - start_in_collection) / 2; - sz_size_t const last_offset = end_in_collection - 1; - sz_size_t const first_offset = start_in_collection; - _sz_sort_ngram_t const first_window = global_windows[first_offset]; - _sz_sort_ngram_t const middle_window = global_windows[middle_offset]; - _sz_sort_ngram_t const last_window = global_windows[last_offset]; + sz_size_t const middle_offset = start_in_sequence + (end_in_sequence - start_in_sequence) / 2; + sz_size_t const last_offset = end_in_sequence - 1; + sz_size_t const first_offset = start_in_sequence; + sz_pgram_t const first_window = global_pgrams[first_offset]; + sz_pgram_t const middle_window = global_pgrams[middle_offset]; + sz_pgram_t const last_window = global_pgrams[last_offset]; if (first_window < middle_window) { if (middle_window < last_window) { pivot_window = middle_window; } else if (first_window < last_window) { pivot_window = last_window; } @@ -158,120 +247,116 @@ SZ_PUBLIC void _sz_qsort_serial_3way_partition( } // Loop through the collection and move the elements around the pivot with the 3-way partitioning. - sz_size_t partitioning_progress = start_in_collection; // Current index. - sz_size_t less_than_pivot_offset = start_in_collection; // Boundary for elements < pivot_window. - sz_size_t greater_than_pivot_offset = end_in_collection - 1; // Boundary for elements > pivot_window. + sz_size_t partitioning_progress = start_in_sequence; // Current index. + sz_size_t smaller_offset = start_in_sequence; // Boundary for elements < pivot_window. + sz_size_t greater_offset = end_in_sequence - 1; // Boundary for elements > pivot_window. - while (partitioning_progress <= greater_than_pivot_offset) { + while (partitioning_progress <= greater_offset) { // Element is less than pivot: swap into the < pivot region. - if (global_windows[partitioning_progress] < pivot_window) { -#if defined(_SZ_IS_64_BIT) - sz_u64_swap(&global_order[partitioning_progress], &global_order[less_than_pivot_offset]); - sz_u64_swap(&global_windows[partitioning_progress], &global_windows[less_than_pivot_offset]); -#else - sz_u32_swap(&global_order[partitioning_progress], &global_order[less_than_pivot_offset]); - sz_u32_swap(&global_windows[partitioning_progress], &global_windows[less_than_pivot_offset]); -#endif + if (global_pgrams[partitioning_progress] < pivot_window) { + _sz_swap(sz_sorted_idx_t, global_order[partitioning_progress], global_order[smaller_offset]); + _sz_swap(sz_pgram_t, global_pgrams[partitioning_progress], global_pgrams[smaller_offset]); ++partitioning_progress; - ++less_than_pivot_offset; + ++smaller_offset; } // Element is greater than pivot: swap into the > pivot region. - else if (global_windows[partitioning_progress] > pivot_window) { -#if defined(_SZ_IS_64_BIT) - sz_u64_swap(&global_order[partitioning_progress], &global_order[greater_than_pivot_offset]); - sz_u64_swap(&global_windows[partitioning_progress], &global_windows[greater_than_pivot_offset]); -#else - sz_u32_swap(&global_order[partitioning_progress], &global_order[greater_than_pivot_offset]); - sz_u32_swap(&global_windows[partitioning_progress], &global_windows[greater_than_pivot_offset]); -#endif - --greater_than_pivot_offset; + else if (global_pgrams[partitioning_progress] > pivot_window) { + _sz_swap(sz_sorted_idx_t, global_order[partitioning_progress], global_order[greater_offset]); + _sz_swap(sz_pgram_t, global_pgrams[partitioning_progress], global_pgrams[greater_offset]); + --greater_offset; } // Element equals pivot_window: leave it in place. else { ++partitioning_progress; } } - *first_pivot_offset = less_than_pivot_offset; - *last_pivot_offset = greater_than_pivot_offset; + *first_pivot_offset = smaller_offset; + *last_pivot_offset = greater_offset; } -SZ_PUBLIC void _sz_qsort_serial_recursively( // - sz_sequence_t const *const collection, // - _sz_sort_ngram_t *const global_windows, sz_size_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // - sz_size_t const start_character) { +/** + * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort` and `sz_pgrams_sort`, + * and using the `_sz_sequence_argsort_serial_3way_partition` under the hood. + */ +SZ_PUBLIC void _sz_sequence_argsort_serial_recursively( // + sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence) { // Partition the collection around some pivot or 2 pivots in a 3-way partitioning sz_size_t first_pivot_index, last_pivot_index; - _sz_qsort_serial_3way_partition( // - global_windows, global_order, // - start_in_collection, end_in_collection, // + _sz_sequence_argsort_serial_3way_partition( // + global_pgrams, global_order, // + start_in_sequence, end_in_sequence, // &first_pivot_index, &last_pivot_index); // Recursively sort the left partition - if (start_in_collection < first_pivot_index) - _sz_qsort_serial_recursively(collection, global_windows, global_order, start_in_collection, first_pivot_index, - start_character); + if (start_in_sequence < first_pivot_index) + _sz_sequence_argsort_serial_recursively(global_pgrams, global_order, start_in_sequence, first_pivot_index); // Recursively sort the right partition - if (last_pivot_index + 1 < end_in_collection) - _sz_qsort_serial_recursively(collection, global_windows, global_order, last_pivot_index + 1, end_in_collection, - start_character); + if (last_pivot_index + 1 < end_in_sequence) + _sz_sequence_argsort_serial_recursively(global_pgrams, global_order, last_pivot_index + 1, end_in_sequence); } -SZ_PUBLIC void _sz_qsort_serial_next_window( // - sz_sequence_t const *const collection, // - _sz_sort_ngram_t *const global_windows, sz_size_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +/** + * @brief Recursive Quick-Sort adaptation for strings, that processes the strings a few N-grams at a time. + * It combines `_sz_sequence_argsort_serial_export_next_pgrams` and `_sz_sequence_argsort_serial_recursively`, + * recursively diving into the identical windows. + */ +SZ_PUBLIC void _sz_sequence_argsort_serial_next_pgrams( // + sz_sequence_t const *const sequence, // + sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // sz_size_t const start_character) { // Prepare the new range of windows - _sz_qsort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, - start_character); + _sz_sequence_argsort_serial_export_next_pgrams(sequence, global_pgrams, global_order, start_in_sequence, + end_in_sequence, start_character); // Sort current windows with a quicksort - _sz_qsort_serial_recursively(collection, global_windows, global_order, start_in_collection, end_in_collection, - start_character); + _sz_sequence_argsort_serial_recursively(global_pgrams, global_order, start_in_sequence, end_in_sequence); // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. - sz_size_t const window_capacity = sizeof(_sz_sort_ngram_t) - 1; + sz_size_t const window_capacity = sizeof(sz_pgram_t) - 1; // Repeat the procedure for the identical windows - sz_size_t nested_start = start_in_collection; - sz_size_t nested_end = start_in_collection; - while (nested_end != end_in_collection) { + sz_size_t nested_start = start_in_sequence; + sz_size_t nested_end = start_in_sequence; + while (nested_end != end_in_sequence) { // Find the end of the identical windows - _sz_sort_ngram_t current_window_integer = global_windows[nested_start]; - while (nested_end != end_in_collection && current_window_integer == global_windows[nested_end]) ++nested_end; + sz_pgram_t current_window_integer = global_pgrams[nested_start]; + while (nested_end != end_in_sequence && current_window_integer == global_pgrams[nested_end]) ++nested_end; // If the identical windows are not trivial and each string has more characters, sort them recursively sz_cptr_t current_window_str = (sz_cptr_t)¤t_window_integer; sz_size_t current_window_length = (sz_size_t)current_window_str[0]; //! The byte order was swapped - if (nested_end - nested_start > 1 && current_window_length == window_capacity) { - _sz_qsort_serial_next_window(collection, global_windows, global_order, nested_start, nested_end, - start_character + window_capacity); + int has_multiple_strings = nested_end - nested_start > 1; + int has_more_characters_in_each = current_window_length == window_capacity; + if (has_multiple_strings && has_more_characters_in_each) { + _sz_sequence_argsort_serial_next_pgrams(sequence, global_pgrams, global_order, nested_start, nested_end, + start_character + window_capacity); } // Move to the next nested_start = nested_end; } } -SZ_PUBLIC void _sz_qsort_serial_insertion(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC void _sz_sequence_argsort_serial_insertion(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // This algorithm needs no memory allocations: sz_unused(alloc); // Assume `order` is already initialized with 0, 1, 2, ... N. - for (sz_size_t i = 1; i < collection->count; ++i) { + for (sz_size_t i = 1; i < sequence->count; ++i) { sz_sorted_idx_t current_idx = order[i]; sz_size_t j = i; while (j > 0) { // Get the two strings to compare. sz_sorted_idx_t previous_idx = order[j - 1]; - sz_cptr_t previous_start = collection->get_start(collection, previous_idx); - sz_cptr_t current_start = collection->get_start(collection, current_idx); - sz_size_t previous_length = collection->get_length(collection, previous_idx); - sz_size_t current_length = collection->get_length(collection, current_idx); + sz_cptr_t previous_start = sequence->get_start(sequence, previous_idx); + sz_cptr_t current_start = sequence->get_start(sequence, current_idx); + sz_size_t previous_length = sequence->get_length(sequence, previous_idx); + sz_size_t current_length = sequence->get_length(sequence, current_idx); // Use the provided sz_order to compare. sz_ordering_t ordering = sz_order(previous_start, previous_length, current_start, current_length); @@ -287,16 +372,16 @@ SZ_PUBLIC void _sz_qsort_serial_insertion(sz_sequence_t const *collection, sz_me } } -SZ_PUBLIC sz_bool_t sz_qsort_serial(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // First, initialize the `order` with `std::iota`-like behavior. - for (sz_size_t i = 0; i != collection->count; ++i) order[i] = i; + for (sz_size_t i = 0; i != sequence->count; ++i) order[i] = i; // On very small collections - just use the quadratic-complexity insertion sort // without any smart optimizations or memory allocations. - if (collection->count <= 32) { - _sz_qsort_serial_insertion(collection, alloc, order); + if (sequence->count <= 32) { + _sz_sequence_argsort_serial_insertion(sequence, alloc, order); return sz_true_k; } @@ -311,51 +396,255 @@ SZ_PUBLIC sz_bool_t sz_qsort_serial(sz_sequence_t const *collection, sz_memory_a // Assuming that some strings may contain or even end with NULL bytes, we need to make sure, that their length // is included in those P-long words. So, in reality, we will be taking (P-1) bytes from each string on every // iteration of a recursive algorithm. - _sz_sort_ngram_t *windows = - (_sz_sort_ngram_t *)alloc->allocate(collection->count * sizeof(_sz_sort_ngram_t), alloc); + sz_size_t memory_usage = sequence->count * sizeof(sz_pgram_t); + sz_pgram_t *windows = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); if (!windows) return sz_false_k; - // Recursively sort the whole collection. - _sz_qsort_serial_next_window(collection, windows, order, 0, collection->count, 0); + // Recursively sort the whole sequence. + _sz_sequence_argsort_serial_next_pgrams(sequence, windows, order, 0, sequence->count, 0); // Free temporary storage. - alloc->free(windows, collection->count * sizeof(_sz_sort_ngram_t), alloc); + alloc->free(windows, memory_usage, alloc); + return sz_true_k; +} + +SZ_PUBLIC sz_bool_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + sz_unused(alloc); + // First, initialize the `order` with `std::iota`-like behavior. + for (sz_size_t i = 0; i != count; ++i) order[i] = i; + // Reuse the string sorting algorithm for sorting the "pgrams". + _sz_sequence_argsort_serial_recursively((sz_pgram_t *)pgrams, order, 0, count); + return sz_true_k; +} + +#pragma endregion // Serial QuickSort Implementation + +#pragma region Serial MergeSort Implementation + +/** + * @brief A scalar sorting network for 8 elements that reorders both the keys + * and their corresponding offsets in only 19 comparisons, the most efficient + * variant currently known. + * @see https://en.wikipedia.org/wiki/Sorting_network + * + * The network consists of 6 stages with the following compare–swap pairs: + * + * Stage 1: (0,1), (2,3), (4,5), (6,7) + * Stage 2: (0,2), (1,3), (4,6), (5,7) + * Stage 3: (1,2), (5,6) + * Stage 4: (0,4), (1,5), (2,6), (3,7) + * Stage 5: (2,4), (3,5) + * Stage 6: (1,2), (3,4), (5,6) + */ +void _sz_sequence_argsort_stable_serial_8x_network(sz_pgram_t *keys, sz_sorted_idx_t *offsets) { + +#define _sz_sequence_argsort_stable_8x_conditional_swap(i, j) \ + do { \ + if (keys[i] > keys[j]) { \ + _sz_swap(sz_pgram_t, keys[i], keys[j]); \ + _sz_swap(sz_sorted_idx_t, offsets[i], offsets[j]); \ + } \ + } while (0) + + // Stage 1: Compare–swap adjacent pairs. + _sz_sequence_argsort_stable_8x_conditional_swap(0, 1); + _sz_sequence_argsort_stable_8x_conditional_swap(2, 3); + _sz_sequence_argsort_stable_8x_conditional_swap(4, 5); + _sz_sequence_argsort_stable_8x_conditional_swap(6, 7); + + // Stage 2: Compare–swap with stride 2. + _sz_sequence_argsort_stable_8x_conditional_swap(0, 2); + _sz_sequence_argsort_stable_8x_conditional_swap(1, 3); + _sz_sequence_argsort_stable_8x_conditional_swap(4, 6); + _sz_sequence_argsort_stable_8x_conditional_swap(5, 7); + + // Stage 3: Compare–swap between middle elements. + _sz_sequence_argsort_stable_8x_conditional_swap(1, 2); + _sz_sequence_argsort_stable_8x_conditional_swap(5, 6); + + // Stage 4: Compare–swap across the two halves. + _sz_sequence_argsort_stable_8x_conditional_swap(0, 4); + _sz_sequence_argsort_stable_8x_conditional_swap(1, 5); + _sz_sequence_argsort_stable_8x_conditional_swap(2, 6); + _sz_sequence_argsort_stable_8x_conditional_swap(3, 7); + + // Stage 5: Compare–swap within each half. + _sz_sequence_argsort_stable_8x_conditional_swap(2, 4); + _sz_sequence_argsort_stable_8x_conditional_swap(3, 5); + + // Stage 6: Final compare–swap of adjacent elements. + _sz_sequence_argsort_stable_8x_conditional_swap(1, 2); + _sz_sequence_argsort_stable_8x_conditional_swap(3, 4); + _sz_sequence_argsort_stable_8x_conditional_swap(5, 6); + +#undef _sz_sequence_argsort_stable_8x_conditional_swap + + // Validate the sorting network. + if (SZ_DEBUG) + for (sz_size_t i = 1; i < 8; ++i) + _sz_assert(keys[i - 1] <= keys[i] && "The sorting network must sort the keys in ascending order."); +} + +/** + * @brief Helper function similar to `std::set_union` over pairs of integers and their original indices. + * @see https://en.cppreference.com/w/cpp/algorithm/set_union + */ +void _sz_sequence_argsort_stable_serial_merge( // + sz_pgram_t const *first_pgrams, sz_sorted_idx_t const *first_indices, sz_size_t first_count, // + sz_pgram_t const *second_pgrams, sz_sorted_idx_t const *second_indices, sz_size_t second_count, // + sz_pgram_t *result_pgrams, sz_sorted_idx_t *result_indices) { + + // Compute the end pointers for each input array + sz_pgram_t const *const first_end = first_pgrams + first_count; + sz_pgram_t const *const second_end = second_pgrams + second_count; + + // Merge until one array is exhausted + while (first_pgrams < first_end && second_pgrams < second_end) { + if (*first_pgrams < *second_pgrams) { + *result_pgrams++ = *first_pgrams++; + *result_indices++ = *first_indices++; + } + else if (*second_pgrams < *first_pgrams) { + *result_pgrams++ = *second_pgrams++; + *result_indices++ = *second_indices++; + } + else { + // Equal keys: for stability, choose the one from the first array + *result_pgrams++ = *first_pgrams; + *result_indices++ = *first_indices; + ++first_pgrams; + ++first_indices; + ++second_pgrams; + ++second_indices; + } + } + + // Copy any remaining elements from the first array + while (first_pgrams < first_end) { + *result_pgrams++ = *first_pgrams++; + *result_indices++ = *first_indices++; + } + + // Copy any remaining elements from the second array + while (second_pgrams < second_end) { + *result_pgrams++ = *second_pgrams++; + *result_indices++ = *second_indices++; + } +} + +SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + + // First, initialize the `order` with `std::iota`-like behavior. + for (sz_size_t i = 0; i != count; ++i) order[i] = i; + + // Go through short chunks of 8 elements and sort them with a sorting network. + for (sz_size_t i = 0; i + 8 <= count; i += 8) _sz_sequence_argsort_stable_serial_8x_network(pgrams + i, order + i); + + // For the tail of the array, sort it with insertion sort. + for (sz_size_t i = count & ~7; i < count; i++) { + sz_pgram_t current_address = pgrams[i]; + sz_sorted_idx_t current_idx = order[i]; + sz_size_t j = i; + while (j > 0 && pgrams[j - 1] > current_address) { + pgrams[j] = pgrams[j - 1]; + order[j] = order[j - 1]; + --j; + } + pgrams[j] = current_address; + order[j] = current_idx; + } + + // At this point, the array is partitioned into sorted runs. + // We'll now merge these runs until the whole array is sorted. + // Allocate temporary memory to hold merged results: + // - one block for keys (`sz_pgram_t`) + // - one block for indices (`sz_sorted_idx_t`) + sz_size_t memory_usage = sizeof(sz_pgram_t) * count + sizeof(sz_sorted_idx_t) * count; + sz_pgram_t *pgrams_temporary = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); + sz_sorted_idx_t *order_temporary = (sz_sorted_idx_t *)(pgrams_temporary + count); + if (!pgrams_temporary) return sz_false_k; + + // Set initial run size (the sorted chunks). + sz_size_t run_size = 8; + + // Pointers for current source and destination arrays. + sz_pgram_t *src_pgrams = pgrams; + sz_sorted_idx_t *src_order = order; + sz_pgram_t *dst_pgrams = pgrams_temporary; + sz_sorted_idx_t *dst_order = order_temporary; + + // Merge sorted runs in a bottom-up manner until the run size covers the whole array. + while (run_size < count) { + // Process adjacent runs. + for (sz_size_t i = 0; i < count; i += run_size * 2) { + // Determine the number of elements in the left run. + sz_size_t left_count = run_size; + if (i + left_count > count) { left_count = count - i; } + + // Determine the number of elements in the right run. + sz_size_t right_count = run_size; + if (i + run_size >= count) { right_count = 0; } + else if (i + run_size + right_count > count) { right_count = count - (i + run_size); } + + // Merge the two runs: + _sz_sequence_argsort_stable_serial_merge( // + src_pgrams + i, src_order + i, left_count, // + src_pgrams + i + run_size, src_order + i + run_size, right_count, // + dst_pgrams + i, dst_order + i); + } + + // Swap the roles of the source and destination arrays. + _sz_swap(sz_pgram_t *, src_pgrams, dst_pgrams); + _sz_swap(sz_sorted_idx_t *, src_order, dst_order); + + // Double the run size for the next pass. + run_size *= 2; + } + + // If the final sorted result is not in the original array, copy the sorted results back. + if (src_pgrams != pgrams) + for (sz_size_t i = 0; i < count; ++i) pgrams[i] = src_pgrams[i], order[i] = src_order[i]; + + // Free the temporary memory used for merging. + alloc->free(pgrams_temporary, memory_usage, alloc); return sz_true_k; } -#pragma endregion // Serial Implementation +#pragma endregion // Serial MergeSort Implementation #pragma region Ice Lake Implementation -SZ_PUBLIC void _sz_qsort_ice_recursively( // - sz_sequence_t const *const collection, // - _sz_sort_ngram_t *const global_windows, sz_size_t *const global_order, // - sz_size_t const start_in_collection, sz_size_t const end_in_collection, // +SZ_PUBLIC void _sz_sequence_argsort_ice_recursively( // + sz_sequence_t const *const collection, // + sz_pgram_t *const global_pgrams, sz_size_t *const global_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // sz_size_t const start_character) { // Prepare the new range of windows - _sz_qsort_serial_export_prefixes(collection, global_windows, global_order, start_in_collection, end_in_collection, - start_character); + _sz_sequence_argsort_serial_export_next_pgrams(collection, global_pgrams, global_order, start_in_sequence, + end_in_sequence, start_character); // We can implement a form of a Radix sort here, that will count the number of elements with // a certain bit set. The naive approach may require too many loops over data. A more "vectorized" // approach would be to maintain a histogram for several bits at once. For 4 bits we will // need 2^4 = 16 counters. sz_size_t histogram[16] = {0}; - for (sz_size_t byte_in_window = 0; byte_in_window != sizeof(_sz_sort_ngram_t); ++byte_in_window) { + for (sz_size_t byte_in_window = 0; byte_in_window != sizeof(sz_pgram_t); ++byte_in_window) { // First sort based on the low nibble of each byte. - for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { - sz_size_t const byte = (global_windows[i] >> (byte_in_window * 8)) & 0xFF; + for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { + sz_size_t const byte = (global_pgrams[i] >> (byte_in_window * 8)) & 0xFF; ++histogram[byte]; } - sz_size_t offset = start_in_collection; + sz_size_t offset = start_in_sequence; for (sz_size_t i = 0; i != 16; ++i) { sz_size_t const count = histogram[i]; histogram[i] = offset; offset += count; } - for (sz_size_t i = start_in_collection; i < end_in_collection; ++i) { - sz_size_t const byte = (global_windows[i] >> (byte_in_window * 8)) & 0xFF; + for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { + sz_size_t const byte = (global_pgrams[i] >> (byte_in_window * 8)) & 0xFF; global_order[histogram[byte]] = i; ++histogram[byte]; } @@ -370,8 +659,9 @@ SZ_PUBLIC void _sz_qsort_ice_recursively( // #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_bool_t sz_qsort(sz_sequence_t const *collection, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { - return sz_qsort_serial(collection, alloc, order); +SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + return sz_sequence_argsort_serial(sequence, alloc, order); } #endif // !SZ_DYNAMIC_DISPATCH diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index ba405700..0b23b33b 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -47,7 +47,7 @@ #include "memory.h" // `sz_copy`, `sz_move`, `sz_fill` #include "similarity.h" // `sz_edit_distance`, `sz_alignment_score` #include "small_string.h" // `sz_string_t`, `sz_string_init`, `sz_string_free` -#include "sort.h" // `sz_sort`, `sz_sort_partial`, `sz_partition` +#include "sort.h" // `sz_sequence_argsort`, `sz_pgrams_sort`, `sz_pgrams_sort_stable` #include "types.h" // `sz_size_t`, `sz_bool_t`, `sz_ordering_t` #ifdef __cplusplus diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 664ce607..0a4737ad 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -3974,8 +3974,8 @@ void randomize(basic_string_slice string, string_view alphabet = "ab using sorted_idx_t = sz_sorted_idx_t; /** - * @brief Internal data-structure used to forward the arguments to the `sz_sort` function. - * @see sorted_order + * @brief Internal data-structure used to forward the arguments to the `sz_sequence_argsort` function. + * @see argsort */ template struct _sequence_args { @@ -4004,18 +4004,18 @@ sz_size_t _call_sequence_member_length(struct sz_sequence_t const *sequence, sz_ /** * @brief Computes the permutation of an array, that would lead to sorted order. * The elements of the array must be convertible to a `string_view` with the given extractor. - * Unlike the `sz_sort` C interface, overwrites the output array. + * Unlike the `sz_sequence_argsort` C interface, overwrites the output array. * * @param[in] begin The pointer to the first element of the array. * @param[in] end The pointer to the element after the last element of the array. * @param[out] order The pointer to the output array of indices, that will be populated with the permutation. * @param[in] extractor The function object that extracts the string from the object. * - * @see sz_sort + * @see sz_sequence_argsort */ template -void sorted_order(objects_type_ const *begin, objects_type_ const *end, sorted_idx_t *order, - string_extractor_ &&extractor) noexcept { +void argsort(objects_type_ const *begin, objects_type_ const *end, sorted_idx_t *order, + string_extractor_ &&extractor) noexcept { // Pack the arguments into a single structure to reference it from the callback. _sequence_args args = {begin, static_cast(end - begin), order, @@ -4030,7 +4030,8 @@ void sorted_order(objects_type_ const *begin, objects_type_ const *end, sorted_i array.get_length = _call_sequence_member_length; using sz_alloc_type = sz_memory_allocator_t; - _with_alloc>([&](sz_alloc_type &alloc) { return sz_sort(&array, &alloc, order); }); + _with_alloc>( + [&](sz_alloc_type &alloc) { return sz_sequence_argsort(&array, &alloc, order); }); } #if !SZ_AVOID_STL @@ -4075,10 +4076,10 @@ std::bitset hashes_fingerprint(basic_string const &str * @throw `std::bad_alloc` if the allocation fails. */ template -std::vector sorted_order( // +std::vector argsort( // objects_type_ const *begin, objects_type_ const *end, string_extractor_ &&extractor) noexcept(false) { std::vector order(end - begin); - sorted_order(begin, end, order.data(), std::forward(extractor)); + argsort(begin, end, order.data(), std::forward(extractor)); return order; } @@ -4088,10 +4089,10 @@ std::vector sorted_order( // * @throw `std::bad_alloc` if the allocation fails. */ template -std::vector sorted_order(string_like_type_ const *begin, string_like_type_ const *end) noexcept(false) { +std::vector argsort(string_like_type_ const *begin, string_like_type_ const *end) noexcept(false) { static_assert( // std::is_convertible::value, "The type must be convertible to string_view."); - return sorted_order(begin, end, [](string_like_type_ const &s) -> string_view { return s; }); + return argsort(begin, end, [](string_like_type_ const &s) -> string_view { return s; }); } /** @@ -4100,11 +4101,11 @@ std::vector sorted_order(string_like_type_ const *begin, string_li * @throw `std::bad_alloc` if the allocation fails. */ template -std::vector sorted_order(std::vector const &array) noexcept(false) { +std::vector argsort(std::vector const &array) noexcept(false) { static_assert( // std::is_convertible::value, "The type must be convertible to string_view."); - return sorted_order(array.data(), array.data() + array.size(), - [](string_like_type_ const &s) -> string_view { return s; }); + return argsort(array.data(), array.data() + array.size(), + [](string_like_type_ const &s) -> string_view { return s; }); } #endif diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 6d8086f2..01d090b2 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -325,6 +325,7 @@ typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching f struct sz_sequence_t; // Forward declaration of an ordered collection of strings typedef sz_size_t sz_sorted_idx_t; // Index of a sorted string in a list of strings +typedef sz_size_t sz_pgram_t; // "Pointer-sized N-gram" of a string typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> @@ -481,8 +482,8 @@ typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, sz_error_cost_t, sz_memory_allocator_t *); -/** @brief Signature of ::sz_sort. */ -typedef sz_bool_t (*sz_sort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); +/** @brief Signature of ::sz_sequence_argsort. */ +typedef sz_bool_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); #pragma endregion @@ -644,8 +645,6 @@ SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_le typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); typedef struct sz_sequence_t { void const *handle; @@ -984,24 +983,6 @@ SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { return x; } -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { - sz_u64_t t = *a; - *a = *b; - *b = t; -} - -/** - * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. - */ -SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { - void *t = *a; - *a = *b; - *b = t; -} - /** * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. */ diff --git a/python/lib.c b/python/lib.c index c5346772..1406691c 100644 --- a/python/lib.c +++ b/python/lib.c @@ -3214,7 +3214,7 @@ static sz_bool_t Strs_sort_(Strs *self, sz_string_view_t **parts_output, sz_sort sequence.get_start = parts_get_start; sequence.get_length = parts_get_length; for (sz_sorted_idx_t i = 0; i != sequence.count; ++i) sequence.order[i] = i; - sz_sort(&sequence); + sz_sequence_argsort(&sequence); // Export results *parts_output = parts; diff --git a/scripts/bench.hpp b/scripts/bench.hpp index b321fa7e..cbec9bf5 100644 --- a/scripts/bench.hpp +++ b/scripts/bench.hpp @@ -44,24 +44,24 @@ using binary_function_t = std::function +template struct tracked_function_gt { std::string name {""}; - function_type function {nullptr}; + function_type_ function {nullptr}; bool needs_testing {false}; std::size_t failed_count; std::vector failed_strings; benchmark_result_t results; - tracked_function_gt(std::string name = "", function_type function = nullptr, bool needs_testing = false) + tracked_function_gt(std::string name = "", function_type_ function = nullptr, bool needs_testing = false) : name(name), function(function), needs_testing(needs_testing), failed_count(0), failed_strings(), results() {} tracked_function_gt(tracked_function_gt const &) = default; tracked_function_gt &operator=(tracked_function_gt const &) = default; void print() const { - bool is_binary = std::is_same(); + bool is_binary = std::is_same(); // If failures have occurred, output them to file to simplify the debugging process. bool contains_failures = !failed_strings.empty(); @@ -229,35 +229,46 @@ inline sz_string_view_t to_c(sz::string_view str) noexcept { return {str.data(), inline sz_string_view_t to_c(sz::string const &str) noexcept { return {str.data(), str.size()}; } inline sz_string_view_t to_c(sz_string_view_t str) noexcept { return str; } +/** + * @brief Invoke the same function many times, until the total time elapsed exceeds the limit. + * @return Total seconds elapsed. + */ +template +seconds_t repeat_until_limit(function_type_ &&function) { + + namespace stdc = std::chrono; + using clock_t = stdc::high_resolution_clock; + clock_t::time_point start_time = clock_t::now(); + seconds_t seconds = 0; + + while (seconds < seconds_per_benchmark) { + function(); + clock_t::time_point current_time = clock_t::now(); + seconds = stdc::duration_cast(current_time - start_time).count() / 1.e9; + } + return seconds; +} + /** * @brief Loop over all elements in a dataset in somewhat random order, benchmarking the function cost. * @param strings Strings to loop over. Length must be a power of two. * @param function Function to be applied to each `sz_string_view_t`. Must return the number of bytes processed. * @return Number of seconds per iteration. */ -template -benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&function) { +template +benchmark_result_t bench_on_tokens(strings_type_ &&strings, function_type_ &&function) { - namespace stdc = std::chrono; - using clock_t = stdc::high_resolution_clock; - clock_t::time_point t1 = clock_t::now(); benchmark_result_t result; - std::size_t lookup_mask = bit_floor(strings.size()) - 1; - - while (true) { + std::size_t const lookup_mask = bit_floor(strings.size()) - 1; + result.seconds = repeat_until_limit([&]() { // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking - { - result.bytes_passed += function(strings[(result.iterations + 0) & lookup_mask]) + - function(strings[(result.iterations + 1) & lookup_mask]) + - function(strings[(result.iterations + 2) & lookup_mask]) + - function(strings[(result.iterations + 3) & lookup_mask]); - result.iterations += 4; - } - - clock_t::time_point t2 = clock_t::now(); - result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; - if (result.seconds > seconds_per_benchmark) break; - } + result.bytes_passed += // + function(strings[(result.iterations + 0) & lookup_mask]) + + function(strings[(result.iterations + 1) & lookup_mask]) + + function(strings[(result.iterations + 2) & lookup_mask]) + + function(strings[(result.iterations + 3) & lookup_mask]); + result.iterations += 4; + }); return result; } @@ -269,31 +280,22 @@ benchmark_result_t bench_on_tokens(strings_type &&strings, function_type &&funct * Must return the number of bytes processed. * @return Number of seconds per iteration. */ -template -benchmark_result_t bench_on_token_pairs(strings_type &&strings, function_type &&function) { +template +benchmark_result_t bench_on_token_pairs(strings_type_ &&strings, function_type_ &&function) { - namespace stdc = std::chrono; - using clock_t = stdc::high_resolution_clock; - clock_t::time_point t1 = clock_t::now(); benchmark_result_t result; std::size_t lookup_mask = bit_floor(strings.size()) - 1; std::size_t largest_prime = static_cast(18446744073709551557ull); - - while (true) { + result.seconds = repeat_until_limit([&]() { // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking - { - auto second = (result.iterations * largest_prime) & lookup_mask; - result.bytes_passed += function(strings[(result.iterations + 0) & lookup_mask], strings[second]) + - function(strings[(result.iterations + 1) & lookup_mask], strings[second]) + - function(strings[(result.iterations + 2) & lookup_mask], strings[second]) + - function(strings[(result.iterations + 3) & lookup_mask], strings[second]); - result.iterations += 4; - } - - clock_t::time_point t2 = clock_t::now(); - result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; - if (result.seconds > seconds_per_benchmark) break; - } + auto second_index = (result.iterations * largest_prime) & lookup_mask; + result.bytes_passed += // + function(strings[(result.iterations + 0) & lookup_mask], strings[second_index]) + + function(strings[(result.iterations + 1) & lookup_mask], strings[second_index]) + + function(strings[(result.iterations + 2) & lookup_mask], strings[second_index]) + + function(strings[(result.iterations + 3) & lookup_mask], strings[second_index]); + result.iterations += 4; + }); return result; } @@ -301,8 +303,8 @@ benchmark_result_t bench_on_token_pairs(strings_type &&strings, function_type && /** * @brief Evaluation for unary string operations: hashing. */ -template -void bench_unary_functions(strings_type &&strings, functions_type &&variants) { +template +void bench_unary_functions(strings_type_ &&strings, functions_type &&variants) { for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { auto &variant = variants[variant_idx]; @@ -337,8 +339,8 @@ void bench_unary_functions(strings_type &&strings, functions_type &&variants) { /** * @brief Evaluation for binary string operations: equality, ordering, prefix, suffix, distance. */ -template -void bench_binary_functions(strings_type &&strings, functions_type &&variants) { +template +void bench_binary_functions(strings_type_ &&strings, functions_type &&variants) { for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { auto &variant = variants[variant_idx]; diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 75800582..729ac856 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -19,8 +19,7 @@ using namespace ashvardanian::stringzilla::scripts; namespace sz = ashvardanian::stringzilla; using strings_t = std::vector; -using idx_t = sz_size_t; -using permute_t = std::vector; +using permute_t = std::vector; #pragma region C callbacks @@ -54,87 +53,128 @@ static int _get_qsort_order(const void *a, const void *b, void *arg) { #pragma endregion -void expect_sorted(strings_t const &strings, permute_t const &permute) { +template +void expect_sorted(strings_type_ const &strings, permute_t const &permute) { if (!std::is_sorted(permute.begin(), permute.end(), [&](std::size_t i, std::size_t j) { return strings[i] < strings[j]; })) throw std::runtime_error("Sorting failed!"); } -template -void bench_permute(char const *name, strings_t &strings, permute_t &permute, algo_at &&algo) { - namespace stdc = std::chrono; - using clock_t = stdc::high_resolution_clock; - constexpr std::size_t iterations = 3; - clock_t::time_point t1 = clock_t::now(); +template +void bench_permute(char const *name, callback_type_ &&callback) { // Run multiple iterations - for (std::size_t i = 0; i != iterations; ++i) { - std::iota(permute.begin(), permute.end(), 0); - algo(strings, permute); - } + std::size_t iterations = 0; + seconds_t duration = repeat_until_limit([&]() { + callback(); + iterations++; + }); // Measure elapsed time - clock_t::time_point t2 = clock_t::now(); - double dif = stdc::duration_cast(t2 - t1).count() * 1.0; - double millisecs = dif / (iterations * 1e6); - std::printf("Elapsed time is %.2lf milliseconds/iteration for %s.\n", millisecs, name); + duration /= iterations; + if (duration >= 0.1) { std::printf("Elapsed time is %.2lf seconds for %s.\n", duration, name); } + else if (duration >= 0.001) { std::printf("Elapsed time is %.2lf milliseconds for %s.\n", duration * 1e3, name); } + else { std::printf("Elapsed time is %.2lf microseconds for %s.\n", duration * 1e6, name); } } int main(int argc, char const **argv) { std::printf("StringZilla. Starting sorting benchmarks.\n"); - dataset_t dataset = prepare_benchmark_environment(argc, argv); - strings_t strings {dataset.tokens.begin(), dataset.tokens.end()}; + dataset_t const dataset = prepare_benchmark_environment(argc, argv); + strings_t const strings {dataset.tokens.begin(), dataset.tokens.end()}; + permute_t permute(strings.size()); + using allocator_t = std::allocator; + + // Before sorting the strings themselves, which is a heavy operation, let's sort some prefixes + // to understand how the sorting algorithm behaves. + std::vector pgrams(strings.size()); + std::transform(strings.begin(), strings.end(), pgrams.begin(), [](std::string const &str) { + sz_pgram_t pgram = 0; + std::memcpy(&pgram, str.c_str(), (std::min)(sizeof(pgram), str.size())); + return pgram; + }); + + // Sorting P-grams + bench_permute("std::sort(pgrams)", [&]() { + std::iota(permute.begin(), permute.end(), 0); + std::sort(permute.begin(), permute.end(), + [&](sz_sorted_idx_t i, sz_sorted_idx_t j) { return pgrams[i] < pgrams[j]; }); + }); + expect_sorted(pgrams, permute); + + // Unlike the `std::sort` adaptation above, the `sz_pgrams_sort_serial` also sorts the input array inplace + std::vector pgrams_sorted(strings.size()); + bench_permute("sz_pgrams_sort_serial", [&]() { + std::copy(pgrams.begin(), pgrams.end(), pgrams_sorted.begin()); + std::iota(permute.begin(), permute.end(), 0); + sz::_with_alloc([&](sz_memory_allocator_t &alloc) { + return sz_pgrams_sort_serial(pgrams_sorted.data(), pgrams_sorted.size(), &alloc, permute.data()); + }); + }); + expect_sorted(pgrams, permute); - permute_t permute_base, permute_new; - permute_base.resize(strings.size()); - permute_new.resize(strings.size()); + // Unlike the `std::sort` adaptation above, the `sz_pgrams_sort_stable_serial` also sorts the input array inplace + bench_permute("sz_pgrams_sort_stable_serial", [&]() { + std::copy(pgrams.begin(), pgrams.end(), pgrams_sorted.begin()); + std::iota(permute.begin(), permute.end(), 0); + sz::_with_alloc([&](sz_memory_allocator_t &alloc) { + return sz_pgrams_sort_stable_serial(pgrams_sorted.data(), pgrams_sorted.size(), &alloc, permute.data()); + }); + }); + expect_sorted(pgrams, permute); - // Sorting - bench_permute("std::sort", strings, permute_base, [](strings_t const &strings, permute_t &permute) { - std::sort(permute.begin(), permute.end(), [&](idx_t i, idx_t j) { return strings[i] < strings[j]; }); + // Sorting strings + bench_permute("std::sort(positions)", [&]() { + std::iota(permute.begin(), permute.end(), 0); + std::sort(permute.begin(), permute.end(), + [&](sz_sorted_idx_t i, sz_sorted_idx_t j) { return strings[i] < strings[j]; }); }); - expect_sorted(strings, permute_base); + expect_sorted(strings, permute); - bench_permute("sz_sort_serial", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + bench_permute("sz_sequence_argsort", [&]() { + std::iota(permute.begin(), permute.end(), 0); sz_sequence_t array; array.count = strings.size(); array.handle = &strings; array.get_start = get_start; array.get_length = get_length; - sz::_with_alloc>( - [&](sz_memory_allocator_t &alloc) { return sz_sort_serial(&array, &alloc, permute.data()); }); + sz::_with_alloc( + [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort(&array, &alloc, permute.data()); }); }); - expect_sorted(strings, permute_new); + expect_sorted(strings, permute); #if __linux__ && defined(_GNU_SOURCE) && !defined(__BIONIC__) - bench_permute("qsort_r", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + bench_permute("qsort_r", [&]() { + std::iota(permute.begin(), permute.end(), 0); sz_sequence_t array; array.count = strings.size(); array.handle = &strings; array.get_start = get_start; array.get_length = get_length; - qsort_r(permute.data(), array.count, sizeof(sz_u64_t), _get_qsort_order, &array); + qsort_r(permute.data(), array.count, sizeof(sz_sorted_idx_t), _get_qsort_order, &array); }); - expect_sorted(strings, permute_new); + expect_sorted(strings, permute); #elif defined(_MSC_VER) - bench_permute("qsort_s", strings, permute_new, [](strings_t const &strings, permute_t &permute) { + bench_permute("qsort_s", [&]() { + std::iota(permute.begin(), permute.end(), 0); sz_sequence_t array; array.count = strings.size(); array.handle = &strings; array.get_start = get_start; array.get_length = get_length; - qsort_s(permute.data(), array.count, sizeof(sz_u64_t), _get_qsort_order, &array); + qsort_s(permute.data(), array.count, sizeof(sz_sorted_idx_t), _get_qsort_order, &array); }); - expect_sorted(strings, permute_new); + expect_sorted(strings, permute); #else sz_unused(_get_qsort_order); #endif std::printf("---- Stable Sorting:\n"); - bench_permute("std::stable_sort", strings, permute_base, [](strings_t const &strings, permute_t &permute) { - std::stable_sort(permute.begin(), permute.end(), [&](idx_t i, idx_t j) { return strings[i] < strings[j]; }); + bench_permute("std::stable_sort", [&]() { + std::iota(permute.begin(), permute.end(), 0); + std::stable_sort(permute.begin(), permute.end(), + [&](sz_sorted_idx_t i, sz_sorted_idx_t j) { return strings[i] < strings[j]; }); }); - expect_sorted(strings, permute_base); + expect_sorted(strings, permute); return 0; } diff --git a/scripts/test.cpp b/scripts/test.cpp index d8f0cdd6..7b3fe4db 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1596,9 +1596,9 @@ static void test_sequence_algorithms() { using order_t = std::vector; // Basic tests with predetermined orders. - assert_scoped(strs_t x({"a", "b", "c", "d"}), (void)0, sz::sorted_order(x) == order_t({0u, 1u, 2u, 3u})); - assert_scoped(strs_t x({"b", "c", "d", "a"}), (void)0, sz::sorted_order(x) == order_t({3u, 0u, 1u, 2u})); - assert_scoped(strs_t x({"b", "a", "d", "c"}), (void)0, sz::sorted_order(x) == order_t({1u, 0u, 3u, 2u})); + assert_scoped(strs_t x({"a", "b", "c", "d"}), (void)0, sz::argsort(x) == order_t({0u, 1u, 2u, 3u})); + assert_scoped(strs_t x({"b", "c", "d", "a"}), (void)0, sz::argsort(x) == order_t({3u, 0u, 1u, 2u})); + assert_scoped(strs_t x({"b", "a", "d", "c"}), (void)0, sz::argsort(x) == order_t({1u, 0u, 3u, 2u})); // Test on long strings of identical length. for (std::size_t string_length : {5u, 25u}) { @@ -1611,7 +1611,7 @@ static void test_sequence_algorithms() { // Run several iterations of fuzzy tests. for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); - auto order = sz::sorted_order(dataset); + auto order = sz::argsort(dataset); for (std::size_t i = 1; i < dataset.size(); ++i) assert(dataset[order[i - 1]] <= dataset[order[i]]); } } @@ -1626,7 +1626,7 @@ static void test_sequence_algorithms() { // Run several iterations of fuzzy tests. for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); - auto order = sz::sorted_order(dataset); + auto order = sz::argsort(dataset); for (std::size_t i = 1; i < dataset_size; ++i) { assert(dataset[order[i - 1]] <= dataset[order[i]]); } } } @@ -1642,7 +1642,7 @@ static void test_sequence_algorithms() { // Run several iterations of fuzzy tests. for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); - auto order = sz::sorted_order(dataset); + auto order = sz::argsort(dataset); for (std::size_t i = 1; i < dataset_size; ++i) { assert(dataset[order[i - 1]] <= dataset[order[i]]); } } } @@ -1656,7 +1656,7 @@ static void test_sequence_algorithms() { // Run several iterations of fuzzy tests. for (std::size_t experiment_idx = 0; experiment_idx < 10; ++experiment_idx) { std::shuffle(dataset.begin(), dataset.end(), global_random_generator()); - auto order = sz::sorted_order(dataset); + auto order = sz::argsort(dataset); for (std::size_t i = 1; i < dataset_size; ++i) { assert(dataset[order[i - 1]] <= dataset[order[i]]); } } } From db61d93a1c47f5de0faba07221446cb1a1021510 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 16 Feb 2025 23:31:57 +0000 Subject: [PATCH 106/751] Fix: Merge-step bug in stable sort --- include/stringzilla/sort.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 9ea19e8d..e25876cf 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -498,6 +498,7 @@ void _sz_sequence_argsort_stable_serial_merge( // Compute the end pointers for each input array sz_pgram_t const *const first_end = first_pgrams + first_count; sz_pgram_t const *const second_end = second_pgrams + second_count; + sz_pgram_t *const merged_begin = result_pgrams; // Merge until one array is exhausted while (first_pgrams < first_end && second_pgrams < second_end) { @@ -510,13 +511,11 @@ void _sz_sequence_argsort_stable_serial_merge( *result_indices++ = *second_indices++; } else { - // Equal keys: for stability, choose the one from the first array + // Equal keys: for stability, choose the one from the first array, and don't increment the second array *result_pgrams++ = *first_pgrams; *result_indices++ = *first_indices; ++first_pgrams; ++first_indices; - ++second_pgrams; - ++second_indices; } } @@ -531,6 +530,11 @@ void _sz_sequence_argsort_stable_serial_merge( *result_pgrams++ = *second_pgrams++; *result_indices++ = *second_indices++; } + + // Validate the merged result. + if (SZ_DEBUG) + for (sz_size_t i = 1; i < first_count + second_count; ++i) + _sz_assert(merged_begin[i - 1] <= merged_begin[i] && "The merged pgrams must be in ascending order."); } SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, From a38867fdccf76ac00926c0cb90077eb9be1c5e44 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 16 Feb 2025 23:33:42 +0000 Subject: [PATCH 107/751] Improve: Expose Insertion-sort helpers --- include/stringzilla/sort.h | 133 ++++++++++++++++++++++++------------- 1 file changed, 86 insertions(+), 47 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index e25876cf..b6fcdbc9 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -19,6 +19,11 @@ * - `sz_pgrams_sort` - to inplace sort continuous pointer-sized integers with QuickSort. * - `sz_pgrams_sort_stable` - to inplace stable-sort continuous pointer-sized integers with a MergeSort. * + * For cases, when the input is known to be tiny, we provide quadratic-complexity insertion sort adaptations: + * + * - `sz_sequence_argsort_with_insertion` - for string collections. + * - `sz_pgrams_sort_stable_with_insertion` - for continuous unsigned integers. + * */ #ifndef STRINGZILLA_SORT_H_ #define STRINGZILLA_SORT_H_ @@ -145,6 +150,73 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_sve(sz_pgram_t *pgrams, sz_size_t coun #pragma endregion +#pragma region Generic Helpers + +/** + * @brief Quadratic complexity insertion sort adjust for our @b argsort usecase. + * Needs no extra memory and is used as a fallback for small inputs. + */ +SZ_PUBLIC void sz_sequence_argsort_with_insertion(sz_sequence_t const *sequence, sz_sorted_idx_t *order) { + // Assume `order` is already initialized with 0, 1, 2, ... N. + for (sz_size_t i = 1; i < sequence->count; ++i) { + sz_sorted_idx_t current_idx = order[i]; + sz_size_t j = i; + while (j > 0) { + // Get the two strings to compare. + sz_sorted_idx_t previous_idx = order[j - 1]; + sz_cptr_t previous_start = sequence->get_start(sequence, previous_idx); + sz_cptr_t current_start = sequence->get_start(sequence, current_idx); + sz_size_t previous_length = sequence->get_length(sequence, previous_idx); + sz_size_t current_length = sequence->get_length(sequence, current_idx); + + // Use the provided sz_order to compare. + sz_ordering_t ordering = sz_order(previous_start, previous_length, current_start, current_length); + + // If the previous string is not greater than current_idx, we're done. + if (ordering != sz_greater_k) break; + + // Otherwise, shift the previous element to the right. + order[j] = order[j - 1]; + --j; + } + order[j] = current_idx; + } +} + +/** + * @brief Quadratic complexity insertion sort adjust for our @b pgram-sorting usecase. + * Needs no extra memory and is used as a fallback for small inputs. + */ + +SZ_PUBLIC void sz_pgrams_sort_stable_with_insertion(sz_pgram_t *pgrams, sz_size_t count, sz_sorted_idx_t *order) { + + // Assume `order` is already initialized with 0, 1, 2, ... N. + for (sz_size_t i = 1; i < count; ++i) { + // Save the current key and corresponding index. + sz_pgram_t current_key = pgrams[i]; + sz_sorted_idx_t current_idx = order[i]; + sz_size_t j = i; + + // Shift elements of the sorted region that are greater than the current key + // to the right. This loop stops as soon as the correct insertion point is found. + while (j > 0 && pgrams[j - 1] > current_key) { + pgrams[j] = pgrams[j - 1]; + order[j] = order[j - 1]; + --j; + } + + // Insert the current key and index into their proper location. + pgrams[j] = current_key; + order[j] = current_idx; + } + + if (SZ_DEBUG) + for (sz_size_t i = 1; i < count; ++i) + _sz_assert(pgrams[i - 1] <= pgrams[i] && "The pgrams should be sorted in ascending order."); +} + +#pragma endregion + #pragma region Serial QuickSort Implementation SZ_PUBLIC void _sz_sequence_argsort_serial_export_next_pgrams( // @@ -341,37 +413,6 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_next_pgrams( // } } -SZ_PUBLIC void _sz_sequence_argsort_serial_insertion(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { - // This algorithm needs no memory allocations: - sz_unused(alloc); - - // Assume `order` is already initialized with 0, 1, 2, ... N. - for (sz_size_t i = 1; i < sequence->count; ++i) { - sz_sorted_idx_t current_idx = order[i]; - sz_size_t j = i; - while (j > 0) { - // Get the two strings to compare. - sz_sorted_idx_t previous_idx = order[j - 1]; - sz_cptr_t previous_start = sequence->get_start(sequence, previous_idx); - sz_cptr_t current_start = sequence->get_start(sequence, current_idx); - sz_size_t previous_length = sequence->get_length(sequence, previous_idx); - sz_size_t current_length = sequence->get_length(sequence, current_idx); - - // Use the provided sz_order to compare. - sz_ordering_t ordering = sz_order(previous_start, previous_length, current_start, current_length); - - // If the previous string is not greater than current_idx, we're done. - if (ordering != sz_greater_k) break; - - // Otherwise, shift the previous element to the right. - order[j] = order[j - 1]; - --j; - } - order[j] = current_idx; - } -} - SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { @@ -381,7 +422,7 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz // On very small collections - just use the quadratic-complexity insertion sort // without any smart optimizations or memory allocations. if (sequence->count <= 32) { - _sz_sequence_argsort_serial_insertion(sequence, alloc, order); + sz_sequence_argsort_with_insertion(sequence, order); return sz_true_k; } @@ -543,22 +584,20 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != count; ++i) order[i] = i; + // On very small collections - just use the quadratic-complexity insertion sort + // without any smart optimizations or memory allocations. + if (count <= 32) { + sz_pgrams_sort_stable_with_insertion(pgrams, count, order); + return sz_true_k; + } + // Go through short chunks of 8 elements and sort them with a sorting network. - for (sz_size_t i = 0; i + 8 <= count; i += 8) _sz_sequence_argsort_stable_serial_8x_network(pgrams + i, order + i); + for (sz_size_t i = 0; i + 8u <= count; i += 8u) + _sz_sequence_argsort_stable_serial_8x_network(pgrams + i, order + i); // For the tail of the array, sort it with insertion sort. - for (sz_size_t i = count & ~7; i < count; i++) { - sz_pgram_t current_address = pgrams[i]; - sz_sorted_idx_t current_idx = order[i]; - sz_size_t j = i; - while (j > 0 && pgrams[j - 1] > current_address) { - pgrams[j] = pgrams[j - 1]; - order[j] = order[j - 1]; - --j; - } - pgrams[j] = current_address; - order[j] = current_idx; - } + sz_size_t const tail_count = count & 7u; + sz_pgrams_sort_stable_with_insertion(pgrams + count - tail_count, tail_count, order + count - tail_count); // At this point, the array is partitioned into sorted runs. // We'll now merge these runs until the whole array is sorted. @@ -589,8 +628,8 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c // Determine the number of elements in the right run. sz_size_t right_count = run_size; - if (i + run_size >= count) { right_count = 0; } - else if (i + run_size + right_count > count) { right_count = count - (i + run_size); } + if (i + left_count >= count) { right_count = 0; } + else if (i + left_count + right_count > count) { right_count = count - (i + left_count); } // Merge the two runs: _sz_sequence_argsort_stable_serial_merge( // From cd6859a56f7cdfba698d7389733871ec9cf45e8b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 17 Feb 2025 11:01:02 +0000 Subject: [PATCH 108/751] Add: Smaller Sorting Networks It yields no noticeable performance improvements --- include/stringzilla/sort.h | 352 ++++++++++++++++++++++--------------- 1 file changed, 215 insertions(+), 137 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index b6fcdbc9..977d29e1 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -150,7 +150,7 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_sve(sz_pgram_t *pgrams, sz_size_t coun #pragma endregion -#pragma region Generic Helpers +#pragma region Generic Public Helpers /** * @brief Quadratic complexity insertion sort adjust for our @b argsort usecase. @@ -210,16 +210,148 @@ SZ_PUBLIC void sz_pgrams_sort_stable_with_insertion(sz_pgram_t *pgrams, sz_size_ order[j] = current_idx; } - if (SZ_DEBUG) - for (sz_size_t i = 1; i < count; ++i) - _sz_assert(pgrams[i - 1] <= pgrams[i] && "The pgrams should be sorted in ascending order."); +#if SZ_DEBUG + for (sz_size_t i = 1; i < count; ++i) + _sz_assert(pgrams[i - 1] <= pgrams[i] && "The pgrams should be sorted in ascending order."); +#endif } -#pragma endregion +#pragma endregion // Generic Public Helpers + +#pragma region Generic Internal Helpers + +/** + * @brief Convenience macro for of conditional swap of "pgrams" and their indices for a sorting network. + * @see https://en.wikipedia.org/wiki/Sorting_network + */ +#define _sz_sequence_sorting_network_conditional_swap(i, j) \ + do { \ + if (pgrams[i] > pgrams[j]) { \ + _sz_swap(sz_pgram_t, pgrams[i], pgrams[j]); \ + _sz_swap(sz_sorted_idx_t, offsets[i], offsets[j]); \ + } \ + } while (0) + +/** + * @brief Sorting network for 2 elements is just a single compare–swap. + */ +SZ_INTERNAL void _sz_sequence_sorting_network_2x(sz_pgram_t *pgrams, sz_sorted_idx_t *offsets) { + _sz_sequence_sorting_network_conditional_swap(0, 1); +} + +/** + * @brief Sorting network for 3 elements. + * + * The network uses 3 compare–swap operations: + * + * Stage 1: (0, 1) + * Stage 2: (0, 2) + * Stage 3: (1, 2) + */ +SZ_INTERNAL void _sz_sequence_sorting_network_3x(sz_pgram_t *pgrams, sz_sorted_idx_t *offsets) { + + _sz_sequence_sorting_network_conditional_swap(0, 1); + _sz_sequence_sorting_network_conditional_swap(0, 2); + _sz_sequence_sorting_network_conditional_swap(1, 2); + +#if SZ_DEBUG + for (sz_size_t i = 1; i < 3; ++i) + _sz_assert(pgrams[i - 1] <= pgrams[i] && "Sorting network for 3 elements failed."); +#endif +} + +/** + * @brief Sorting network for 4 elements. + * + * The network uses 5 compare–swap operations: + * + * Stage 1: (0, 1) and (2, 3) + * Stage 2: (0, 2) + * Stage 3: (1, 3) + * Stage 4: (1, 2) + */ +SZ_INTERNAL void _sz_sequence_sorting_network_4x(sz_pgram_t *pgrams, sz_sorted_idx_t *offsets) { + + // Stage 1: Compare–swap adjacent pairs. + _sz_sequence_sorting_network_conditional_swap(0, 1); + _sz_sequence_sorting_network_conditional_swap(2, 3); + + // Stage 2: Compare–swap (0, 2) + _sz_sequence_sorting_network_conditional_swap(0, 2); + + // Stage 3: Compare–swap (1, 3) + _sz_sequence_sorting_network_conditional_swap(1, 3); + + // Stage 4: Final compare–swap (1, 2) + _sz_sequence_sorting_network_conditional_swap(1, 2); + +#if SZ_DEBUG + for (sz_size_t i = 1; i < 4; ++i) + _sz_assert(pgrams[i - 1] <= pgrams[i] && "Sorting network for 4 elements failed."); +#endif +} + +/** + * @brief A scalar sorting network for 8 elements that reorders both the pgrams + * and their corresponding offsets in only 19 comparisons, the most efficient + * variant currently known. + * + * The network consists of 6 stages with the following compare–swap pairs: + * + * Stage 1: (0,1), (2,3), (4,5), (6,7) + * Stage 2: (0,2), (1,3), (4,6), (5,7) + * Stage 3: (1,2), (5,6) + * Stage 4: (0,4), (1,5), (2,6), (3,7) + * Stage 5: (2,4), (3,5) + * Stage 6: (1,2), (3,4), (5,6) + */ +SZ_INTERNAL void _sz_sequence_sorting_network_8x(sz_pgram_t *pgrams, sz_sorted_idx_t *offsets) { + + // Stage 1: Compare–swap adjacent pairs. + _sz_sequence_sorting_network_conditional_swap(0, 1); + _sz_sequence_sorting_network_conditional_swap(2, 3); + _sz_sequence_sorting_network_conditional_swap(4, 5); + _sz_sequence_sorting_network_conditional_swap(6, 7); + + // Stage 2: Compare–swap with stride 2. + _sz_sequence_sorting_network_conditional_swap(0, 2); + _sz_sequence_sorting_network_conditional_swap(1, 3); + _sz_sequence_sorting_network_conditional_swap(4, 6); + _sz_sequence_sorting_network_conditional_swap(5, 7); + + // Stage 3: Compare–swap between middle elements. + _sz_sequence_sorting_network_conditional_swap(1, 2); + _sz_sequence_sorting_network_conditional_swap(5, 6); + + // Stage 4: Compare–swap across the two halves. + _sz_sequence_sorting_network_conditional_swap(0, 4); + _sz_sequence_sorting_network_conditional_swap(1, 5); + _sz_sequence_sorting_network_conditional_swap(2, 6); + _sz_sequence_sorting_network_conditional_swap(3, 7); + + // Stage 5: Compare–swap within each half. + _sz_sequence_sorting_network_conditional_swap(2, 4); + _sz_sequence_sorting_network_conditional_swap(3, 5); + + // Stage 6: Final compare–swap of adjacent elements. + _sz_sequence_sorting_network_conditional_swap(1, 2); + _sz_sequence_sorting_network_conditional_swap(3, 4); + _sz_sequence_sorting_network_conditional_swap(5, 6); + +#if SZ_DEBUG + // Validate the sorting network. + for (sz_size_t i = 1; i < 8; ++i) + _sz_assert(pgrams[i - 1] <= pgrams[i] && "The sorting network must sort the pgrams in ascending order."); +#endif +} + +#undef _sz_sequence_sorting_network_conditional_swap + +#pragma endregion // Generic Internal Helpers #pragma region Serial QuickSort Implementation -SZ_PUBLIC void _sz_sequence_argsort_serial_export_next_pgrams( // +SZ_INTERNAL void _sz_sequence_argsort_serial_export_next_pgrams( // sz_sequence_t const *const sequence, // sz_pgram_t *const global_pgrams, sz_sorted_idx_t const *const global_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // @@ -227,7 +359,7 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_export_next_pgrams( // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. - sz_size_t const window_capacity = sizeof(sz_pgram_t) - 1; + sz_size_t const pgram_capacity = sizeof(sz_pgram_t) - 1; // Perform the same operation for every string. for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { @@ -241,14 +373,14 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_export_next_pgrams( sz_cptr_t const source_str = sequence->get_start(sequence, partial_order_index); sz_size_t const length = sequence->get_length(sequence, partial_order_index); sz_size_t const remaining_length = length > start_character ? length - start_character : 0; - sz_size_t const exported_length = remaining_length > window_capacity ? window_capacity : remaining_length; + sz_size_t const exported_length = remaining_length > pgram_capacity ? pgram_capacity : remaining_length; // Fill with zeros, export a slice, and mark the exported length. sz_pgram_t *target_pgram = &global_pgrams[i]; sz_ptr_t target_str = (sz_ptr_t)target_pgram; *target_pgram = 0; for (sz_size_t j = 0; j < exported_length; ++j) target_str[j] = source_str[j + start_character]; - target_str[window_capacity] = exported_length; + target_str[pgram_capacity] = exported_length; #if defined(_SZ_IS_64_BIT) *target_pgram = sz_u64_bytes_reverse(*target_pgram); #else @@ -259,36 +391,52 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_export_next_pgrams( "We can have a zero value if only the string is shorter than other strings at this position."); } - // As our goal is to sort the strings using the exported integer "windows", + // As our goal is to sort the strings using the exported integer "pgrams", // this is a good place to validate the correctness of the exported data. if (SZ_DEBUG && start_character == 0) for (sz_size_t i = start_in_sequence + 1; i < end_in_sequence; ++i) { - sz_pgram_t const previous_window = global_pgrams[i - 1]; - sz_pgram_t const current_window = global_pgrams[i]; + sz_pgram_t const previous_pgram = global_pgrams[i - 1]; + sz_pgram_t const current_pgram = global_pgrams[i]; sz_cptr_t const previous_str = sequence->get_start(sequence, i - 1); sz_size_t const previous_length = sequence->get_length(sequence, i - 1); sz_cptr_t const current_str = sequence->get_start(sequence, i); sz_size_t const current_length = sequence->get_length(sequence, i); - sz_ordering_t const ordering = sz_order( // - previous_str, previous_length > window_capacity ? window_capacity : previous_length, // - current_str, current_length > window_capacity ? window_capacity : current_length); - _sz_assert( // - (previous_window < current_window) == (ordering == sz_less_k) && // - "The exported windows should be in the same order as the original strings."); + sz_ordering_t const ordering = sz_order( // + previous_str, previous_length > pgram_capacity ? pgram_capacity : previous_length, // + current_str, current_length > pgram_capacity ? pgram_capacity : current_length); + _sz_assert( // + (previous_pgram < current_pgram) == (ordering == sz_less_k) && // + "The exported pgrams should be in the same order as the original strings."); } } /** - * @brief The most important part of the QuickSort algorithm, that rearranges the elements in - * such a way, that all entries around the pivot are less than the pivot. - * - * It means that no relative order among the elements on the left or right side of the pivot is preserved. - * We chose the pivot point using Robert Sedgewick's method - the median of three elements - the first, - * the middle, and the last element of the given range. + * @brief Picks the "pivot" value for the QuickSort algorithm's partitioning step using Robert Sedgewick's method, + * the median of three elements - the first, the middle, and the last element of the given range. + */ +SZ_INTERNAL sz_pgram_t _sz_sequence_partitioning_pivot(sz_pgram_t const *pgrams, sz_size_t count) { + sz_size_t const middle_offset = count / 2; + sz_pgram_t const first_pgram = pgrams[0]; + sz_pgram_t const middle_pgram = pgrams[middle_offset]; + sz_pgram_t const last_pgram = pgrams[count - 1]; + if (first_pgram < middle_pgram) { + if (middle_pgram < last_pgram) { return middle_pgram; } + else if (first_pgram < last_pgram) { return last_pgram; } + else { return first_pgram; } + } + else { + if (first_pgram < last_pgram) { return first_pgram; } + else if (middle_pgram < last_pgram) { return last_pgram; } + else { return middle_pgram; } + } +} + +/** + * @brief The most important part of the QuickSort algorithm partitioning the elements around the pivot. * - * Moreover, considering our iterative refinement procedure, we can't just use the normal 2-way partitioning, - * as it will scatter the values equal to the pivot into the left and right partitions. Instead we use the - * Dutch National Flag @b 3-way partitioning, outputting the range of values equal to the pivot. + * The classical variant uses the normal 2-way partitioning, but it will scatter the values equal to the pivot + * into the left and right partitions. Instead we use the Dutch National Flag @b 3-way partitioning, outputting + * the range of values equal to the pivot. * * @see https://en.wikipedia.org/wiki/Dutch_national_flag_problem */ @@ -297,47 +445,42 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_3way_partition( // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // sz_size_t *first_pivot_offset, sz_size_t *last_pivot_offset) { - // Chose the pivot offset with Sedgewick's method. - sz_pgram_t pivot_window; - { - sz_size_t const middle_offset = start_in_sequence + (end_in_sequence - start_in_sequence) / 2; - sz_size_t const last_offset = end_in_sequence - 1; - sz_size_t const first_offset = start_in_sequence; - sz_pgram_t const first_window = global_pgrams[first_offset]; - sz_pgram_t const middle_window = global_pgrams[middle_offset]; - sz_pgram_t const last_window = global_pgrams[last_offset]; - if (first_window < middle_window) { - if (middle_window < last_window) { pivot_window = middle_window; } - else if (first_window < last_window) { pivot_window = last_window; } - else { pivot_window = first_window; } - } - else { - if (first_window < last_window) { pivot_window = first_window; } - else if (middle_window < last_window) { pivot_window = last_window; } - else { pivot_window = middle_window; } - } + // On very small inputs this procedure is rudimentary. + sz_size_t const count = end_in_sequence - start_in_sequence; + if (count <= 4) { + sz_pgram_t *const pgrams = global_pgrams + start_in_sequence; + sz_sorted_idx_t *const offsets = global_order + start_in_sequence; + if (count == 2) { _sz_sequence_sorting_network_2x(pgrams, offsets); } + else if (count == 3) { _sz_sequence_sorting_network_3x(pgrams, offsets); } + else if (count == 4) { _sz_sequence_sorting_network_4x(pgrams, offsets); } + *first_pivot_offset = start_in_sequence; + *last_pivot_offset = end_in_sequence; + return; } + // Chose the pivot offset with Sedgewick's method. + sz_pgram_t const pivot_pgram = _sz_sequence_partitioning_pivot(global_pgrams + start_in_sequence, count); + // Loop through the collection and move the elements around the pivot with the 3-way partitioning. sz_size_t partitioning_progress = start_in_sequence; // Current index. - sz_size_t smaller_offset = start_in_sequence; // Boundary for elements < pivot_window. - sz_size_t greater_offset = end_in_sequence - 1; // Boundary for elements > pivot_window. + sz_size_t smaller_offset = start_in_sequence; // Boundary for elements < `pivot_pgram`. + sz_size_t greater_offset = end_in_sequence - 1; // Boundary for elements > `pivot_pgram`. while (partitioning_progress <= greater_offset) { // Element is less than pivot: swap into the < pivot region. - if (global_pgrams[partitioning_progress] < pivot_window) { + if (global_pgrams[partitioning_progress] < pivot_pgram) { _sz_swap(sz_sorted_idx_t, global_order[partitioning_progress], global_order[smaller_offset]); _sz_swap(sz_pgram_t, global_pgrams[partitioning_progress], global_pgrams[smaller_offset]); ++partitioning_progress; ++smaller_offset; } // Element is greater than pivot: swap into the > pivot region. - else if (global_pgrams[partitioning_progress] > pivot_window) { + else if (global_pgrams[partitioning_progress] > pivot_pgram) { _sz_swap(sz_sorted_idx_t, global_order[partitioning_progress], global_order[greater_offset]); _sz_swap(sz_pgram_t, global_pgrams[partitioning_progress], global_pgrams[greater_offset]); --greater_offset; } - // Element equals pivot_window: leave it in place. + // Element equals `pivot_pgram`: leave it in place. else { ++partitioning_progress; } } @@ -349,7 +492,7 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_3way_partition( // * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort` and `sz_pgrams_sort`, * and using the `_sz_sequence_argsort_serial_3way_partition` under the hood. */ -SZ_PUBLIC void _sz_sequence_argsort_serial_recursively( // +SZ_INTERNAL void _sz_sequence_argsort_serial_recursively( // sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence) { @@ -372,41 +515,41 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_recursively( // /** * @brief Recursive Quick-Sort adaptation for strings, that processes the strings a few N-grams at a time. * It combines `_sz_sequence_argsort_serial_export_next_pgrams` and `_sz_sequence_argsort_serial_recursively`, - * recursively diving into the identical windows. + * recursively diving into the identical pgrams. */ -SZ_PUBLIC void _sz_sequence_argsort_serial_next_pgrams( // +SZ_INTERNAL void _sz_sequence_argsort_serial_next_pgrams( // sz_sequence_t const *const sequence, // sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // sz_size_t const start_character) { - // Prepare the new range of windows + // Prepare the new range of pgrams _sz_sequence_argsort_serial_export_next_pgrams(sequence, global_pgrams, global_order, start_in_sequence, end_in_sequence, start_character); - // Sort current windows with a quicksort + // Sort current pgrams with a quicksort _sz_sequence_argsort_serial_recursively(global_pgrams, global_order, start_in_sequence, end_in_sequence); // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. - sz_size_t const window_capacity = sizeof(sz_pgram_t) - 1; + sz_size_t const pgram_capacity = sizeof(sz_pgram_t) - 1; - // Repeat the procedure for the identical windows + // Repeat the procedure for the identical pgrams sz_size_t nested_start = start_in_sequence; sz_size_t nested_end = start_in_sequence; while (nested_end != end_in_sequence) { - // Find the end of the identical windows - sz_pgram_t current_window_integer = global_pgrams[nested_start]; - while (nested_end != end_in_sequence && current_window_integer == global_pgrams[nested_end]) ++nested_end; + // Find the end of the identical pgrams + sz_pgram_t current_pgram = global_pgrams[nested_start]; + while (nested_end != end_in_sequence && current_pgram == global_pgrams[nested_end]) ++nested_end; - // If the identical windows are not trivial and each string has more characters, sort them recursively - sz_cptr_t current_window_str = (sz_cptr_t)¤t_window_integer; - sz_size_t current_window_length = (sz_size_t)current_window_str[0]; //! The byte order was swapped + // If the identical pgrams are not trivial and each string has more characters, sort them recursively + sz_cptr_t current_pgram_str = (sz_cptr_t)¤t_pgram; + sz_size_t current_pgram_length = (sz_size_t)current_pgram_str[0]; //! The byte order was swapped int has_multiple_strings = nested_end - nested_start > 1; - int has_more_characters_in_each = current_window_length == window_capacity; + int has_more_characters_in_each = current_pgram_length == pgram_capacity; if (has_multiple_strings && has_more_characters_in_each) { _sz_sequence_argsort_serial_next_pgrams(sequence, global_pgrams, global_order, nested_start, nested_end, - start_character + window_capacity); + start_character + pgram_capacity); } // Move to the next nested_start = nested_end; @@ -438,14 +581,14 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz // is included in those P-long words. So, in reality, we will be taking (P-1) bytes from each string on every // iteration of a recursive algorithm. sz_size_t memory_usage = sequence->count * sizeof(sz_pgram_t); - sz_pgram_t *windows = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); - if (!windows) return sz_false_k; + sz_pgram_t *pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); + if (!pgrams) return sz_false_k; // Recursively sort the whole sequence. - _sz_sequence_argsort_serial_next_pgrams(sequence, windows, order, 0, sequence->count, 0); + _sz_sequence_argsort_serial_next_pgrams(sequence, pgrams, order, 0, sequence->count, 0); // Free temporary storage. - alloc->free(windows, memory_usage, alloc); + alloc->free(pgrams, memory_usage, alloc); return sz_true_k; } @@ -463,75 +606,11 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, s #pragma region Serial MergeSort Implementation -/** - * @brief A scalar sorting network for 8 elements that reorders both the keys - * and their corresponding offsets in only 19 comparisons, the most efficient - * variant currently known. - * @see https://en.wikipedia.org/wiki/Sorting_network - * - * The network consists of 6 stages with the following compare–swap pairs: - * - * Stage 1: (0,1), (2,3), (4,5), (6,7) - * Stage 2: (0,2), (1,3), (4,6), (5,7) - * Stage 3: (1,2), (5,6) - * Stage 4: (0,4), (1,5), (2,6), (3,7) - * Stage 5: (2,4), (3,5) - * Stage 6: (1,2), (3,4), (5,6) - */ -void _sz_sequence_argsort_stable_serial_8x_network(sz_pgram_t *keys, sz_sorted_idx_t *offsets) { - -#define _sz_sequence_argsort_stable_8x_conditional_swap(i, j) \ - do { \ - if (keys[i] > keys[j]) { \ - _sz_swap(sz_pgram_t, keys[i], keys[j]); \ - _sz_swap(sz_sorted_idx_t, offsets[i], offsets[j]); \ - } \ - } while (0) - - // Stage 1: Compare–swap adjacent pairs. - _sz_sequence_argsort_stable_8x_conditional_swap(0, 1); - _sz_sequence_argsort_stable_8x_conditional_swap(2, 3); - _sz_sequence_argsort_stable_8x_conditional_swap(4, 5); - _sz_sequence_argsort_stable_8x_conditional_swap(6, 7); - - // Stage 2: Compare–swap with stride 2. - _sz_sequence_argsort_stable_8x_conditional_swap(0, 2); - _sz_sequence_argsort_stable_8x_conditional_swap(1, 3); - _sz_sequence_argsort_stable_8x_conditional_swap(4, 6); - _sz_sequence_argsort_stable_8x_conditional_swap(5, 7); - - // Stage 3: Compare–swap between middle elements. - _sz_sequence_argsort_stable_8x_conditional_swap(1, 2); - _sz_sequence_argsort_stable_8x_conditional_swap(5, 6); - - // Stage 4: Compare–swap across the two halves. - _sz_sequence_argsort_stable_8x_conditional_swap(0, 4); - _sz_sequence_argsort_stable_8x_conditional_swap(1, 5); - _sz_sequence_argsort_stable_8x_conditional_swap(2, 6); - _sz_sequence_argsort_stable_8x_conditional_swap(3, 7); - - // Stage 5: Compare–swap within each half. - _sz_sequence_argsort_stable_8x_conditional_swap(2, 4); - _sz_sequence_argsort_stable_8x_conditional_swap(3, 5); - - // Stage 6: Final compare–swap of adjacent elements. - _sz_sequence_argsort_stable_8x_conditional_swap(1, 2); - _sz_sequence_argsort_stable_8x_conditional_swap(3, 4); - _sz_sequence_argsort_stable_8x_conditional_swap(5, 6); - -#undef _sz_sequence_argsort_stable_8x_conditional_swap - - // Validate the sorting network. - if (SZ_DEBUG) - for (sz_size_t i = 1; i < 8; ++i) - _sz_assert(keys[i - 1] <= keys[i] && "The sorting network must sort the keys in ascending order."); -} - /** * @brief Helper function similar to `std::set_union` over pairs of integers and their original indices. * @see https://en.cppreference.com/w/cpp/algorithm/set_union */ -void _sz_sequence_argsort_stable_serial_merge( // +SZ_INTERNAL void _sz_sequence_argsort_stable_serial_merge( // sz_pgram_t const *first_pgrams, sz_sorted_idx_t const *first_indices, sz_size_t first_count, // sz_pgram_t const *second_pgrams, sz_sorted_idx_t const *second_indices, sz_size_t second_count, // sz_pgram_t *result_pgrams, sz_sorted_idx_t *result_indices) { @@ -592,8 +671,7 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c } // Go through short chunks of 8 elements and sort them with a sorting network. - for (sz_size_t i = 0; i + 8u <= count; i += 8u) - _sz_sequence_argsort_stable_serial_8x_network(pgrams + i, order + i); + for (sz_size_t i = 0; i + 8u <= count; i += 8u) _sz_sequence_sorting_network_8x(pgrams + i, order + i); // For the tail of the array, sort it with insertion sort. sz_size_t const tail_count = count & 7u; From 71f1f4baecca3e3692acc561c74f376a351478f6 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 20 Feb 2025 13:35:39 +0000 Subject: [PATCH 109/751] Break: `checkum` to `bytesum`, new hash, and PRNG --- README.md | 2 +- c/lib.c | 63 +++++++++++++++++++++++++---- include/stringzilla/stringzilla.h | 2 +- include/stringzilla/stringzilla.hpp | 43 ++++++++------------ include/stringzilla/types.h | 16 ++++++-- python/lib.c | 10 ++--- rust/lib.rs | 20 ++++----- scripts/bench_fingerprint.cpp | 2 +- scripts/bench_token.cpp | 14 +++---- scripts/test.cpp | 12 +++--- scripts/test.py | 4 +- 11 files changed, 119 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index c5253c4c..c657decf 100644 --- a/README.md +++ b/README.md @@ -629,7 +629,7 @@ sz_size_t substring_position = sz_find_haswell(haystack.start, haystack.length, sz_size_t substring_position = sz_find_neon(haystack.start, haystack.length, needle.start, needle.length); // Hash strings -sz_u64_t hash = sz_hash(haystack.start, haystack.length); +sz_u64_t hash = sz_hash(haystack.start, haystack.length, 42); // or any other seed ;) // Perform collection level operations sz_sequence_t array = {your_handle, your_count, your_get_start, your_get_length}; diff --git a/c/lib.c b/c/lib.c index 64e7b61a..c68e7a1f 100644 --- a/c/lib.c +++ b/c/lib.c @@ -177,7 +177,12 @@ typedef struct sz_implementations_t { sz_move_t move; sz_fill_t fill; sz_look_up_transform_t look_up_transform; - sz_checksum_t checksum; + + sz_bytesum_t bytesum; + sz_hash_t hash; + sz_hash_state_init_t hash_state_init; + sz_hash_state_stream_t hash_state_stream; + sz_hash_state_fold_t hash_state_fold; sz_find_byte_t find_byte; sz_find_byte_t rfind_byte; @@ -214,7 +219,12 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->move = sz_move_serial; impl->fill = sz_fill_serial; impl->look_up_transform = sz_look_up_transform_serial; - impl->checksum = sz_checksum_serial; + + impl->bytesum = sz_bytesum_serial; + impl->hash = sz_hash_serial; + impl->hash_state_init = sz_hash_state_init_serial; + impl->hash_state_stream = sz_hash_state_stream_serial; + impl->hash_state_fold = sz_hash_state_fold_serial; impl->find = sz_find_serial; impl->rfind = sz_rfind_serial; @@ -236,7 +246,12 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->move = sz_move_haswell; impl->fill = sz_fill_haswell; impl->look_up_transform = sz_look_up_transform_haswell; - impl->checksum = sz_checksum_haswell; + + impl->bytesum = sz_bytesum_haswell; + impl->hash = sz_hash_haswell; + impl->hash_state_init = sz_hash_state_init_haswell; + impl->hash_state_stream = sz_hash_state_stream_haswell; + impl->hash_state_fold = sz_hash_state_fold_haswell; impl->find_byte = sz_find_byte_haswell; impl->rfind_byte = sz_rfind_byte_haswell; @@ -256,11 +271,17 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->move = sz_move_skylake; impl->fill = sz_fill_skylake; + impl->bytesum = sz_bytesum_skylake; + impl->hash = sz_hash_skylake; + impl->hash_state_init = sz_hash_state_init_skylake; + impl->hash_state_stream = sz_hash_state_stream_skylake; + impl->hash_state_fold = sz_hash_state_fold_skylake; + impl->find = sz_find_skylake; impl->rfind = sz_rfind_skylake; impl->find_byte = sz_find_byte_skylake; impl->rfind_byte = sz_rfind_byte_skylake; - impl->checksum = sz_checksum_skylake; + impl->bytesum = sz_bytesum_skylake; } #endif @@ -268,10 +289,17 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { if (caps & sz_cap_ice_k) { impl->find_from_set = sz_find_charset_ice; impl->rfind_from_set = sz_rfind_charset_ice; + impl->edit_distance = sz_edit_distance_ice; impl->alignment_score = sz_alignment_score_ice; + impl->look_up_transform = sz_look_up_transform_ice; - impl->checksum = sz_checksum_ice; + + impl->bytesum = sz_bytesum_ice; + impl->hash = sz_hash_ice; + impl->hash_state_init = sz_hash_state_init_ice; + impl->hash_state_stream = sz_hash_state_stream_ice; + impl->hash_state_fold = sz_hash_state_fold_ice; } #endif @@ -283,7 +311,12 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->move = sz_move_neon; impl->fill = sz_fill_neon; impl->look_up_transform = sz_look_up_transform_neon; - impl->checksum = sz_checksum_neon; + + impl->bytesum = sz_bytesum_neon; + impl->hash = sz_hash_neon; + impl->hash_state_init = sz_hash_state_init_neon; + impl->hash_state_stream = sz_hash_state_stream_neon; + impl->hash_state_fold = sz_hash_state_fold_neon; impl->find = sz_find_neon; impl->rfind = sz_rfind_neon; @@ -331,7 +364,23 @@ BOOL WINAPI _DllMainCRTStartup(HINSTANCE hints, DWORD forward_reason, LPVOID lp) __attribute__((constructor)) static void sz_dispatch_table_init_on_gcc_or_clang(void) { sz_dispatch_table_init(); } #endif -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { return sz_dispatch_table.checksum(text, length); } +SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length) { return sz_dispatch_table.bytesum(text, length); } + +SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { + return sz_dispatch_table.hash(text, length, seed); +} + +SZ_DYNAMIC void sz_hash_state_init(sz_hash_state_t *state, sz_u64_t seed) { + sz_dispatch_table.hash_state_init(state, seed); +} + +SZ_DYNAMIC void sz_hash_state_stream(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + sz_dispatch_table.hash_state_stream(state, text, length); +} + +SZ_DYNAMIC sz_u64_t sz_hash_state_fold(sz_hash_state_t const *state) { + return sz_dispatch_table.hash_state_fold(state); +} SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { return sz_dispatch_table.equal(a, b, length); diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 0b23b33b..349aba79 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -43,7 +43,7 @@ #include "compare.h" // `sz_equal`, `sz_order` #include "find.h" // `sz_find`, `sz_find_charset`, `sz_rfind` -#include "hash.h" // `sz_checksum`, `sz_hash`, `sz_hashes` +#include "hash.h" // `sz_bytesum`, `sz_hash`, `sz_state_init`, `sz_state_stream`, `sz_state_fold` #include "memory.h" // `sz_copy`, `sz_move`, `sz_fill` #include "similarity.h" // `sz_edit_distance`, `sz_alignment_score` #include "small_string.h" // `sz_string_t`, `sz_string_init`, `sz_string_free` diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 0a4737ad..12f65265 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -1929,7 +1929,7 @@ class basic_string_slice { size_type hash() const noexcept { return static_cast(sz_hash(start_, length_)); } /** @brief Aggregates the values of individual bytes of a string. */ - size_type checksum() const noexcept { return static_cast(sz_checksum(start_, length_)); } + size_type bytesum() const noexcept { return static_cast(sz_bytesum(start_, length_)); } /** @brief Populate a character set with characters present in this string. */ char_set as_set() const noexcept { @@ -3326,33 +3326,30 @@ class basic_string { size_type hash() const noexcept { return view().hash(); } /** @brief Aggregates the values of individual bytes of a string. */ - size_type checksum() const noexcept { return view().checksum(); } + size_type bytesum() const noexcept { return view().bytesum(); } /** - * @brief Overwrites the string with random characters from the given alphabet using the random generator. + * @brief Overwrites the string with random binary data. * - * @param generator A random generator function object that returns a random number in the range [0, 2^64). - * @param alphabet A string of characters to choose from. + * @param nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! + * @param key A 128-bit key to initialize the AES-CTR block-cypher, zeros by default. */ - template - basic_string &randomize(generator_type &generator, string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { + basic_string &randomize(sz_u64_t nonce, sz_aes128_block_t key = {}) noexcept { sz_ptr_t start; sz_size_t length; sz_string_range(&string_, &start, &length); - sz_random_generator_t generator_callback = &_call_random_generator; - sz_generate(alphabet.data(), alphabet.size(), start, length, generator_callback, &generator); + sz_generate(start, length, nonce, &key); return *this; } /** - * @brief Overwrites the string with random characters from the given alphabet - * using `std::rand` as the random generator. - * - * @param alphabet A string of characters to choose from. + * @brief Overwrites the string with random binary data. + * Produces the nonce from a static variable, incrementing it each time. + * In this case the undefined behaviour in concurrent environments plays in our favor. */ - basic_string &randomize(string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { - auto generator = []() { return static_cast(std::rand()); }; - return randomize(generator, alphabet); + basic_string &randomize() noexcept { + static sz_u64_t nonce = 42; + return randomize(nonce++, {}); } /** @@ -3360,25 +3357,19 @@ class basic_string { * May throw exceptions if the memory allocation fails. * * @param length The length of the generated string. - * @param alphabet A string of characters to choose from. + * @param nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! */ - static basic_string random(size_type length, string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept(false) { - return basic_string(length, '\0').randomize(alphabet); + static basic_string random(size_type length, sz_u64_t nonce) noexcept(false) { + return basic_string(length, '\0').randomize(nonce); } /** * @brief Generate a new random string of given length using the provided random number generator. * May throw exceptions if the memory allocation fails. * - * @param generator A random generator function object that returns a random number in the range [0, 2^64). * @param length The length of the generated string. - * @param alphabet A string of characters to choose from. */ - template - static basic_string random(generator_type &generator, size_type length, - string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept(false) { - return basic_string(length, '\0').randomize(generator, alphabet); - } + static basic_string random(size_type length) noexcept(false) { return basic_string(length, '\0').randomize(); } /** * @brief Replaces ( @b in-place ) all occurrences of a given string with the ::replacement string. diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 01d090b2..a3b9d62e 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -443,10 +443,19 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void #pragma region API Signature Types /** @brief Signature of ::sz_hash. */ -typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); +typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t, sz_u64_t); -/** @brief Signature of ::sz_checksum. */ -typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); +/** @brief Signature of ::sz_hash_state_init. */ +typedef void (*sz_hash_state_init_t)(struct sz_hash_state_t *, sz_u64_t); + +/** @brief Signature of ::sz_hash_state_stream. */ +typedef void (*sz_hash_state_stream_t)(struct sz_hash_state_t *, sz_cptr_t, sz_size_t); + +/** @brief Signature of ::sz_hash_state_fold. */ +typedef sz_u64_t (*sz_hash_state_fold_t)(struct sz_hash_state_t const *); + +/** @brief Signature of ::sz_bytesum. */ +typedef sz_u64_t (*sz_bytesum_t)(sz_cptr_t, sz_size_t); /** @brief Signature of ::sz_equal. */ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); @@ -887,6 +896,7 @@ SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x #pragma GCC push_options #pragma GCC target("bmi", "bmi2") #pragma clang attribute push(__attribute__((target("bmi,bmi2"))), apply_to = function) +SZ_INTERNAL __mmask8 _sz_u8_mask_until(sz_size_t n) { return (__mmask8)_bzhi_u32(0xFFu, n); } SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { return (__mmask16)_bzhi_u32(0xFFFFu, n); } SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { return (__mmask32)_bzhi_u64(0xFFFFFFFFu, n); } SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { return (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n); } diff --git a/python/lib.c b/python/lib.c index 1406691c..6e334719 100644 --- a/python/lib.c +++ b/python/lib.c @@ -717,7 +717,7 @@ static PyObject *Str_like_hash(PyObject *self, PyObject *args, PyObject *kwargs) return PyLong_FromUnsignedLongLong((unsigned long long)result); } -static char const doc_like_checksum[] = // +static char const doc_like_bytesum[] = // "Compute the checksum of individual byte values in a string.\n" "\n" "This function can be called as a method on a Str object or as a standalone function.\n" @@ -728,12 +728,12 @@ static char const doc_like_checksum[] = // "Raises:\n" " TypeError: If the argument is not string-like or incorrect number of arguments is provided."; -static PyObject *Str_like_checksum(PyObject *self, PyObject *args, PyObject *kwargs) { +static PyObject *Str_like_bytesum(PyObject *self, PyObject *args, PyObject *kwargs) { // Check minimum arguments int is_member = self != NULL && PyObject_TypeCheck(self, &StrType); Py_ssize_t nargs = PyTuple_Size(args); if (nargs < !is_member || nargs > !is_member + 1 || kwargs) { - PyErr_SetString(PyExc_TypeError, "checksum() expects exactly one positional argument"); + PyErr_SetString(PyExc_TypeError, "bytesum() expects exactly one positional argument"); return NULL; } @@ -746,7 +746,7 @@ static PyObject *Str_like_checksum(PyObject *self, PyObject *args, PyObject *kwa return NULL; } - sz_u64_t result = sz_checksum(text.start, text.length); + sz_u64_t result = sz_bytesum(text.start, text.length); return PyLong_FromUnsignedLongLong((unsigned long long)result); } @@ -3684,7 +3684,7 @@ static PyMethodDef stringzilla_methods[] = { // Global unary extensions {"hash", Str_like_hash, SZ_METHOD_FLAGS, doc_like_hash}, - {"checksum", Str_like_checksum, SZ_METHOD_FLAGS, doc_like_checksum}, + {"bytesum", Str_like_bytesum, SZ_METHOD_FLAGS, doc_like_bytesum}, {NULL, NULL, 0, NULL}}; diff --git a/rust/lib.rs b/rust/lib.rs index 08c8772a..07db0a32 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -56,7 +56,7 @@ pub mod sz { fn sz_hash(text: *const c_void, length: usize) -> u64; - fn sz_checksum(text: *const c_void, length: usize) -> u64; + fn sz_bytesum(text: *const c_void, length: usize) -> u64; fn sz_edit_distance( haystack1: *const c_void, @@ -123,21 +123,21 @@ pub mod sz { /// # Returns /// /// A `u64` representing the checksum value of the input byte slice. - pub fn checksum(text: T) -> u64 + pub fn bytesum(text: T) -> u64 where T: AsRef<[u8]>, { let text_ref = text.as_ref(); let text_pointer = text_ref.as_ptr() as _; let text_length = text_ref.len(); - let result = unsafe { sz_checksum(text_pointer, text_length) }; + let result = unsafe { sz_bytesum(text_pointer, text_length) }; return result; } /// Computes a 64-bit AES-based hash value for a given byte slice `text`. /// This function is designed to provide a high-quality hash value for use in /// hash tables, data structures, and cryptographic applications. - /// Unlike the checksum function, the hash function is order-sensitive. + /// Unlike the bytesum function, the hash function is order-sensitive. /// /// # Arguments /// @@ -1034,7 +1034,7 @@ pub trait StringZilla<'a, N> where N: AsRef<[u8]> + 'a, { - /// Computes the checksum value of unsigned bytes in a given string. + /// Computes the bytesum value of unsigned bytes in a given string. /// This function is useful for verifying data integrity and detecting changes in /// binary data, such as files or network packets. /// @@ -1044,14 +1044,14 @@ where /// use stringzilla::StringZilla; /// /// let text = "Hello"; - /// assert_eq!(text.sz_checksum(), Some(500)); + /// assert_eq!(text.sz_bytesum(), Some(500)); /// ``` - fn sz_checksum(&self) -> u64; + fn sz_bytesum(&self) -> u64; /// Computes a 64-bit AES-based hash value for a given string. /// This function is designed to provide a high-quality hash value for use in /// hash tables, data structures, and cryptographic applications. - /// Unlike the checksum function, the hash function is order-sensitive. + /// Unlike the bytesum function, the hash function is order-sensitive. /// /// # Examples /// @@ -1352,8 +1352,8 @@ where T: AsRef<[u8]> + ?Sized, N: AsRef<[u8]> + 'a, { - fn sz_checksum(&self) -> u64 { - sz::checksum(self) + fn sz_bytesum(&self) -> u64 { + sz::bytesum(self) } fn sz_hash(&self) -> u64 { diff --git a/scripts/bench_fingerprint.cpp b/scripts/bench_fingerprint.cpp index 82064a29..cbc2812c 100644 --- a/scripts/bench_fingerprint.cpp +++ b/scripts/bench_fingerprint.cpp @@ -90,7 +90,7 @@ void bench(strings_type &&strings) { if (strings.size() == 0) return; // Benchmark logical operations - bench_unary_functions(strings, checksum_functions()); + bench_unary_functions(strings, bytesum_functions()); bench_unary_functions(strings, hashing_functions()); bench_binary_functions(strings, equality_functions()); bench_binary_functions(strings, ordering_functions()); diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 749daa85..64ba2f96 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -11,7 +11,7 @@ using namespace ashvardanian::stringzilla::scripts; -tracked_unary_functions_t checksum_functions() { +tracked_unary_functions_t bytesum_functions() { auto wrap_sz = [](auto function) -> unary_function_t { return unary_function_t([function](std::string_view s) { return function(s.data(), s.size()); }); }; @@ -21,18 +21,18 @@ tracked_unary_functions_t checksum_functions() { return std::accumulate(s.begin(), s.end(), (std::size_t)0, [](std::size_t sum, char c) { return sum + static_cast(c); }); }}, - {"sz_checksum_serial", wrap_sz(sz_checksum_serial), true}, + {"sz_bytesum_serial", wrap_sz(sz_bytesum_serial), true}, #if SZ_USE_HASWELL - {"sz_checksum_haswell", wrap_sz(sz_checksum_haswell), true}, + {"sz_bytesum_haswell", wrap_sz(sz_bytesum_haswell), true}, #endif #if SZ_USE_SKYLAKE - {"sz_checksum_skylake", wrap_sz(sz_checksum_skylake), true}, + {"sz_bytesum_skylake", wrap_sz(sz_bytesum_skylake), true}, #endif #if SZ_USE_ICE - {"sz_checksum_ice", wrap_sz(sz_checksum_ice), true}, + {"sz_bytesum_ice", wrap_sz(sz_bytesum_ice), true}, #endif #if SZ_USE_NEON - {"sz_checksum_neon", wrap_sz(sz_checksum_neon), true}, + {"sz_bytesum_neon", wrap_sz(sz_bytesum_neon), true}, #endif }; return result; @@ -139,7 +139,7 @@ void bench(strings_type &&strings) { if (strings.size() == 0) return; // Benchmark logical operations - bench_unary_functions(strings, checksum_functions()); + bench_unary_functions(strings, bytesum_functions()); bench_unary_functions(strings, hashing_functions()); bench_binary_functions(strings, equality_functions()); bench_binary_functions(strings, ordering_functions()); diff --git a/scripts/test.cpp b/scripts/test.cpp index 7b3fe4db..58752a35 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -802,17 +802,17 @@ static void test_non_stl_extensions_for_reads() { return std::accumulate(s.begin(), s.end(), (std::size_t)0, [](std::size_t sum, char c) { return sum + static_cast(c); }); }; - assert(str("a").checksum() == (std::size_t)'a'); - assert(str("0").checksum() == (std::size_t)'0'); - assert(str("0123456789").checksum() == arithmetic_sum('0', '9')); - assert(str("abcdefghijklmnopqrstuvwxyz").checksum() == arithmetic_sum('a', 'z')); - assert(str("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz").checksum() == + assert(str("a").bytesum() == (std::size_t)'a'); + assert(str("0").bytesum() == (std::size_t)'0'); + assert(str("0123456789").bytesum() == arithmetic_sum('0', '9')); + assert(str("abcdefghijklmnopqrstuvwxyz").bytesum() == arithmetic_sum('a', 'z')); + assert(str("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz").bytesum() == arithmetic_sum('a', 'z') * 3); assert_scoped( str s = "近来,加文出席微博之夜时对着镜头频繁摆出假笑表情、一度累瘫睡倒在沙发上的照片被广泛转发,引发对他失去童年、" "被过度消费的担忧。八岁的加文,已当网红近六年了,可以说,自懂事以来,他没有过过一天没有名气的日子。", - (void)0, s.checksum() == accumulate_bytes(s)); + (void)0, s.bytesum() == accumulate_bytes(s)); // Computing edit-distances. assert(sz::hamming_distance(str("hello"), str("hello")) == 0); diff --git a/scripts/test.py b/scripts/test.py index 93a01706..ea95e8d4 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -777,12 +777,12 @@ def test_translations_random(length: int): @pytest.mark.repeat(3) @pytest.mark.parametrize("length", list(range(0, 300)) + [1024, 4096, 100000]) -def test_checksums_random(length: int): +def test_bytesums_random(length: int): def sum_bytes(body: str) -> int: return sum([ord(c) for c in body]) body = get_random_string(length=length) - assert sum_bytes(body) == sz.checksum(body) + assert sum_bytes(body) == sz.bytesum(body) @pytest.mark.parametrize("list_length", [10, 20, 30, 40, 50]) From cb18c787c0b413a8ce1422222f2a0a68f9e102dc Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 20 Feb 2025 22:46:36 +0000 Subject: [PATCH 110/751] Add: AES-based hash placeholders --- README.md | 14 +- c/lib.c | 30 +- drafts/fingerprint.h | 13 +- drafts/sort.h | 124 +++++ include/stringzilla/find.h | 17 +- include/stringzilla/hash.h | 760 +++++++++++++++++++++------- include/stringzilla/sort.h | 217 ++++++-- include/stringzilla/stringzilla.hpp | 85 ++-- include/stringzilla/types.h | 5 +- rust/lib.rs | 14 +- scripts/bench_token.cpp | 26 +- scripts/test.cpp | 7 +- 12 files changed, 978 insertions(+), 334 deletions(-) create mode 100644 drafts/sort.h diff --git a/README.md b/README.md index c657decf..22c8e2b0 100644 --- a/README.md +++ b/README.md @@ -622,14 +622,22 @@ Both are companions of the `sz_find`, first for x86 CPUs with AVX-512 support, a sz_string_view_t haystack = {your_text, your_text_length}; sz_string_view_t needle = {your_subtext, your_subtext_length}; -// Perform string-level operations +// Perform string-level operations auto-picking the backend or dispatching manually sz_size_t substring_position = sz_find(haystack.start, haystack.length, needle.start, needle.length); sz_size_t substring_position = sz_find_skylake(haystack.start, haystack.length, needle.start, needle.length); sz_size_t substring_position = sz_find_haswell(haystack.start, haystack.length, needle.start, needle.length); sz_size_t substring_position = sz_find_neon(haystack.start, haystack.length, needle.start, needle.length); -// Hash strings -sz_u64_t hash = sz_hash(haystack.start, haystack.length, 42); // or any other seed ;) +// Hash strings at once +sz_u64_t hash = sz_hash(haystack.start, haystack.length, 42); // 42 is the seed +sz_u64_t checksum = sz_bytesum(haystack.start, haystack.length); // or accumulate byte values + +// Hash strings incrementally with "init", "stream", and "fold": +sz_hash_state_t state; +sz_hash_state_init(&state, 42); +sz_hash_state_stream(&state, haystack.start, 1); // first char +sz_hash_state_stream(&state, haystack.start + 1, haystack.length - 1); // rest of the string +sz_u64_t hash = sz_hash_state_fold(&state); // Perform collection level operations sz_sequence_t array = {your_handle, your_count, your_get_start, your_get_length}; diff --git a/c/lib.c b/c/lib.c index c68e7a1f..559062ba 100644 --- a/c/lib.c +++ b/c/lib.c @@ -183,6 +183,7 @@ typedef struct sz_implementations_t { sz_hash_state_init_t hash_state_init; sz_hash_state_stream_t hash_state_stream; sz_hash_state_fold_t hash_state_fold; + sz_generate_t generate; sz_find_byte_t find_byte; sz_find_byte_t rfind_byte; @@ -225,6 +226,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_serial; impl->hash_state_stream = sz_hash_state_stream_serial; impl->hash_state_fold = sz_hash_state_fold_serial; + impl->generate = sz_generate_serial; impl->find = sz_find_serial; impl->rfind = sz_rfind_serial; @@ -252,6 +254,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_haswell; impl->hash_state_stream = sz_hash_state_stream_haswell; impl->hash_state_fold = sz_hash_state_fold_haswell; + impl->generate = sz_generate_haswell; impl->find_byte = sz_find_byte_haswell; impl->rfind_byte = sz_rfind_byte_haswell; @@ -276,6 +279,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_skylake; impl->hash_state_stream = sz_hash_state_stream_skylake; impl->hash_state_fold = sz_hash_state_fold_skylake; + impl->generate = sz_generate_skylake; impl->find = sz_find_skylake; impl->rfind = sz_rfind_skylake; @@ -300,6 +304,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_ice; impl->hash_state_stream = sz_hash_state_stream_ice; impl->hash_state_fold = sz_hash_state_fold_ice; + impl->generate = sz_generate_ice; } #endif @@ -317,6 +322,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_neon; impl->hash_state_stream = sz_hash_state_stream_neon; impl->hash_state_fold = sz_hash_state_fold_neon; + impl->generate = sz_generate_neon; impl->find = sz_find_neon; impl->rfind = sz_rfind_neon; @@ -382,6 +388,10 @@ SZ_DYNAMIC sz_u64_t sz_hash_state_fold(sz_hash_state_t const *state) { return sz_dispatch_table.hash_state_fold(state); } +SZ_DYNAMIC void sz_generate(sz_ptr_t result, sz_size_t result_length, sz_u64_t nonce) { + sz_dispatch_table.generate(result, result_length, nonce); +} + SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { return sz_dispatch_table.equal(a, b, length); } @@ -499,22 +509,6 @@ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_ return sz_rfind_charset(h, h_length, &set); } -#if !SZ_AVOID_LIBC -sz_u64_t _sz_random_generator(void *empty_state) { - sz_unused(empty_state); - return (sz_u64_t)rand(); -} -#endif - -SZ_DYNAMIC void sz_generate( // - sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { -#if !SZ_AVOID_LIBC - if (!generator) generator = _sz_random_generator; -#endif - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); -} - // Provide overrides for the libc mem* functions #if SZ_OVERRIDE_LIBC && !defined(__CYGWIN__) @@ -591,8 +585,8 @@ SZ_DYNAMIC void *memrchr(void const *s, int c_wide, size_t n) { } SZ_DYNAMIC void memfrob(void *s, size_t n) { - char const *base64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - sz_generate(base64, 64, s, n, SZ_NULL, SZ_NULL); + static sz_u64_t nonce = 42; + sz_generate(s, n, nonce++); } #endif diff --git a/drafts/fingerprint.h b/drafts/fingerprint.h index 9cdfcc5e..a442d7a0 100644 --- a/drafts/fingerprint.h +++ b/drafts/fingerprint.h @@ -1,17 +1,8 @@ /** - * @brief Hardware-accelerated string hashing and checksums. + * @brief Hardware-accelerated rolling string hashes or fingerprints. * @file hash.h * @author Ash Vardanian - * - * Includes core APIs: - * - * - `sz_checksum` - for byte-level checksums. - * - `sz_hash` - for 64-bit single-shot hashing. - * - `sz_hashes` - producing the rolling hashes of a string. - * - `sz_generate` - populating buffers with random data. - * - * Convenience functions for character-set matching: - * + * - `sz_hashes_fingerprint` * - `sz_hashes_intersection` */ diff --git a/drafts/sort.h b/drafts/sort.h new file mode 100644 index 00000000..bc1bb34e --- /dev/null +++ b/drafts/sort.h @@ -0,0 +1,124 @@ + + +/** + * @brief Perform a compare–exchange (compare–swap) on two 8‑lane vectors, + * updating both the keys and their associated offsets. + * + * @param keys Pointer to a __m512i containing 8 keys. + * @param offsets Pointer to a __m512i containing 8 offsets. + * @param perm Permutation vector (as __m512i) that maps each lane + * to its “partner” in the compare–exchange. + * @param fixed_mask An 8‑bit immediate mask (as __mmask8) that indicates, + * for each pair, which lane is designated as the “upper” + * element. For that lane the max is chosen, while for the + * complementary (“lower”) lane the min is chosen. + * + * This helper function “mirrors” the scalar operation: + * + * if (keys[i] > keys[j]) { + * swap(keys[i], keys[j]); + * swap(offsets[i], offsets[j]); + * } + * + * for each pair (i,j) defined by the permutation vector. + * + * The keys are updated by computing the unsigned min and max between each + * element and its partner, and then blending them into the designated positions + * using the fixed_mask. In order to update the offsets in a stable manner, + * we first compute the partner offsets (using the same permutation), then for each + * pair we choose: + * + * - For the lane designated as lower (mask bit = 0): + * if (orig_key <= partner_key) then keep self’s offset, + * else take the partner’s offset. + * + * - For the lane designated as upper (mask bit = 1): + * if (orig_key > partner_key) then keep self’s offset, + * else take the partner’s offset. + * + * This ensures that if keys are equal (thus stable), no swap is done. + */ +SZ_INTERNAL void cswap_argsort_avx512(__m512i *pgrams, __m512i *offsets, __m512i perm, __mmask8 fixed_mask) { + // Save original pgrams and offsets for condition computation. + __m512i orig_pgrams = *pgrams; + __m512i orig_offsets = *offsets; + + // Compute partner vectors using the permutation vector. + __m512i partner_pgrams = _mm512_permutexvar_epi64(perm, orig_pgrams); + __m512i partner_offsets = _mm512_permutexvar_epi64(perm, orig_offsets); + + // Compute new pgrams: for each pair, choose the unsigned min for the lower lane + // and the unsigned max for the upper lane. + __m512i pgrams_min = _mm512_min_epu64(orig_pgrams, partner_pgrams); + __m512i pgrams_max = _mm512_max_epu64(orig_pgrams, partner_pgrams); + *pgrams = _mm512_mask_blend_epi64(fixed_mask, pgrams_min, pgrams_max); + + // For offsets, we want to mimic the swap decision used for pgrams. + // For each pair (i,j) (with i < j), if orig_pgrams[i] <= partner_pgrams[i] then + // the lower key came from the current lane (i) and the upper from the partner (j); + // otherwise the lower key came from the partner. + __mmask8 lower_cond = + _mm512_cmp_epu64_mask(orig_pgrams, partner_pgrams, _MM_CMPINT_LE); // true if no swap needed for lower lane. + __mmask8 upper_cond = + _mm512_cmp_epu64_mask(orig_pgrams, partner_pgrams, _MM_CMPINT_GT); // true if swap needed for upper lane. + + // Compute offsets for lower positions (fixed_mask bit = 0): + // If lower_cond is true, then the current lane’s offset is correct; + // otherwise, use the partner’s offset. + __m512i offsets_lower = _mm512_mask_blend_epi64(lower_cond, partner_offsets, orig_offsets); + + // Compute offsets for upper positions (fixed_mask bit = 1): + // If upper_cond is true, then keep the current lane’s offset; + // otherwise, use the partner’s offset. + __m512i offsets_upper = _mm512_mask_blend_epi64(upper_cond, orig_offsets, partner_offsets); + + // Combine the two sets: for lanes designated as lower (mask bit = 0) use offsets_lower; + // for lanes designated as upper (mask bit = 1) use offsets_upper. + *offsets = _mm512_mask_blend_epi64(fixed_mask, offsets_lower, offsets_upper); + + // Validate the sorting network. + if (SZ_DEBUG) { + sz_pgram_t pgrams_array[8]; + sz_sorted_idx_t offsets_array[8]; + _mm512_storeu_si512(pgrams_array, *pgrams); + _mm512_storeu_si512(offsets_array, *offsets); + for (sz_size_t i = 1; i < 8; ++i) + _sz_assert(pgrams_array[i - 1] <= pgrams_array[i] && + "The sorting network must sort the pgrams in ascending order."); + } +} + +SZ_PUBLIC void _sz_sequence_argsort_ice_recursively( // + sz_sequence_t const *const collection, // + sz_pgram_t *const global_pgrams, sz_size_t *const global_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // + sz_size_t const start_character) { + + // Prepare the new range of windows + _sz_sequence_argsort_serial_export_next_pgrams(collection, global_pgrams, global_order, start_in_sequence, + end_in_sequence, start_character); + + // We can implement a form of a Radix sort here, that will count the number of elements with + // a certain bit set. The naive approach may require too many loops over data. A more "vectorized" + // approach would be to maintain a histogram for several bits at once. For 4 bits we will + // need 2^4 = 16 counters. + sz_size_t histogram[16] = {0}; + for (sz_size_t byte_in_window = 0; byte_in_window != sizeof(sz_pgram_t); ++byte_in_window) { + // First sort based on the low nibble of each byte. + for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { + sz_size_t const byte = (global_pgrams[i] >> (byte_in_window * 8)) & 0xFF; + ++histogram[byte]; + } + sz_size_t offset = start_in_sequence; + for (sz_size_t i = 0; i != 16; ++i) { + sz_size_t const count = histogram[i]; + histogram[i] = offset; + offset += count; + } + for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { + sz_size_t const byte = (global_pgrams[i] >> (byte_in_window * 8)) & 0xFF; + global_order[histogram[byte]] = i; + ++histogram[byte]; + } + } +} diff --git a/include/stringzilla/find.h b/include/stringzilla/find.h index b5740429..90b6a16f 100644 --- a/include/stringzilla/find.h +++ b/include/stringzilla/find.h @@ -5,7 +5,6 @@ * * Includes core APIs: * - * - `sz_equal` * - `sz_find` and reverse-order `sz_rfind` * - `sz_find_byte` and reverse-order `sz_rfind_byte` * - `sz_find_charset` and reverse-order `sz_rfind_charset` @@ -138,10 +137,10 @@ SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cpt * May have identical implementation and performance to ::sz_rfind_charset. * * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". + * - 6 whitespaces: " \t\n\r\v\f". + * - 16 digits forming a float number: "0123456789,.eE+-". + * - 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. + * - 2 JSON string special characters useful to locate the end of the string: "\"\\". * * @param text String to be scanned. * @param set Set of relevant characters. @@ -155,10 +154,10 @@ SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charse * May have identical implementation and performance to ::sz_find_charset. * * Useful for parsing, when we want to skip a set of characters. Examples: - * * 6 whitespaces: " \t\n\r\v\f". - * * 16 digits forming a float number: "0123456789,.eE+-". - * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. - * * 2 JSON string special characters useful to locate the end of the string: "\"\\". + * - 6 whitespaces: " \t\n\r\v\f". + * - 16 digits forming a float number: "0123456789,.eE+-". + * - 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. + * - 2 JSON string special characters useful to locate the end of the string: "\"\\". * * @param text String to be scanned. * @param set Set of relevant characters. diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 415b4b67..2094a0d3 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -1,13 +1,55 @@ /** - * @brief Hardware-accelerated string hashing and checksums. + * @brief Hardware-accelerated non-cryptographic string hashing and checksums. * @file hash.h * @author Ash Vardanian * * Includes core APIs: * - * - `sz_checksum` - for byte-level 64-bit unsigned checksums. - * - `sz_hash` - for 64-bit single-shot hashing. - * - `sz_generate` - populating buffers with random data. + * - `sz_bytesum` - for byte-level 64-bit unsigned byte-level checksums. + * - `sz_hash` - for 64-bit single-shot hashing using AES instructions. + * - `sz_hash_state_init`, `sz_hash_state_stream`, `sz_hash_state_fold` - for incremental hashing. + * - `sz_generate` - for populating buffers with pseudo-random noise using AES instructions. + * + * Why the hell do we need a yet another hashing library?! + * Turns out, most existing libraries have noticeable constraints. Try finding a library that: + * + * - Outputs 64-bit or 128-bit hashes and passes the SMHasher test suite. + * - Is fast for both short and long strings. + * - Supports incremental @b (streaming) hashing, when the data arrives in chunks. + * - Supports custom seeds hashes and secret strings for security. + * - Provides dynamic dispatch for different architectures to simplify deployment. + * - Uses modern SIMD, including not just AVX2 and NEON, but also AVX-512 and SVE2. + * - Documents its logic and guarantees the same output across different platforms. + * + * This includes projects like "MurmurHash", "CityHash", "SpookyHash", "FarmHash", "MetroHash", "HighwayHash", etc. + * There are 2 libraries that are close to meeting these requirements: "xxHash" in C++ and "aHash" in Rust: + * + * - "aHash" is fast, but written in Rust, has no dynamic dispatch, and lacks AVX-512 and SVE2 support. + * It also does not adhere to a fixed output, and can't be used in applications like computing packet checksums + * in network traffic or implementing persistent data structures. + * + * - "xxHash" is implemented in C, has an extremely wide set of third-party language bindings, and provides both + * 32-, 64-, and 128-bit hashes. It is fast, but its dynamic dispatch is limited to x86 with `xxh_x86dispatch.c`. + * + * StringZilla uses a scheme more similar to the "aHash" library, utilizing the AES extensions, that provide + * a remarkable level of "mixing per cycle" and are broadly available on modern CPUs. Similar to "aHash", they + * are combined with "shuffle & add" instructions to provide a high level of entropy in the output. That operation + * is practically free, as many modern CPUs will dispatch them on different ports. On x86, for example: + * + * - `VAESDEC` (ZMM, ZMM, ZMM)`: + * - on Intel Ice Lake: 5 cycles on port 0. + * - On AMD Zen4: 4 cycles on ports 0 or 1. + * - `VPSHUFB_Z (ZMM, K, ZMM, ZMM)` + * - on Intel Ice Lake: 3 cycles on port 5. + * - On AMD Zen4: 2 cycles on ports 1 or 2. + * - `VPADDQ (ZMM, ZMM, ZMM)`: + * - on Intel Ice Lake: 1 cycle on ports 0 or 5. + * - On AMD Zen4: 1 cycle on ports 0, 1, 2, 3. + * + * Unlike "aHash", on long inputs, we use a procedure that is more vector-friendly on modern servers. + * Unlike "aHash", we don't load interleaved memory regions, making vectorized variant more similar to sequential. + * On platforms like Skylake-X or newer, we also benefit from masked loads. + * */ #ifndef STRINGZILLA_HASH_H_ #define STRINGZILLA_HASH_H_ @@ -28,196 +70,205 @@ extern "C" { * @param length Number of bytes in the text. * @return 64-bit unsigned value. */ -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); +SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length); /** - * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, - * simple implementation, and supports rolling computation, reused in other APIs. - * Similar to `std::hash` in C++. + * @brief Computes the 64-bit unsigned hash of a string similar to @b `std::hash` in C++. + * It's not cryptographically secure, but it's fast and provides a good distribution. + * It passes the SMHasher suite by Austin Appleby with no collisions, even with `--extra` flag. + * @see HASH.md for a detailed explanation of the algorithm. * * @param text String to hash. * @param length Number of bytes in the text. + * @param seed 64-bit unsigned seed for the hash. * @return 64-bit hash value. */ -SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length) { - sz_unused(text && length); - return 0; -} +SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed); + +/** + * @brief A Pseudorandom Number Generator (PRNG), inspired the AES-CTR-128 algorithm, + * but using only one round of AES mixing as opposed to "NIST SP 800-90A". + * + * CTR_DRBG (CounTeR mode Deterministic Random Bit Generator) appears secure and indistinguishable from a true + * random source when AES is used as the underlying block cipher and 112 bits are taken from this PRNG. + * When AES is used as the underlying block cipher and 128 bits are taken from each instantiation, + * the required security level is delivered with the caveat that a 128-bit cipher's output in + * counter mode can be distinguished from a true RNG. + * + * In this case, it doesn't apply, as we only use one round of AES mixing. We also don't expose a separate "key", + * only a "nonce", to keep the API simple. + * + * @param text Output string buffer to be populated. + * @param length Number of bytes in the string. + * @param nonce "Number used ONCE" to ensure uniqueness of produced blocks. + */ +SZ_DYNAMIC void sz_generate(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); + +/** + * @brief The state for incremental construction of a hash. + * @see sz_hash_state_init, sz_hash_state_stream, sz_hash_state_fold. + */ +typedef struct sz_hash_state_t { + sz_u512_vec_t aes; + sz_u512_vec_t sum; + sz_u512_vec_t key; + + sz_u512_vec_t ins; + sz_size_t ins_length; +} sz_hash_state_t; + +typedef struct _sz_hash_minimal_t { + sz_u128_vec_t aes; + sz_u128_vec_t sum; + sz_u128_vec_t key; +} _sz_hash_minimal_t; + +/** + * @brief Initializes the state for incremental construction of a hash. + * + * @param state The state to initialize. + * @param seed The 64-bit unsigned seed for the hash. + */ +SZ_DYNAMIC void sz_hash_state_init(sz_hash_state_t *state, sz_u64_t seed); /** - * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. - * Similar to `text[i] = alphabet[rand() % cardinality]`. + * @brief Updates the state with new data. * - * The modulo operation is expensive, and should be avoided in performance-critical code. - * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. - * Alternative algorithms would include: - * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ - * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm - * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + * @param state The state to stream. + * @param text The new data to include in the hash. + * @param length The number of bytes in the new data. + */ +SZ_DYNAMIC void sz_hash_state_stream(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length); + +/** + * @brief Finalizes the state and returns the hash. * - * @param alphabet Set of characters to sample from. - * @param cardinality Number of characters to sample from. - * @param text Output string, can point to the same address as ::text. - * @param generate Callback producing random numbers given the generator state. - * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. + * @param state The state to fold. + * @return The 64-bit hash value. */ -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, - sz_random_generator_t generate, void *generator); +SZ_DYNAMIC sz_u64_t sz_hash_state_fold(sz_hash_state_t const *state); -/** @copydoc sz_checksum */ -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); +/** @copydoc sz_bytesum */ +SZ_PUBLIC sz_u64_t sz_bytesum_serial(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); +SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length, sz_u64_t seed); /** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial( // - sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, - void *generator) { - sz_unused(alphabet && cardinality && text && length && generate && generator); -} +SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); + +/** @copydoc sz_hash_state_init */ +SZ_PUBLIC void sz_hash_state_init_serial(sz_hash_state_t *state, sz_u64_t seed); + +/** @copydoc sz_hash_state_stream */ +SZ_PUBLIC void sz_hash_state_stream_serial(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash_state_fold */ +SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state); + +/** @copydoc sz_bytesum */ +SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash */ +SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t text, sz_size_t length, sz_u64_t seed); + +/** @copydoc sz_generate */ +SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); + +/** @copydoc sz_hash_state_init */ +SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed); + +/** @copydoc sz_hash_state_stream */ +SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash_state_fold */ +SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state); + +/** @copydoc sz_bytesum */ +SZ_PUBLIC sz_u64_t sz_bytesum_skylake(sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash */ +SZ_PUBLIC sz_u64_t sz_hash_skylake(sz_cptr_t text, sz_size_t length, sz_u64_t seed); + +/** @copydoc sz_generate */ +SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); + +/** @copydoc sz_hash_state_init */ +SZ_PUBLIC void sz_hash_state_init_skylake(sz_hash_state_t *state, sz_u64_t seed); + +/** @copydoc sz_hash_state_stream */ +SZ_PUBLIC void sz_hash_state_stream_skylake(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash_state_fold */ +SZ_PUBLIC sz_u64_t sz_hash_state_fold_skylake(sz_hash_state_t const *state); + +/** @copydoc sz_bytesum */ +SZ_PUBLIC sz_u64_t sz_bytesum_ice(sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash */ +SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t text, sz_size_t length, sz_u64_t seed); + +/** @copydoc sz_generate */ +SZ_PUBLIC void sz_generate_ice(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); + +/** @copydoc sz_hash_state_init */ +SZ_PUBLIC void sz_hash_state_init_ice(sz_hash_state_t *state, sz_u64_t seed); + +/** @copydoc sz_hash_state_stream */ +SZ_PUBLIC void sz_hash_state_stream_ice(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash_state_fold */ +SZ_PUBLIC sz_u64_t sz_hash_state_fold_ice(sz_hash_state_t const *state); + +/** @copydoc sz_bytesum */ +SZ_PUBLIC sz_u64_t sz_bytesum_neon(sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash */ +SZ_PUBLIC sz_u64_t sz_hash_neon(sz_cptr_t text, sz_size_t length, sz_u64_t seed); + +/** @copydoc sz_generate */ +SZ_PUBLIC void sz_generate_neon(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); + +/** @copydoc sz_hash_state_init */ +SZ_PUBLIC void sz_hash_state_init_neon(sz_hash_state_t *state, sz_u64_t seed); + +/** @copydoc sz_hash_state_stream */ +SZ_PUBLIC void sz_hash_state_stream_neon(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash_state_fold */ +SZ_PUBLIC sz_u64_t sz_hash_state_fold_neon(sz_hash_state_t const *state); #pragma endregion // Core API #pragma region Serial Implementation -SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { - sz_u64_t checksum = 0; +SZ_PUBLIC sz_u64_t sz_bytesum_serial(sz_cptr_t text, sz_size_t length) { + sz_u64_t bytesum = 0; sz_u8_t const *text_u8 = (sz_u8_t const *)text; sz_u8_t const *text_end = text_u8 + length; - for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; - return checksum; + for (; text_u8 != text_end; ++text_u8) bytesum += *text_u8; + return bytesum; } -/* - * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. - * Using a Boost-like mixer works very poorly in such case: - * - * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); - * - * Let's stick to the Fibonacci hash trick using the golden ratio. - * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ - */ -#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) -#define _sz_shift_low(x) (x) -#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) -#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) - -SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { - - sz_u64_t hash_low = 0; - sz_u64_t hash_high = 0; - sz_u8_t const *text = (sz_u8_t const *)start; - sz_u8_t const *text_end = text + length; - - switch (length) { - case 0: return 0; - - // Texts under 7 bytes long are definitely below the largest prime. - case 1: - hash_low = _sz_shift_low(text[0]); - hash_high = _sz_shift_high(text[0]); - break; - case 2: - hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); - hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); - break; - case 3: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull + // - _sz_shift_low(text[2]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull + // - _sz_shift_high(text[2]); - break; - case 4: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull + // - _sz_shift_low(text[3]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull + // - _sz_shift_high(text[3]); - break; - case 5: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull + // - _sz_shift_low(text[4]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull + // - _sz_shift_high(text[4]); - break; - case 6: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull + // - _sz_shift_low(text[5]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull + // - _sz_shift_high(text[5]); - break; - case 7: - hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // - _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // - _sz_shift_low(text[4]) * 31ull * 31ull + // - _sz_shift_low(text[5]) * 31ull + // - _sz_shift_low(text[6]); - hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // - _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // - _sz_shift_high(text[4]) * 257ull * 257ull + // - _sz_shift_high(text[5]) * 257ull + // - _sz_shift_high(text[6]); - break; - default: - // Unroll the first seven cycles: - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - hash_low = hash_low * 31ull + _sz_shift_low(text[1]); - hash_high = hash_high * 257ull + _sz_shift_high(text[1]); - hash_low = hash_low * 31ull + _sz_shift_low(text[2]); - hash_high = hash_high * 257ull + _sz_shift_high(text[2]); - hash_low = hash_low * 31ull + _sz_shift_low(text[3]); - hash_high = hash_high * 257ull + _sz_shift_high(text[3]); - hash_low = hash_low * 31ull + _sz_shift_low(text[4]); - hash_high = hash_high * 257ull + _sz_shift_high(text[4]); - hash_low = hash_low * 31ull + _sz_shift_low(text[5]); - hash_high = hash_high * 257ull + _sz_shift_high(text[5]); - hash_low = hash_low * 31ull + _sz_shift_low(text[6]); - hash_high = hash_high * 257ull + _sz_shift_high(text[6]); - text += 7; - - // Iterate throw the rest with the modulus: - for (; text != text_end; ++text) { - hash_low = hash_low * 31ull + _sz_shift_low(text[0]); - hash_high = hash_high * 257ull + _sz_shift_high(text[0]); - // Wrap the hashes around: - hash_low = _sz_prime_mod(hash_low); - hash_high = _sz_prime_mod(hash_high); - } - break; - } +SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { + sz_unused(start && length && seed); + return 0; +} + +SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_unused(text && length && nonce); +} + +SZ_PUBLIC void sz_hash_state_init_serial(sz_hash_state_t *state, sz_u64_t seed) { sz_unused(state && seed); } - return _sz_hash_mix(hash_low, hash_high); +SZ_PUBLIC void sz_hash_state_stream_serial(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + sz_unused(state && text && length); } -#undef _sz_shift_low -#undef _sz_shift_high -#undef _sz_hash_mix -#undef _sz_prime_mod +SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state) { + sz_unused(state); + return 0; +} #pragma endregion // Serial Implementation @@ -228,9 +279,9 @@ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { #if SZ_USE_HASWELL #pragma GCC push_options #pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) +#pragma clang attribute push(__attribute__((target("avx3332"))), apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". // @@ -240,7 +291,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { int is_huge = length > 1ull * 1024ull * 1024ull; // When the buffer is small, there isn't much to innovate. - if (length <= 32) { return sz_checksum_serial(text, length); } + if (length <= 32) { return sz_bytesum_serial(text, length); } else if (!is_huge) { sz_u256_vec_t text_vec, sums_vec; sums_vec.ymm = _mm256_setzero_si256(); @@ -248,6 +299,9 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); } + // We can also avoid the final serial loop by fetching 32 bytes from end, in reverse direction, + // and shifting the data within the register to zero-out the duplicate bytes. + // Accumulating 256 bits is harder, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); @@ -255,7 +309,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); sz_u64_t result = low + high; - if (length) result += sz_checksum_serial(text, length); + if (length) result += sz_bytesum_serial(text, length); return result; } // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. @@ -311,6 +365,24 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { } } +SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { + return sz_hash_serial(text, length, seed); +} + +SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_generate_serial(text, length, nonce); +} + +SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) { + sz_hash_state_init_serial(state, seed); +} + +SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + sz_hash_state_stream_serial(state, text, length); +} + +SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_HASWELL @@ -327,7 +399,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_haswell(sz_cptr_t text, sz_size_t length) { #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_skylake(sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC sz_u64_t sz_bytesum_skylake(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". // @@ -427,6 +499,24 @@ SZ_PUBLIC sz_u64_t sz_checksum_skylake(sz_cptr_t text, sz_size_t length) { } } +SZ_PUBLIC sz_u64_t sz_hash_skylake(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { + return sz_hash_serial(text, length, seed); +} + +SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_generate_serial(text, length, nonce); +} + +SZ_PUBLIC void sz_hash_state_init_skylake(sz_hash_state_t *state, sz_u64_t seed) { + sz_hash_state_init_serial(state, seed); +} + +SZ_PUBLIC void sz_hash_state_stream_skylake(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + sz_hash_state_stream_serial(state, text, length); +} + +SZ_PUBLIC sz_u64_t sz_hash_state_fold_skylake(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SKYLAKE @@ -441,12 +531,13 @@ SZ_PUBLIC sz_u64_t sz_checksum_skylake(sz_cptr_t text, sz_size_t length) { #pragma region Ice Lake Implementation #if SZ_USE_ICE #pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "avx512vnni", "bmi", "bmi2") -#pragma clang attribute push( \ - __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vnni,bmi,bmi2"))), \ +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "avx512vnni", "bmi", "bmi2", \ + "aes", "vaes") +#pragma clang attribute push( \ + __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vnni,bmi,bmi2,aes,vaes"))), \ apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC sz_u64_t sz_bytesum_ice(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". // @@ -572,6 +663,230 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { } } +SZ_INTERNAL void _sz_hash_minimal_init_haswell(_sz_hash_minimal_t *state, sz_u64_t seed) { + __m128i seed_vec = _mm_set1_epi64x(seed); + __m128i pi0 = _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull); + __m128i pi1 = _mm_set_epi64x(0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + // XOR the user-supplied keys with the two "pi" constants + __m128i k1 = _mm_xor_si128(seed_vec, pi0); + __m128i k2 = _mm_xor_si128(seed_vec, pi1); + // Export the keys to the state + state->aes.xmm = k1; + state->sum.xmm = k2; + state->key.xmm = _mm_xor_si128(pi0, pi1); +} + +SZ_INTERNAL sz_u64_t _sz_hash_minimal_finalize_haswell(_sz_hash_minimal_t const *state) { + // Combine the sum and the AES block + __m128i mixed_registers = _mm_aesenc_si128(state->sum.xmm, state->aes.xmm); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + __m128i mixed_within_register = + _mm_aesdec_si128(_mm_aesdec_si128(mixed_registers, state->key.xmm), mixed_registers); + // Extract the low 64 bits + return _mm_cvtsi128_si64(mixed_within_register); +} + +SZ_INTERNAL void _sz_hash_minimal_update_haswell(_sz_hash_minimal_t *state, __m128i block) { + // This shuffle mask is identical to "aHash": + __m128i const shuffle_mask = _mm_set_epi8( // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02); + state->aes.xmm = _mm_aesdec_si128(state->aes.xmm, block); + state->sum.xmm = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmm, shuffle_mask), block); +} + +SZ_PUBLIC void sz_hash_state_init_ice(sz_hash_state_t *state, sz_u64_t seed) { + __m512i seed_vec = _mm512_set1_epi64(seed); + __m512i pi0 = _mm512_set_epi64( // + 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, + 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull); + __m512i pi1 = _mm512_set_epi64( // + 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, + 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + // XOR the user-supplied keys with the two "pi" constants + __m512i k1 = _mm512_xor_si512(seed_vec, pi0); + __m512i k2 = _mm512_xor_si512(seed_vec, pi1); + // Export the keys to the state + state->aes.zmm = k1; + state->sum.zmm = k2; + state->key.zmm = _mm512_xor_si512(pi0, pi1); + state->ins_length = 0; +} + +SZ_INTERNAL void _sz_hash_state_update_ice(sz_hash_state_t *state, __m512i block) { + // This shuffle mask is identical to "aHash": + __m512i const shuffle_mask = _mm512_set_epi8( // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02 // + ); + state->aes.zmm = _mm512_aesdec_epi128(state->aes.zmm, block); + state->sum.zmm = _mm512_add_epi64(_mm512_shuffle_epi8(state->sum.zmm, shuffle_mask), block); +} + +SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_ice(sz_hash_state_t const *state) { + // Combine the sum and the AES block + __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.xmms[0], state->aes.xmms[0]); + __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.xmms[1], state->aes.xmms[1]); + __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.xmms[2], state->aes.xmms[2]); + __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.xmms[3], state->aes.xmms[3]); + // Combine the mixed registers + __m128i mixed_registers01 = _mm_aesenc_si128(mixed_registers0, mixed_registers1); + __m128i mixed_registers23 = _mm_aesenc_si128(mixed_registers2, mixed_registers3); + __m128i mixed_registers = _mm_aesenc_si128(mixed_registers01, mixed_registers23); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + __m128i mixed_within_register = _mm_aesdec_si128( // + _mm_aesdec_si128(mixed_registers, state->key.xmms[0]), mixed_registers); + // Extract the low 64 bits + return _mm_cvtsi128_si64(mixed_within_register); +} + +SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { + + if (length <= 16) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data_vec; + data_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length), start); + _sz_hash_minimal_update_haswell(&state, data_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 32) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec; + data0_vec.xmm = _mm_loadu_epi8(start); + data1_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 16), start + 16); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 48) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec; + data0_vec.xmm = _mm_loadu_epi8(start); + data1_vec.xmm = _mm_loadu_epi8(start + 16); + data2_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 32), start + 32); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 64) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; + data0_vec.xmm = _mm_loadu_epi8(start); + data1_vec.xmm = _mm_loadu_epi8(start + 16); + data2_vec.xmm = _mm_loadu_epi8(start + 32); + data3_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 48), start + 48); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data3_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else { + // Use a larger state to handle the main loop and add different offsets + // to different lanes of the register + sz_hash_state_t state; + sz_hash_state_init_ice(&state, seed); + state.aes.zmm = _mm512_add_epi64( // + state.aes.zmm, // + _mm512_set_epi64(0, length, 16, length, 32, length, 48, length)); + + for (; state.ins_length + 64 <= length; state.ins_length += 64) { + state.ins.zmm = _mm512_loadu_epi8(start + state.ins_length); + _sz_hash_state_update_ice(&state, state.ins.zmm); + } + if (state.ins_length < length) { + state.ins.zmm = _mm512_maskz_loadu_epi8( // + _sz_u64_mask_until(length - state.ins_length), start + state.ins_length); + _sz_hash_state_update_ice(&state, state.ins.zmm); + } + return _sz_hash_state_finalize_ice(&state); + } +} + +SZ_PUBLIC void sz_generate_ice(sz_ptr_t output, sz_size_t length, sz_u64_t nonce) { + // We can use `_mm512_broadcast_i32x4` and the `vbroadcasti32x4` instruction, but its latency is freaking 8 cycles. + // The `_mm512_shuffle_i32x4` and the `vshufi32x4` instruction has a latency of 3 cycles, somewhat better. + // The `_mm512_permutex_epi64` and the `vpermq` instruction also has a latency of 3 cycles. + // So we want to avoid that, if possible. + __m128i nonce_vec = _mm_set1_epi64x(nonce); + __m128i key128 = _mm_xor_si128(nonce_vec, _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull)); + if (length <= 16) { + __mmask16 mask = _sz_u16_mask_until(length); + __m128i input = _mm_set1_epi64x(nonce); + __m128i generated = _mm_aesenc_si128(input, key128); + _mm_mask_storeu_epi8((void *)output, mask, generated); + } + // Assuming the YMM register contains two 128-bit blocks, the input to the generator + // will be more complex, containing the sum of the nonce and the block number. + else if (length <= 32) { + __mmask32 mask = _sz_u32_mask_until(length); + __m256i input = _mm256_set_epi64x(nonce + 1, nonce + 1, nonce, nonce); + __m256i key256 = + _mm256_permute2x128_si256(_mm256_castsi128_si256(key128), _mm256_castsi128_si256(key128), 0x00); + __m256i generated = _mm256_aesenc_epi128(input, key256); + _mm256_mask_storeu_epi8((void *)output, mask, generated); + } + // The last special case we handle outside of the primary loop is for buffers up to 64 bytes long. + else if (length <= 64) { + __mmask64 mask = _sz_u64_mask_until(length); + __m512i input = _mm512_set_epi64( // + nonce + 3, nonce + 3, nonce + 2, nonce + 2, // + nonce + 1, nonce + 1, nonce, nonce); + __m512i key512 = _mm512_permutex_epi64(_mm512_castsi128_si512(key128), 0x00); + __m512i generated = _mm512_aesenc_epi128(input, key512); + _mm512_mask_storeu_epi8((void *)output, mask, generated); + } + // The final part of the function is the primary loop, which processes the buffer in 64-byte chunks. + else { + __m512i increment = _mm512_set1_epi64(4); + __m512i input = _mm512_set_epi64( // + nonce + 3, nonce + 3, nonce + 2, nonce + 2, // + nonce + 1, nonce + 1, nonce, nonce); + __m512i key512 = _mm512_permutex_epi64(_mm512_castsi128_si512(key128), 0x00); + sz_size_t i = 0; + for (; i + 64 <= length; i += 64) { + __m512i generated = _mm512_aesenc_epi128(input, key512); + _mm512_storeu_epi8((void *)(output + i), generated); + input = _mm512_add_epi64(input, increment); + } + // Handle the tail of the buffer. + __mmask64 mask = _sz_u64_mask_until(length - i); + __m512i generated = _mm512_aesenc_epi128(input, key512); + _mm512_mask_storeu_epi8((void *)(output + i), mask, generated); + } +} + +SZ_PUBLIC void sz_hash_state_stream_ice(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + sz_hash_state_stream_serial(state, text, length); +} + +SZ_PUBLIC sz_u64_t sz_hash_state_fold_ice(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_ICE @@ -586,7 +901,7 @@ SZ_PUBLIC sz_u64_t sz_checksum_ice(sz_cptr_t text, sz_size_t length) { #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) -SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC sz_u64_t sz_bytesum_neon(sz_cptr_t text, sz_size_t length) { uint64x2_t sum_vec = vdupq_n_u64(0); // Process 16 bytes (128 bits) at a time @@ -600,10 +915,20 @@ SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { // Final reduction of `sum_vec` to a single scalar sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_checksum_serial(text, length); + if (length) sum += sz_bytesum_serial(text, length); return sum; } +SZ_PUBLIC void sz_hash_state_init_neon(sz_hash_state_t *state, sz_u64_t seed) { + sz_hash_state_init_serial(state, seed); +} + +SZ_PUBLIC void sz_hash_state_stream_neon(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + sz_hash_state_stream_serial(state, text, length); +} + +SZ_PUBLIC sz_u64_t sz_hash_state_fold_neon(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_NEON @@ -629,23 +954,88 @@ SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { +SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length) { +#if SZ_USE_ICE + return sz_bytesum_ice(text, length); +#elif SZ_USE_SKYLAKE + return sz_bytesum_skylake(text, length); +#elif SZ_USE_HASWELL + return sz_bytesum_haswell(text, length); +#elif SZ_USE_NEON + return sz_bytesum_neon(text, length); +#else + return sz_bytesum_serial(text, length); +#endif +} + +SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { +#if SZ_USE_ICE + return sz_hash_ice(text, length, seed); +#elif SZ_USE_SKYLAKE + return sz_hash_skylake(text, length, seed); +#elif SZ_USE_HASWELL + return sz_hash_haswell(text, length, seed); +#elif SZ_USE_NEON + return sz_hash_neon(text, length, seed); +#else + return sz_hash_serial(text, length, seed); +#endif +} + +SZ_DYNAMIC void sz_generate(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { #if SZ_USE_ICE - return sz_checksum_ice(text, length); + sz_generate_ice(text, length, nonce); #elif SZ_USE_SKYLAKE - return sz_checksum_skylake(text, length); + sz_generate_skylake(text, length, nonce); #elif SZ_USE_HASWELL - return sz_checksum_haswell(text, length); + sz_generate_haswell(text, length, nonce); #elif SZ_USE_NEON - return sz_checksum_neon(text, length); + sz_generate_neon(text, length, nonce); #else - return sz_checksum_serial(text, length); + sz_generate_serial(text, length, nonce); #endif } -SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, - sz_random_generator_t generator, void *generator_user_data) { - sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); +SZ_DYNAMIC void sz_hash_state_init(sz_hash_state_t *state, sz_u64_t seed) { +#if SZ_USE_ICE + sz_hash_state_init_ice(state, seed); +#elif SZ_USE_SKYLAKE + sz_hash_state_init_skylake(state, seed); +#elif SZ_USE_HASWELL + sz_hash_state_init_haswell(state, seed); +#elif SZ_USE_NEON + sz_hash_state_init_neon(state, seed); +#else + sz_hash_state_init_serial(state, seed); +#endif +} + +SZ_DYNAMIC void sz_hash_state_stream(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { +#if SZ_USE_ICE + sz_hash_state_stream_ice(state, text, length); +#elif SZ_USE_SKYLAKE + sz_hash_state_stream_skylake(state, text, length); +#elif SZ_USE_HASWELL + sz_hash_state_stream_haswell(state, text, length); +#elif SZ_USE_NEON + sz_hash_state_stream_neon(state, text, length); +#else + sz_hash_state_stream_serial(state, text, length); +#endif +} + +SZ_DYNAMIC sz_u64_t sz_hash_state_fold(sz_hash_state_t const *state) { +#if SZ_USE_ICE + return sz_hash_state_fold_ice(state); +#elif SZ_USE_SKYLAKE + return sz_hash_state_fold_skylake(state); +#elif SZ_USE_HASWELL + return sz_hash_state_fold_haswell(state); +#elif SZ_USE_NEON + return sz_hash_state_fold_neon(state); +#else + return sz_hash_state_fold_serial(state); +#endif } #endif // !SZ_DYNAMIC_DISPATCH diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 977d29e1..4e1a6377 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -31,6 +31,7 @@ #include "types.h" #include "compare.h" // `sz_compare` +#include "memory.h" // `sz_copy` #ifdef __cplusplus extern "C" { @@ -414,19 +415,19 @@ SZ_INTERNAL void _sz_sequence_argsort_serial_export_next_pgrams( * @brief Picks the "pivot" value for the QuickSort algorithm's partitioning step using Robert Sedgewick's method, * the median of three elements - the first, the middle, and the last element of the given range. */ -SZ_INTERNAL sz_pgram_t _sz_sequence_partitioning_pivot(sz_pgram_t const *pgrams, sz_size_t count) { +SZ_INTERNAL sz_pgram_t const *_sz_sequence_partitioning_pivot(sz_pgram_t const *pgrams, sz_size_t count) { sz_size_t const middle_offset = count / 2; - sz_pgram_t const first_pgram = pgrams[0]; - sz_pgram_t const middle_pgram = pgrams[middle_offset]; - sz_pgram_t const last_pgram = pgrams[count - 1]; - if (first_pgram < middle_pgram) { - if (middle_pgram < last_pgram) { return middle_pgram; } - else if (first_pgram < last_pgram) { return last_pgram; } + sz_pgram_t const *first_pgram = &pgrams[0]; + sz_pgram_t const *middle_pgram = &pgrams[middle_offset]; + sz_pgram_t const *last_pgram = &pgrams[count - 1]; + if (*first_pgram < *middle_pgram) { + if (*middle_pgram < *last_pgram) { return middle_pgram; } + else if (*first_pgram < *last_pgram) { return last_pgram; } else { return first_pgram; } } else { - if (first_pgram < last_pgram) { return first_pgram; } - else if (middle_pgram < last_pgram) { return last_pgram; } + if (*first_pgram < *last_pgram) { return first_pgram; } + else if (*middle_pgram < *last_pgram) { return last_pgram; } else { return middle_pgram; } } } @@ -440,7 +441,7 @@ SZ_INTERNAL sz_pgram_t _sz_sequence_partitioning_pivot(sz_pgram_t const *pgrams, * * @see https://en.wikipedia.org/wiki/Dutch_national_flag_problem */ -SZ_PUBLIC void _sz_sequence_argsort_serial_3way_partition( // +SZ_INTERNAL void _sz_sequence_argsort_serial_3way_partition( // sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // sz_size_t *first_pivot_offset, sz_size_t *last_pivot_offset) { @@ -459,7 +460,7 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_3way_partition( // } // Chose the pivot offset with Sedgewick's method. - sz_pgram_t const pivot_pgram = _sz_sequence_partitioning_pivot(global_pgrams + start_in_sequence, count); + sz_pgram_t const pivot_pgram = *_sz_sequence_partitioning_pivot(global_pgrams + start_in_sequence, count); // Loop through the collection and move the elements around the pivot with the 3-way partitioning. sz_size_t partitioning_progress = start_in_sequence; // Current index. @@ -492,7 +493,7 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_3way_partition( // * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort` and `sz_pgrams_sort`, * and using the `_sz_sequence_argsort_serial_3way_partition` under the hood. */ -SZ_INTERNAL void _sz_sequence_argsort_serial_recursively( // +SZ_PUBLIC void _sz_sequence_argsort_serial_recursively( // sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence) { @@ -517,7 +518,7 @@ SZ_INTERNAL void _sz_sequence_argsort_serial_recursively( // * It combines `_sz_sequence_argsort_serial_export_next_pgrams` and `_sz_sequence_argsort_serial_recursively`, * recursively diving into the identical pgrams. */ -SZ_INTERNAL void _sz_sequence_argsort_serial_next_pgrams( // +SZ_PUBLIC void _sz_sequence_argsort_serial_next_pgrams( // sz_sequence_t const *const sequence, // sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // @@ -735,43 +736,177 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c #pragma endregion // Serial MergeSort Implementation +/* AVX512 implementation of the string search algorithms for Ice Lake and newer CPUs. + * Includes extensions: + * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, + * - 2018 CannonLake: IFMA, VBMI, + * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. + * + * We are going to use VBMI2 for `_mm256_maskz_compress_epi8`. + */ #pragma region Ice Lake Implementation +#if SZ_USE_ICE +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") +#pragma clang attribute push( \ + __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ + apply_to = function) -SZ_PUBLIC void _sz_sequence_argsort_ice_recursively( // - sz_sequence_t const *const collection, // - sz_pgram_t *const global_pgrams, sz_size_t *const global_order, // - sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // - sz_size_t const start_character) { +/** + * @brief The most important part of the QuickSort algorithm partitioning the elements around the pivot. + * Unlike the serial algorithm, uses compressed stores to filter and move the elements around the pivot. + * Assuming the extreme cost of shuffling between 2 ZMM registers based on 2 different masks, we use + * extra memory to store the elements smaller and greater than the pivot somewhere else. + */ +SZ_INTERNAL void _sz_sequence_argsort_ice_2way_partition( // + sz_pgram_t *const initial_pgrams, sz_sorted_idx_t *const initial_order, // + sz_pgram_t *const partitioned_pgrams, sz_sorted_idx_t *const partitioned_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // + sz_size_t *const first_pivot_offset, sz_size_t *const last_pivot_offset) { - // Prepare the new range of windows - _sz_sequence_argsort_serial_export_next_pgrams(collection, global_pgrams, global_order, start_in_sequence, - end_in_sequence, start_character); + sz_size_t const count = end_in_sequence - start_in_sequence; + sz_size_t const pgrams_per_register = sizeof(sz_u512_vec_t) / sizeof(sz_pgram_t); - // We can implement a form of a Radix sort here, that will count the number of elements with - // a certain bit set. The naive approach may require too many loops over data. A more "vectorized" - // approach would be to maintain a histogram for several bits at once. For 4 bits we will - // need 2^4 = 16 counters. - sz_size_t histogram[16] = {0}; - for (sz_size_t byte_in_window = 0; byte_in_window != sizeof(sz_pgram_t); ++byte_in_window) { - // First sort based on the low nibble of each byte. - for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { - sz_size_t const byte = (global_pgrams[i] >> (byte_in_window * 8)) & 0xFF; - ++histogram[byte]; - } - sz_size_t offset = start_in_sequence; - for (sz_size_t i = 0; i != 16; ++i) { - sz_size_t const count = histogram[i]; - histogram[i] = offset; - offset += count; + // Choose the pivot offset with Sedgewick's method. + sz_pgram_t const *pivot_pgram_ptr = _sz_sequence_partitioning_pivot(initial_order + start_in_sequence, count); + sz_pgram_t const pivot_pgram = *pivot_pgram_ptr; + sz_u512_vec_t pivot_vec; + pivot_vec.zmm = _mm512_set1_epi64(pivot_pgram); + + // Reading data is always cheaper than writing, so we can further minimize the writes, if + // we know exactly, how many elements are smaller or greater than the pivot. + sz_size_t count_smaller = 0, count_greater = 0; + sz_size_t const tail_count = count & 7u; + __mmask8 const tail_mask = _sz_u8_mask_until(tail_count); + + sz_u512_vec_t pgrams_vec, order_vec; + for (sz_size_t i = start_in_sequence; i < end_in_sequence; i += pgrams_per_register) { + pgrams_vec.zmm = // + i + pgrams_per_register <= end_in_sequence // + ? _mm512_loadu_si512(initial_pgrams + i) + : _mm512_maskz_loadu_epi64(tail_mask, initial_pgrams + i); + count_smaller += sz_u32_popcount(_mm512_cmplt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm)); + count_greater += sz_u32_popcount(_mm512_cmpgt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm)); + } + + // Now all we need to do is to loop through the collection and export them into the temporary buffer + // in 3 separate segments - smaller, equal, and greater than the pivot. + sz_size_t const count_equal = count - count_smaller - count_greater; + sz_size_t smaller_offset = start_in_sequence; + sz_size_t equal_offset = start_in_sequence + count_smaller; + sz_size_t greater_offset = start_in_sequence + count_smaller + count_equal; + + // The naive algorithm - unzip the elements into 3 separate buffers. + for (sz_size_t i = start_in_sequence; i < end_in_sequence; i += pgrams_per_register) { + if (i + pgrams_per_register <= end_in_sequence) { + pgrams_vec.zmm = _mm512_loadu_si512(initial_pgrams + i); + order_vec.zmm = _mm512_loadu_si512(initial_order + i); } - for (sz_size_t i = start_in_sequence; i < end_in_sequence; ++i) { - sz_size_t const byte = (global_pgrams[i] >> (byte_in_window * 8)) & 0xFF; - global_order[histogram[byte]] = i; - ++histogram[byte]; + else { + pgrams_vec.zmm = _mm512_maskz_loadu_epi64(tail_count, initial_pgrams + i); + order_vec.zmm = _mm512_maskz_loadu_epi64(tail_count, initial_order + i); } + pgrams_vec.zmm = _mm512_loadu_si512(initial_pgrams + i); + order_vec.zmm = _mm512_loadu_si512(initial_order + i); + __mmask8 const smaller_mask = _mm512_cmplt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm); + __mmask8 const equal_mask = _mm512_cmpeq_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm); + __mmask8 const greater_mask = _mm512_cmpgt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm); + + // Compress the elements into the temporary buffer. + _mm512_mask_compressstoreu_epi64(partitioned_pgrams + smaller_offset, smaller_mask, pgrams_vec.zmm); + _mm512_mask_compressstoreu_epi64(partitioned_order + smaller_offset, smaller_mask, order_vec.zmm); + smaller_offset += _mm_popcnt_u32(smaller_mask); + + _mm512_mask_compressstoreu_epi64(partitioned_pgrams + equal_offset, equal_mask, pgrams_vec.zmm); + _mm512_mask_compressstoreu_epi64(partitioned_order + equal_offset, equal_mask, order_vec.zmm); + equal_offset += _mm_popcnt_u32(equal_mask); + + _mm512_mask_compressstoreu_epi64(partitioned_pgrams + greater_offset, greater_mask, pgrams_vec.zmm); + _mm512_mask_compressstoreu_epi64(partitioned_order + greater_offset, greater_mask, order_vec.zmm); + greater_offset += _mm_popcnt_u32(greater_mask); + } + + // Copy back. + sz_copy((sz_ptr_t)(initial_pgrams), (sz_cptr_t)(partitioned_pgrams), count_smaller * sizeof(sz_pgram_t)); + sz_copy((sz_ptr_t)(initial_order), (sz_cptr_t)(partitioned_order), count_smaller * sizeof(sz_pgram_t)); + sz_copy((sz_ptr_t)(initial_pgrams + count_smaller), // + (sz_cptr_t)(partitioned_pgrams + count_smaller), // + count_equal * sizeof(sz_pgram_t)); + sz_copy((sz_ptr_t)(initial_order + count_smaller), // + (sz_cptr_t)(partitioned_order + count_smaller), // + count_equal * sizeof(sz_pgram_t)); + sz_copy((sz_ptr_t)(initial_pgrams + count_smaller + count_equal), // + (sz_cptr_t)(partitioned_pgrams + count_smaller + count_equal), // + count_greater); + sz_copy((sz_ptr_t)(initial_order + count_smaller + count_equal), // + (sz_cptr_t)(partitioned_order + count_smaller + count_equal), // + count_greater); + + // Return the offsets of the equal elements. + *first_pivot_offset = count_smaller; + *last_pivot_offset = count_smaller + count_equal; +} + +/** + * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort_ice` and `sz_pgrams_sort_ice`, + * and using the `_sz_sequence_argsort_ice_2way_partition` under the hood. + */ +SZ_INTERNAL void _sz_sequence_argsort_ice_recursively( // + sz_pgram_t *initial_pgrams, sz_sorted_idx_t *initial_order, // + sz_pgram_t *temporary_pgrams, sz_sorted_idx_t *temporary_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence) { + + // On very small inputs, when we don't even have enough input for a single ZMM register, + // use simple insertion sort without any extra memory. + sz_size_t const count = end_in_sequence - start_in_sequence; + sz_size_t const pgrams_per_register = sizeof(sz_u512_vec_t) / sizeof(sz_pgram_t); + if (count <= pgrams_per_register) { + sz_pgrams_sort_stable_with_insertion( // + initial_pgrams + start_in_sequence, count, initial_order + start_in_sequence); + return; } + + // Partition the collection around some pivot + sz_size_t first_pivot_index, last_pivot_index; + _sz_sequence_argsort_ice_2way_partition( // + initial_pgrams, initial_order, temporary_pgrams, temporary_order, // + start_in_sequence, end_in_sequence, // + &first_pivot_index, &last_pivot_index); + + // Recursively sort the left and right partitions, tracking where the output goes + if (start_in_sequence < first_pivot_index) + _sz_sequence_argsort_ice_recursively( // + initial_pgrams, initial_order, temporary_pgrams, temporary_order, // + start_in_sequence, first_pivot_index); + if (last_pivot_index + 1 < end_in_sequence) + _sz_sequence_argsort_ice_recursively( // + initial_pgrams, initial_order, temporary_pgrams, temporary_order, // + last_pivot_index + 1, end_in_sequence); +} + +SZ_PUBLIC sz_bool_t sz_pgrams_sort_ice(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + + // First, initialize the `order` with `std::iota`-like behavior. + for (sz_size_t i = 0; i != count; ++i) order[i] = i; + + // Allocate memory for partitioning the elements around the pivot. + sz_size_t memory_usage = sizeof(sz_pgram_t) * count + sizeof(sz_sorted_idx_t) * count; + sz_pgram_t *temporary_pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); + sz_sorted_idx_t *temporary_order = (sz_sorted_idx_t *)(temporary_pgrams + count); + if (!temporary_pgrams) return sz_false_k; + + // Reuse the string sorting algorithm for sorting the "pgrams". + _sz_sequence_argsort_ice_recursively(pgrams, order, temporary_pgrams, temporary_order, 0, count); + + // Deallocate the temporary memory used for partitioning. + alloc->free(temporary_pgrams, memory_usage, alloc); + return sz_true_k; } +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SZ_USE_ICE #pragma endregion // Ice Lake Implementation /* Pick the right implementation for the string search algorithms. diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 12f65265..3f5466b6 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -1926,7 +1926,9 @@ class basic_string_slice { #pragma endregion /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ - size_type hash() const noexcept { return static_cast(sz_hash(start_, length_)); } + size_type hash(std::uint64_t seed = 42) const noexcept { + return static_cast(sz_hash(start_, length_, static_cast(seed))); + } /** @brief Aggregates the values of individual bytes of a string. */ size_type bytesum() const noexcept { return static_cast(sz_bytesum(start_, length_)); } @@ -2795,7 +2797,7 @@ class basic_string { } /** - * @brief Erases ( @b in-place ) a range of characters defined with signed offsets. + * @brief Erases @b (in-place) a range of characters defined with signed offsets. * @return Number of characters removed. */ size_type try_erase(difference_type signed_start_offset = 0, difference_type signed_end_offset = npos) noexcept { @@ -2807,7 +2809,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) a range of characters at a given signed offset. + * @brief Inserts @b (in-place) a range of characters at a given signed offset. * @return `true` if the insertion was successful, `false` otherwise. */ bool try_insert(difference_type signed_offset, string_view string) noexcept { @@ -2823,7 +2825,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @return `true` if the replacement was successful, `false` otherwise. */ bool try_replace(difference_type signed_start_offset, difference_type signed_end_offset, @@ -2874,7 +2876,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) a ::character multiple times at the given offset. + * @brief Inserts @b (in-place) a ::character multiple times at the given offset. * @throw `std::out_of_range` if `offset > size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2890,7 +2892,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) a range of characters at the given offset. + * @brief Inserts @b (in-place) a range of characters at the given offset. * @throw `std::out_of_range` if `offset > size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2907,7 +2909,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) a range of characters at the given offset. + * @brief Inserts @b (in-place) a range of characters at the given offset. * @throw `std::out_of_range` if `offset > size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2917,7 +2919,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) a slice of another string at the given offset. + * @brief Inserts @b (in-place) a slice of another string at the given offset. * @throw `std::out_of_range` if `offset > size()` or `other_index > other.size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2928,7 +2930,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) one ::character at the given iterator position. + * @brief Inserts @b (in-place) one ::character at the given iterator position. * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2940,7 +2942,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) a ::character multiple times at the given iterator position. + * @brief Inserts @b (in-place) a ::character multiple times at the given iterator position. * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2952,7 +2954,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) a range at the given iterator position. + * @brief Inserts @b (in-place) a range at the given iterator position. * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2975,7 +2977,7 @@ class basic_string { } /** - * @brief Inserts ( @b in-place ) an initializer list of characters. + * @brief Inserts @b (in-place) an initializer list of characters. * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. * @throw `std::length_error` if the string is too long. * @throw `std::bad_alloc` if the allocation fails. @@ -2985,7 +2987,7 @@ class basic_string { } /** - * @brief Erases ( @b in-place ) the given range of characters. + * @brief Erases @b (in-place) the given range of characters. * @throws `std::out_of_range` if `pos > size()`. * @see `try_erase_slice` for a cleaner exception-less alternative. */ @@ -2997,7 +2999,7 @@ class basic_string { } /** - * @brief Erases ( @b in-place ) the given range of characters. + * @brief Erases @b (in-place) the given range of characters. * @return Iterator pointing following the erased character, or end() if no such character exists. */ iterator erase(const_iterator first, const_iterator last) noexcept { @@ -3008,13 +3010,13 @@ class basic_string { } /** - * @brief Erases ( @b in-place ) the one character at a given postion. + * @brief Erases @b (in-place) the one character at a given postion. * @return Iterator pointing following the erased character, or end() if no such character exists. */ iterator erase(const_iterator pos) noexcept { return erase(pos, pos + 1); } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3028,7 +3030,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3038,7 +3040,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()` or `pos2 > str.size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3049,7 +3051,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3059,7 +3061,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3070,7 +3072,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3080,7 +3082,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3090,7 +3092,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a repetition of given characters. + * @brief Replaces @b (in-place) a range of characters with a repetition of given characters. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3104,7 +3106,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a repetition of given characters. + * @brief Replaces @b (in-place) a range of characters with a repetition of given characters. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3115,7 +3117,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3134,7 +3136,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) a range of characters with a given initializer list. + * @brief Replaces @b (in-place) a range of characters with a given initializer list. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. * @see `try_replace` for a cleaner exception-less alternative. @@ -3332,13 +3334,12 @@ class basic_string { * @brief Overwrites the string with random binary data. * * @param nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! - * @param key A 128-bit key to initialize the AES-CTR block-cypher, zeros by default. */ - basic_string &randomize(sz_u64_t nonce, sz_aes128_block_t key = {}) noexcept { + basic_string &randomize(sz_u64_t nonce) noexcept { sz_ptr_t start; sz_size_t length; sz_string_range(&string_, &start, &length); - sz_generate(start, length, nonce, &key); + sz_generate(start, length, nonce); return *this; } @@ -3349,7 +3350,7 @@ class basic_string { */ basic_string &randomize() noexcept { static sz_u64_t nonce = 42; - return randomize(nonce++, {}); + return randomize(nonce++); } /** @@ -3372,7 +3373,7 @@ class basic_string { static basic_string random(size_type length) noexcept(false) { return basic_string(length, '\0').randomize(); } /** - * @brief Replaces ( @b in-place ) all occurrences of a given string with the ::replacement string. + * @brief Replaces @b (in-place) all occurrences of a given string with the ::replacement string. * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, @@ -3385,7 +3386,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) all occurrences of a given character set with the ::replacement string. + * @brief Replaces @b (in-place) all occurrences of a given character set with the ::replacement string. * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, @@ -3398,7 +3399,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) all occurrences of a given string with the ::replacement string. + * @brief Replaces @b (in-place) all occurrences of a given string with the ::replacement string. * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, @@ -3410,7 +3411,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) all occurrences of a given character set with the ::replacement string. + * @brief Replaces @b (in-place) all occurrences of a given character set with the ::replacement string. * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, @@ -3422,7 +3423,7 @@ class basic_string { } /** - * @brief Replaces ( @b in-place ) all characters in the string using the provided lookup table. + * @brief Replaces @b (in-place) all characters in the string using the provided lookup table. */ basic_string &transform(look_up_table const &table) noexcept { transform(table, data()); @@ -3917,20 +3918,16 @@ std::ptrdiff_t alignment_score( * @brief Overwrites the string slice with random characters from the given alphabet using the random generator. * * @param string The string to overwrite. - * @param generator A random generator function object that returns a random number in the range [0, 2^64). - * @param alphabet A string of characters to choose from. + * @param nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! */ -template -void randomize( // - basic_string_slice string, generator_type_ &generator, - string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { +template +void randomize(basic_string_slice string, sz_u64_t nonce) noexcept { static_assert(!std::is_const::value, "The string must be mutable."); - sz_random_generator_t generator_callback = &_call_random_generator; - sz_generate(alphabet.data(), alphabet.size(), string.data(), string.size(), generator_callback, &generator); + sz_generate(string.data(), string.size(), nonce); } /** - * @brief Replaces ( @b in-place ) all characters in the string using the provided lookup table. + * @brief Replaces @b (in-place) all characters in the string using the provided lookup table. */ template void transform(basic_string_slice string, basic_look_up_table const &table) noexcept { diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index a3b9d62e..b10f57a1 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -323,6 +323,7 @@ typedef char *sz_ptr_t; // A type alias for `char *` typedef char const *sz_cptr_t; // A type alias for `char const *` typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions +struct sz_hash_state_t; // Forward declaration of a hash state structure struct sz_sequence_t; // Forward declaration of an ordered collection of strings typedef sz_size_t sz_sorted_idx_t; // Index of a sorted string in a list of strings typedef sz_size_t sz_pgram_t; // "Pointer-sized N-gram" of a string @@ -406,7 +407,6 @@ SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); -typedef sz_u64_t (*sz_random_generator_t)(void *); /** * @brief Some complex pattern matching algorithms may require memory allocations. @@ -457,6 +457,9 @@ typedef sz_u64_t (*sz_hash_state_fold_t)(struct sz_hash_state_t const *); /** @brief Signature of ::sz_bytesum. */ typedef sz_u64_t (*sz_bytesum_t)(sz_cptr_t, sz_size_t); +/** @brief Signature of ::sz_generate. */ +typedef void (*sz_generate_t)(sz_ptr_t, sz_size_t, sz_u64_t); + /** @brief Signature of ::sz_equal. */ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); diff --git a/rust/lib.rs b/rust/lib.rs index 07db0a32..d9d4e237 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -54,10 +54,12 @@ pub mod sz { needle_length: usize, ) -> *const c_void; - fn sz_hash(text: *const c_void, length: usize) -> u64; - fn sz_bytesum(text: *const c_void, length: usize) -> u64; + fn sz_hash(text: *const c_void, length: usize, seed: u64) -> u64; + + fn sz_generate(text: *mut c_void, length: usize, seed: u64) -> u64; + fn sz_edit_distance( haystack1: *const c_void, haystack1_length: usize, @@ -102,14 +104,6 @@ pub mod sz { allocator: *const c_void, ) -> isize; - fn sz_generate( - alphabet: *const c_void, - alphabet_size: usize, - text: *mut c_void, - length: usize, - generate: *const c_void, - generator: *mut c_void, - ); } /// Computes the checksum value of unsigned bytes in a given byte slice `text`. diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 64ba2f96..93ae2b7e 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -40,11 +40,23 @@ tracked_unary_functions_t bytesum_functions() { tracked_unary_functions_t hashing_functions() { auto wrap_sz = [](auto function) -> unary_function_t { - return unary_function_t([function](std::string_view s) { return function(s.data(), s.size()); }); + return unary_function_t([function](std::string_view s) { return function(s.data(), s.size(), 42); }); }; tracked_unary_functions_t result = { - {"sz_hash_serial", wrap_sz(sz_hash_serial)}, {"std::hash", [](std::string_view s) { return std::hash {}(s); }}, + {"sz_hash_serial", wrap_sz(sz_hash_serial)}, +#if SZ_USE_HASWELL + {"sz_hash_haswell", wrap_sz(sz_hash_haswell)}, +#endif +#if SZ_USE_SKYLAKE + {"sz_hash_skylake", wrap_sz(sz_hash_skylake)}, +#endif +#if SZ_USE_ICE + {"sz_hash_ice", wrap_sz(sz_hash_ice)}, +#endif +#if SZ_USE_NEON + {"sz_hash_neon", wrap_sz(sz_hash_neon)}, +#endif }; return result; } @@ -65,11 +77,11 @@ tracked_unary_functions_t random_generation_functions(std::size_t token_length) randomize_string(buffer.data(), token_length, alphabet.data(), alphabet.size()); return token_length; })}, - {"sz::randomize" + suffix, unary_function_t([token_length](std::string_view alphabet) -> std::size_t { - sz::string_span span(buffer.data(), token_length); - sz::randomize(span, global_random_generator(), alphabet); - return token_length; - })}, + // {"sz::randomize" + suffix, unary_function_t([token_length](std::string_view alphabet) -> std::size_t { + // sz::string_span span(buffer.data(), token_length); + // sz::randomize(span, global_random_generator(), alphabet); + // return token_length; + // })}, }; return result; } diff --git a/scripts/test.cpp b/scripts/test.cpp index 58752a35..74282523 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -939,11 +939,8 @@ void test_non_stl_extensions_for_updates() { // Randomization. assert(str::random(0).empty()); - assert(str::random(4, "a") == "aaaa"); - assert(str::random(4, "aaaa") == "aaaa"); - assert(str::random(global_random_generator(), 4, "aaaa") == "aaaa"); - assert_scoped(str s = str::random(128, "ACGT"), (void)s, - s.contains('A') && s.contains('C') && s.contains('G') && s.contains('T')); + assert(str::random(4).size() == 4); + assert(str::random(4, 42).size() == 4); } /** From 1da0e2b7944914e12ae7d563d2031af21da08952 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 07:25:30 +0000 Subject: [PATCH 111/751] Fix: Ice Lake partitioning logic --- include/stringzilla/sort.h | 88 +++++++++++++++++++------------------- scripts/bench_sort.cpp | 9 ++++ 2 files changed, 54 insertions(+), 43 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 4e1a6377..55a3677e 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -755,10 +755,8 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c /** * @brief The most important part of the QuickSort algorithm partitioning the elements around the pivot. * Unlike the serial algorithm, uses compressed stores to filter and move the elements around the pivot. - * Assuming the extreme cost of shuffling between 2 ZMM registers based on 2 different masks, we use - * extra memory to store the elements smaller and greater than the pivot somewhere else. */ -SZ_INTERNAL void _sz_sequence_argsort_ice_2way_partition( // +SZ_INTERNAL void _sz_sequence_argsort_ice_3way_partition( // sz_pgram_t *const initial_pgrams, sz_sorted_idx_t *const initial_order, // sz_pgram_t *const partitioned_pgrams, sz_sorted_idx_t *const partitioned_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // @@ -768,7 +766,7 @@ SZ_INTERNAL void _sz_sequence_argsort_ice_2way_partition( sz_size_t const pgrams_per_register = sizeof(sz_u512_vec_t) / sizeof(sz_pgram_t); // Choose the pivot offset with Sedgewick's method. - sz_pgram_t const *pivot_pgram_ptr = _sz_sequence_partitioning_pivot(initial_order + start_in_sequence, count); + sz_pgram_t const *pivot_pgram_ptr = _sz_sequence_partitioning_pivot(initial_pgrams + start_in_sequence, count); sz_pgram_t const pivot_pgram = *pivot_pgram_ptr; sz_u512_vec_t pivot_vec; pivot_vec.zmm = _mm512_set1_epi64(pivot_pgram); @@ -780,37 +778,35 @@ SZ_INTERNAL void _sz_sequence_argsort_ice_2way_partition( __mmask8 const tail_mask = _sz_u8_mask_until(tail_count); sz_u512_vec_t pgrams_vec, order_vec; - for (sz_size_t i = start_in_sequence; i < end_in_sequence; i += pgrams_per_register) { - pgrams_vec.zmm = // - i + pgrams_per_register <= end_in_sequence // - ? _mm512_loadu_si512(initial_pgrams + i) - : _mm512_maskz_loadu_epi64(tail_mask, initial_pgrams + i); + for (sz_size_t i = start_in_sequence; i + pgrams_per_register <= end_in_sequence; i += pgrams_per_register) { + pgrams_vec.zmm = _mm512_loadu_si512(initial_pgrams + i); count_smaller += sz_u32_popcount(_mm512_cmplt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm)); count_greater += sz_u32_popcount(_mm512_cmpgt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm)); } + if (tail_count) { + pgrams_vec.zmm = _mm512_maskz_loadu_epi64(tail_mask, initial_pgrams + end_in_sequence - tail_count); + count_smaller += sz_u32_popcount(_mm512_mask_cmplt_epu64_mask(tail_mask, pgrams_vec.zmm, pivot_vec.zmm)); + count_greater += sz_u32_popcount(_mm512_mask_cmpgt_epu64_mask(tail_mask, pgrams_vec.zmm, pivot_vec.zmm)); + } // Now all we need to do is to loop through the collection and export them into the temporary buffer // in 3 separate segments - smaller, equal, and greater than the pivot. sz_size_t const count_equal = count - count_smaller - count_greater; + _sz_assert(count_equal >= 1 && "The pivot must be present in the collection."); + _sz_assert(count_smaller + count_equal + count_greater == count && "The partitioning must be exhaustive."); sz_size_t smaller_offset = start_in_sequence; sz_size_t equal_offset = start_in_sequence + count_smaller; sz_size_t greater_offset = start_in_sequence + count_smaller + count_equal; // The naive algorithm - unzip the elements into 3 separate buffers. for (sz_size_t i = start_in_sequence; i < end_in_sequence; i += pgrams_per_register) { - if (i + pgrams_per_register <= end_in_sequence) { - pgrams_vec.zmm = _mm512_loadu_si512(initial_pgrams + i); - order_vec.zmm = _mm512_loadu_si512(initial_order + i); - } - else { - pgrams_vec.zmm = _mm512_maskz_loadu_epi64(tail_count, initial_pgrams + i); - order_vec.zmm = _mm512_maskz_loadu_epi64(tail_count, initial_order + i); - } - pgrams_vec.zmm = _mm512_loadu_si512(initial_pgrams + i); - order_vec.zmm = _mm512_loadu_si512(initial_order + i); - __mmask8 const smaller_mask = _mm512_cmplt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm); - __mmask8 const equal_mask = _mm512_cmpeq_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm); - __mmask8 const greater_mask = _mm512_cmpgt_epu64_mask(pgrams_vec.zmm, pivot_vec.zmm); + __mmask8 const load_mask = i + pgrams_per_register <= end_in_sequence ? 0xFF : tail_mask; + pgrams_vec.zmm = _mm512_maskz_loadu_epi64(load_mask, initial_pgrams + i); + order_vec.zmm = _mm512_maskz_loadu_epi64(load_mask, initial_order + i); + + __mmask8 const smaller_mask = _mm512_mask_cmplt_epu64_mask(load_mask, pgrams_vec.zmm, pivot_vec.zmm); + __mmask8 const equal_mask = _mm512_mask_cmpeq_epu64_mask(load_mask, pgrams_vec.zmm, pivot_vec.zmm); + __mmask8 const greater_mask = _mm512_mask_cmpgt_epu64_mask(load_mask, pgrams_vec.zmm, pivot_vec.zmm); // Compress the elements into the temporary buffer. _mm512_mask_compressstoreu_epi64(partitioned_pgrams + smaller_offset, smaller_mask, pgrams_vec.zmm); @@ -827,29 +823,35 @@ SZ_INTERNAL void _sz_sequence_argsort_ice_2way_partition( } // Copy back. - sz_copy((sz_ptr_t)(initial_pgrams), (sz_cptr_t)(partitioned_pgrams), count_smaller * sizeof(sz_pgram_t)); - sz_copy((sz_ptr_t)(initial_order), (sz_cptr_t)(partitioned_order), count_smaller * sizeof(sz_pgram_t)); - sz_copy((sz_ptr_t)(initial_pgrams + count_smaller), // - (sz_cptr_t)(partitioned_pgrams + count_smaller), // - count_equal * sizeof(sz_pgram_t)); - sz_copy((sz_ptr_t)(initial_order + count_smaller), // - (sz_cptr_t)(partitioned_order + count_smaller), // - count_equal * sizeof(sz_pgram_t)); - sz_copy((sz_ptr_t)(initial_pgrams + count_smaller + count_equal), // - (sz_cptr_t)(partitioned_pgrams + count_smaller + count_equal), // - count_greater); - sz_copy((sz_ptr_t)(initial_order + count_smaller + count_equal), // - (sz_cptr_t)(partitioned_order + count_smaller + count_equal), // - count_greater); + sz_copy_skylake((sz_ptr_t)(initial_pgrams + start_in_sequence), // + (sz_cptr_t)(partitioned_pgrams + start_in_sequence), // + count_smaller * sizeof(sz_pgram_t)); + sz_copy_skylake((sz_ptr_t)(initial_order + start_in_sequence), // + (sz_cptr_t)(partitioned_order + start_in_sequence), // + count_smaller * sizeof(sz_sorted_idx_t)); + + sz_copy_skylake((sz_ptr_t)(initial_pgrams + start_in_sequence + count_smaller), // + (sz_cptr_t)(partitioned_pgrams + start_in_sequence + count_smaller), // + count_equal * sizeof(sz_pgram_t)); + sz_copy_skylake((sz_ptr_t)(initial_order + start_in_sequence + count_smaller), // + (sz_cptr_t)(partitioned_order + start_in_sequence + count_smaller), // + count_equal * sizeof(sz_sorted_idx_t)); + + sz_copy_skylake((sz_ptr_t)(initial_pgrams + start_in_sequence + count_smaller + count_equal), // + (sz_cptr_t)(partitioned_pgrams + start_in_sequence + count_smaller + count_equal), // + count_greater * sizeof(sz_pgram_t)); + sz_copy_skylake((sz_ptr_t)(initial_order + start_in_sequence + count_smaller + count_equal), // + (sz_cptr_t)(partitioned_order + start_in_sequence + count_smaller + count_equal), // + count_greater * sizeof(sz_sorted_idx_t)); // Return the offsets of the equal elements. - *first_pivot_offset = count_smaller; - *last_pivot_offset = count_smaller + count_equal; + *first_pivot_offset = start_in_sequence + count_smaller; + *last_pivot_offset = start_in_sequence + count_smaller + count_equal - 1; } /** * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort_ice` and `sz_pgrams_sort_ice`, - * and using the `_sz_sequence_argsort_ice_2way_partition` under the hood. + * and using the `_sz_sequence_argsort_ice_3way_partition` under the hood. */ SZ_INTERNAL void _sz_sequence_argsort_ice_recursively( // sz_pgram_t *initial_pgrams, sz_sorted_idx_t *initial_order, // @@ -868,17 +870,17 @@ SZ_INTERNAL void _sz_sequence_argsort_ice_recursively( // // Partition the collection around some pivot sz_size_t first_pivot_index, last_pivot_index; - _sz_sequence_argsort_ice_2way_partition( // + _sz_sequence_argsort_ice_3way_partition( // initial_pgrams, initial_order, temporary_pgrams, temporary_order, // start_in_sequence, end_in_sequence, // &first_pivot_index, &last_pivot_index); - // Recursively sort the left and right partitions, tracking where the output goes - if (start_in_sequence < first_pivot_index) + // Recursively sort the left and right partitions, if there are at least 2 elements in each + if (start_in_sequence + 1 < first_pivot_index) _sz_sequence_argsort_ice_recursively( // initial_pgrams, initial_order, temporary_pgrams, temporary_order, // start_in_sequence, first_pivot_index); - if (last_pivot_index + 1 < end_in_sequence) + if (last_pivot_index + 2 < end_in_sequence) _sz_sequence_argsort_ice_recursively( // initial_pgrams, initial_order, temporary_pgrams, temporary_order, // last_pivot_index + 1, end_in_sequence); diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 729ac856..6bc0dd11 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -112,6 +112,15 @@ int main(int argc, char const **argv) { }); expect_sorted(pgrams, permute); + bench_permute("sz_pgrams_sort_ice", [&]() { + std::copy(pgrams.begin(), pgrams.end(), pgrams_sorted.begin()); + std::iota(permute.begin(), permute.end(), 0); + sz::_with_alloc([&](sz_memory_allocator_t &alloc) { + return sz_pgrams_sort_ice(pgrams_sorted.data(), pgrams_sorted.size(), &alloc, permute.data()); + }); + }); + expect_sorted(pgrams, permute); + // Unlike the `std::sort` adaptation above, the `sz_pgrams_sort_stable_serial` also sorts the input array inplace bench_permute("sz_pgrams_sort_stable_serial", [&]() { std::copy(pgrams.begin(), pgrams.end(), pgrams_sorted.begin()); From 69d4ecb656ed697eb0209931397587d0a8c165b1 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 07:44:39 +0000 Subject: [PATCH 112/751] Add: `sz_sequence_argsort_ice` --- include/stringzilla/sort.h | 77 ++++++++++++++++++++++++++++++++++++++ scripts/bench_sort.cpp | 16 +++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 55a3677e..0273467a 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -906,6 +906,83 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_ice(sz_pgram_t *pgrams, sz_size_t count, sz_m return sz_true_k; } +/** + * @brief Recursive Quick-Sort adaptation for strings, that processes the strings a few N-grams at a time. + * It combines `_sz_sequence_argsort_serial_export_next_pgrams` and `_sz_sequence_argsort_serial_recursively`, + * recursively diving into the identical pgrams. + */ +SZ_PUBLIC void _sz_sequence_argsort_ice_next_pgrams( // + sz_sequence_t const *const sequence, // + sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // + sz_pgram_t *const temporary_pgrams, sz_sorted_idx_t *const temporary_order, // + sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // + sz_size_t const start_character) { + + // Prepare the new range of pgrams + _sz_sequence_argsort_serial_export_next_pgrams(sequence, global_pgrams, global_order, start_in_sequence, + end_in_sequence, start_character); + + // Sort current pgrams with a quicksort + _sz_sequence_argsort_ice_recursively(global_pgrams, global_order, temporary_pgrams, temporary_order, + start_in_sequence, end_in_sequence); + + // Depending on the architecture, we will export a different number of bytes. + // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. + sz_size_t const pgram_capacity = sizeof(sz_pgram_t) - 1; + + // Repeat the procedure for the identical pgrams + sz_size_t nested_start = start_in_sequence; + sz_size_t nested_end = start_in_sequence; + while (nested_end != end_in_sequence) { + // Find the end of the identical pgrams + sz_pgram_t current_pgram = global_pgrams[nested_start]; + while (nested_end != end_in_sequence && current_pgram == global_pgrams[nested_end]) ++nested_end; + + // If the identical pgrams are not trivial and each string has more characters, sort them recursively + sz_cptr_t current_pgram_str = (sz_cptr_t)¤t_pgram; + sz_size_t current_pgram_length = (sz_size_t)current_pgram_str[0]; //! The byte order was swapped + int has_multiple_strings = nested_end - nested_start > 1; + int has_more_characters_in_each = current_pgram_length == pgram_capacity; + if (has_multiple_strings && has_more_characters_in_each) { + _sz_sequence_argsort_ice_next_pgrams(sequence, global_pgrams, global_order, temporary_pgrams, + temporary_order, nested_start, nested_end, + start_character + pgram_capacity); + } + // Move to the next + nested_start = nested_end; + } +} + +SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + + // First, initialize the `order` with `std::iota`-like behavior. + sz_size_t count = sequence->count; + for (sz_size_t i = 0; i != count; ++i) order[i] = i; + + // On very small collections - just use the quadratic-complexity insertion sort + // without any smart optimizations or memory allocations. + if (count <= 32) { + sz_sequence_argsort_with_insertion(sequence, order); + return sz_true_k; + } + + // Allocate memory for partitioning the elements around the pivot. + sz_size_t memory_usage = sizeof(sz_pgram_t) * count * 2 + sizeof(sz_sorted_idx_t) * count; + sz_pgram_t *global_pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); + sz_pgram_t *temporary_pgrams = global_pgrams + count; + sz_sorted_idx_t *temporary_order = (sz_sorted_idx_t *)(temporary_pgrams + count); + if (!global_pgrams) return sz_false_k; + + // Recursively sort the whole sequence. + _sz_sequence_argsort_ice_next_pgrams(sequence, global_pgrams, order, temporary_pgrams, temporary_order, // + 0, count, 0); + + // Free temporary storage. + alloc->free(global_pgrams, memory_usage, alloc); + return sz_true_k; +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_ICE diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 6bc0dd11..22758d95 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -139,7 +139,7 @@ int main(int argc, char const **argv) { }); expect_sorted(strings, permute); - bench_permute("sz_sequence_argsort", [&]() { + bench_permute("sz_sequence_argsort_serial", [&]() { std::iota(permute.begin(), permute.end(), 0); sz_sequence_t array; array.count = strings.size(); @@ -147,7 +147,19 @@ int main(int argc, char const **argv) { array.get_start = get_start; array.get_length = get_length; sz::_with_alloc( - [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort(&array, &alloc, permute.data()); }); + [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort_serial(&array, &alloc, permute.data()); }); + }); + expect_sorted(strings, permute); + + bench_permute("sz_sequence_argsort_ice", [&]() { + std::iota(permute.begin(), permute.end(), 0); + sz_sequence_t array; + array.count = strings.size(); + array.handle = &strings; + array.get_start = get_start; + array.get_length = get_length; + sz::_with_alloc( + [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort_ice(&array, &alloc, permute.data()); }); }); expect_sorted(strings, permute); From cc98389e82005783ffe8c0f28974207ad02ad83a Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 07:44:53 +0000 Subject: [PATCH 113/751] Add: Sorting placeholders & dispatch --- c/lib.c | 36 ++++++++++++++++++++++++++++++++++ include/stringzilla/sort.h | 39 +++++++++++++++++++++++++++++++++++++ include/stringzilla/types.h | 9 +++++++++ 3 files changed, 84 insertions(+) diff --git a/c/lib.c b/c/lib.c index 559062ba..9d79fd98 100644 --- a/c/lib.c +++ b/c/lib.c @@ -196,6 +196,9 @@ typedef struct sz_implementations_t { sz_alignment_score_t alignment_score; sz_sequence_argsort_t sequence_argsort; + sz_pgrams_sort_t pgrams_sort; + sz_sequence_argsort_stable_t sequence_argsort_stable; + sz_pgrams_sort_stable_t pgrams_sort_stable; } sz_implementations_t; @@ -237,7 +240,11 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->edit_distance = sz_edit_distance_serial; impl->alignment_score = sz_alignment_score_serial; + impl->sequence_argsort = sz_sequence_argsort_serial; + impl->pgrams_sort = sz_pgrams_sort_serial; + impl->sequence_argsort_stable = sz_sequence_argsort_stable_serial; + impl->pgrams_sort_stable = sz_pgrams_sort_stable_serial; #if SZ_USE_HASWELL if (caps & sz_cap_haswell_k) { @@ -305,6 +312,11 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_stream = sz_hash_state_stream_ice; impl->hash_state_fold = sz_hash_state_fold_ice; impl->generate = sz_generate_ice; + + impl->sequence_argsort = sz_sequence_argsort_ice; + impl->pgrams_sort = sz_pgrams_sort_ice; + impl->sequence_argsort_stable = sz_sequence_argsort_stable_ice; + impl->pgrams_sort_stable = sz_pgrams_sort_stable_ice; } #endif @@ -332,6 +344,15 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->rfind_from_set = sz_rfind_charset_neon; } #endif + +#if SZ_USE_SVE + if (caps & sz_cap_sve_k) { + impl->sequence_argsort = sz_sequence_argsort_sve; + impl->pgrams_sort = sz_pgrams_sort_sve; + impl->sequence_argsort_stable = sz_sequence_argsort_stable_sve; + impl->pgrams_sort_stable = sz_pgrams_sort_stable_sve; + } +#endif } #if defined(_MSC_VER) @@ -479,6 +500,21 @@ SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *array, sz_memory_a return sz_dispatch_table.sequence_argsort(array, alloc, order); } +SZ_DYNAMIC sz_bool_t sz_pgrams_sort(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, + sz_size_t *order) { + return sz_dispatch_table.pgrams_sort(array, count, alloc, order); +} + +SZ_DYNAMIC sz_bool_t sz_sequence_argsort_stable(sz_sequence_t const *array, sz_memory_allocator_t *alloc, + sz_size_t *order) { + return sz_dispatch_table.sequence_argsort_stable(array, alloc, order); +} + +SZ_DYNAMIC sz_bool_t sz_pgrams_sort_stable(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, + sz_size_t *order) { + return sz_dispatch_table.pgrams_sort_stable(array, count, alloc, order); +} + SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_charset_t set; sz_charset_init(&set); diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 0273467a..190074a3 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -996,7 +996,46 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_me SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { +#if SZ_USE_ICE + return sz_sequence_argsort_ice(sequence, alloc, order); +#elif SZ_USE_SVE + return sz_sequence_argsort_sve(sequence, alloc, order); +#else + return sz_sequence_argsort_serial(sequence, alloc, order); +#endif +} + +SZ_DYNAMIC sz_bool_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { +#if SZ_USE_ICE + return sz_pgrams_sort_ice(pgrams, count, alloc, order); +#elif SZ_USE_SVE + return sz_pgrams_sort_sve(pgrams, count, alloc, order); +#else + return sz_pgrams_sort_serial(pgrams, count, alloc, order); +#endif +} + +SZ_DYNAMIC sz_bool_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { +#if SZ_USE_ICE + return sz_sequence_argsort_ice(sequence, alloc, order); +#elif SZ_USE_SVE + return sz_sequence_argsort_sve(sequence, alloc, order); +#else return sz_sequence_argsort_serial(sequence, alloc, order); +#endif +} + +SZ_DYNAMIC sz_bool_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { +#if SZ_USE_ICE + return sz_pgrams_sort_ice(pgrams, count, alloc, order); +#elif SZ_USE_SVE + return sz_pgrams_sort_sve(pgrams, count, alloc, order); +#else + return sz_pgrams_sort_serial(pgrams, count, alloc, order); +#endif } #endif // !SZ_DYNAMIC_DISPATCH diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index b10f57a1..825e36f9 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -497,6 +497,15 @@ typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_s /** @brief Signature of ::sz_sequence_argsort. */ typedef sz_bool_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); +/** @brief Signature of ::sz_pgrams_sort. */ +typedef sz_bool_t (*sz_pgrams_sort_t)(sz_pgram_t *, sz_size_t, sz_memory_allocator_t *, sz_sorted_idx_t *); + +/** @brief Signature of ::sz_sequence_argsort_stable. */ +typedef sz_sequence_argsort_t sz_sequence_argsort_stable_t; + +/** @brief Signature of ::sz_pgrams_sort_stable. */ +typedef sz_pgrams_sort_t sz_pgrams_sort_stable_t; + #pragma endregion #pragma region Helper Structures From 8bc161f4256c50d00a6c5f22e11612a00c47a12a Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 13:27:23 +0000 Subject: [PATCH 114/751] Docs: Disable sorting includes --- .clang-format | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.clang-format b/.clang-format index c1418bae..c97feb6f 100644 --- a/.clang-format +++ b/.clang-format @@ -44,8 +44,8 @@ BraceWrapping: SplitEmptyNamespace: false IndentBraces: false -SortIncludes: true -SortUsingDeclarations: true +SortIncludes: false +SortUsingDeclarations: false SpaceAfterCStyleCast: false SpaceAfterLogicalNot: false From e0055d5d88828aba726c4baf87605abcb3865d39 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 13:29:31 +0000 Subject: [PATCH 115/751] Break: `look_up_transform` to `lookup` API --- c/lib.c | 60 +++++----- include/stringzilla/memory.h | 163 ++++++++++++++++++++-------- include/stringzilla/stringzilla.hpp | 11 +- scripts/bench_memory.cpp | 8 +- 4 files changed, 157 insertions(+), 85 deletions(-) diff --git a/c/lib.c b/c/lib.c index 9d79fd98..b65784cc 100644 --- a/c/lib.c +++ b/c/lib.c @@ -176,7 +176,7 @@ typedef struct sz_implementations_t { sz_move_t copy; sz_move_t move; sz_fill_t fill; - sz_look_up_transform_t look_up_transform; + sz_lookup_t lookup; sz_bytesum_t bytesum; sz_hash_t hash; @@ -192,8 +192,8 @@ typedef struct sz_implementations_t { sz_find_set_t find_from_set; sz_find_set_t rfind_from_set; - sz_edit_distance_t edit_distance; - sz_alignment_score_t alignment_score; + sz_levenshtein_distance_t edit_distance; + sz_needleman_wunsch_score_t alignment_score; sz_sequence_argsort_t sequence_argsort; sz_pgrams_sort_t pgrams_sort; @@ -222,7 +222,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->copy = sz_copy_serial; impl->move = sz_move_serial; impl->fill = sz_fill_serial; - impl->look_up_transform = sz_look_up_transform_serial; + impl->lookup = sz_lookup_serial; impl->bytesum = sz_bytesum_serial; impl->hash = sz_hash_serial; @@ -238,8 +238,8 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->find_from_set = sz_find_charset_serial; impl->rfind_from_set = sz_rfind_charset_serial; - impl->edit_distance = sz_edit_distance_serial; - impl->alignment_score = sz_alignment_score_serial; + impl->edit_distance = sz_levenshtein_distance_serial; + impl->alignment_score = sz_needleman_wunsch_score_serial; impl->sequence_argsort = sz_sequence_argsort_serial; impl->pgrams_sort = sz_pgrams_sort_serial; @@ -254,7 +254,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->copy = sz_copy_haswell; impl->move = sz_move_haswell; impl->fill = sz_fill_haswell; - impl->look_up_transform = sz_look_up_transform_haswell; + impl->lookup = sz_lookup_haswell; impl->bytesum = sz_bytesum_haswell; impl->hash = sz_hash_haswell; @@ -301,10 +301,10 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->find_from_set = sz_find_charset_ice; impl->rfind_from_set = sz_rfind_charset_ice; - impl->edit_distance = sz_edit_distance_ice; - impl->alignment_score = sz_alignment_score_ice; + impl->edit_distance = sz_levenshtein_distance_ice; + impl->alignment_score = sz_needleman_wunsch_score_ice; - impl->look_up_transform = sz_look_up_transform_ice; + impl->lookup = sz_lookup_ice; impl->bytesum = sz_bytesum_ice; impl->hash = sz_hash_ice; @@ -327,7 +327,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->copy = sz_copy_neon; impl->move = sz_move_neon; impl->fill = sz_fill_neon; - impl->look_up_transform = sz_look_up_transform_neon; + impl->lookup = sz_lookup_neon; impl->bytesum = sz_bytesum_neon; impl->hash = sz_hash_neon; @@ -433,8 +433,8 @@ SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { sz_dispatch_table.fill(target, length, value); } -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { - sz_dispatch_table.look_up_transform(source, length, lut, target); +SZ_DYNAMIC void sz_lookup(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut) { + sz_dispatch_table.lookup(target, length, source, lut); } SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { @@ -475,43 +475,43 @@ SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); } -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // +SZ_DYNAMIC sz_size_t sz_levenshtein_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { return sz_dispatch_table.edit_distance(a, a_length, b, b_length, bound, alloc); } -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // +SZ_DYNAMIC sz_size_t sz_levenshtein_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); + return _sz_levenshtein_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); } -SZ_DYNAMIC sz_ssize_t sz_alignment_score( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // +SZ_DYNAMIC sz_ssize_t sz_needleman_wunsch_score( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { return sz_dispatch_table.alignment_score(a, a_length, b, b_length, subs, gap, alloc); } -SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *array, sz_memory_allocator_t *alloc, sz_size_t *order) { +SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *array, sz_memory_allocator_t *alloc, sz_size_t *order) { return sz_dispatch_table.sequence_argsort(array, alloc, order); } -SZ_DYNAMIC sz_bool_t sz_pgrams_sort(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, - sz_size_t *order) { +SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, + sz_size_t *order) { return sz_dispatch_table.pgrams_sort(array, count, alloc, order); } -SZ_DYNAMIC sz_bool_t sz_sequence_argsort_stable(sz_sequence_t const *array, sz_memory_allocator_t *alloc, - sz_size_t *order) { +SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *array, sz_memory_allocator_t *alloc, + sz_size_t *order) { return sz_dispatch_table.sequence_argsort_stable(array, alloc, order); } -SZ_DYNAMIC sz_bool_t sz_pgrams_sort_stable(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, - sz_size_t *order) { +SZ_DYNAMIC sz_status_t sz_pgrams_sort_stable(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, + sz_size_t *order) { return sz_dispatch_table.pgrams_sort_stable(array, count, alloc, order); } diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index d8db210b..de739f22 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -3,12 +3,12 @@ * @file memory.h * @author Ash Vardanian * - * Includes: + * Includes core APIs for contiguous memory operations: * * - `sz_copy` - analog to `memcpy` * - `sz_move` - analog to `memmove` * - `sz_fill` - analog to `memset` - * - `sz_look_up_transform` - LUT transformation of a string, similar to OpenCV LUT + * - `sz_lookup` - LUT transformation of a string, similar to OpenCV LUT * - TODO: `sz_detect_encoding` - similar to `iconv` or `chardet` * * Convenience functions for character-set mapping: @@ -28,11 +28,27 @@ extern "C" { /** * @brief Similar to `memcpy`, copies contents of one string into another. - * The behavior is undefined if the strings overlap. * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. + * @param[out] target String to copy into. Can be `NULL`, if the @p length is zero. + * @param[in] length Number of bytes to copy. Can be a zero. + * @param[in] source String to copy from. Can be `NULL`, if the @p length is zero. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char output[2]; + * sz_copy(output, "hi", 2); + * return output[0] == 'h' && output[1] == 'i' ? 0 : 1; + * } + * @endcode + * + * @pre The @p target and @p source must not overlap. + * @sa sz_move + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_copy_serial, sz_copy_haswell, sz_copy_skylake, sz_copy_neon */ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); @@ -40,27 +56,92 @@ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); * @brief Similar to `memmove`, copies (moves) contents of one string into another. * Unlike `sz_copy`, allows overlapping strings as arguments. * - * @param target String to copy into. - * @param length Number of bytes to copy. - * @param source String to copy from. + * @param[out] target String to copy into. Can be `NULL`, if the @p length is zero. + * @param[in] length Number of bytes to copy. Can be a zero. + * @param[in] source String to copy from. Can be `NULL`, if the @p length is zero. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char buffer[3] = {'a', 'b', 'c'}; + * sz_move(buffer, buffer + 1, 2); + * return buffer[0] == 'b' && buffer[1] == 'c' && buffer[2] == 'c' ? 0 : 1; + * } + * @endcode + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_move_serial, sz_move_haswell, sz_move_skylake, sz_move_neon */ SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** * @brief Similar to `memset`, fills a string with a given value. * - * @param target String to fill. - * @param length Number of bytes to fill. - * @param value Value to fill with. + * @param[out] target String to fill. Can be `NULL`, if the @p length is zero. + * @param[in] length Number of bytes to fill. Can be a zero. + * @param[in] value Value to fill with. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char buffer[2]; + * sz_fill(buffer, 2, 'x'); + * return buffer[0] == 'x' && buffer[1] == 'x' ? 0 : 1; + * } + * @endcode + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_fill_serial, sz_fill_haswell, sz_fill_skylake, sz_fill_neon */ SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); +/** + * @brief Look Up Table @b (LUT) transformation of a @p source string. Same as `for (char &c : text) c = lut[c]`. + * @see https://en.wikipedia.org/wiki/Lookup_table + * + * Can be used to implement some form of string normalization, partially masking punctuation marks, + * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has + * broad implications in image processing, where image channel transformations are often done using LUTs. + * + * @param[out] target Output string, can point to the same address as @p source. + * @param[in] length Number of bytes in the string. + * @param[in] source String to be mapped using the @p lut table into the @p target. + * @param[in] lut Look Up Table to apply. Must be exactly @b 256 bytes long. + * + * Example usage: + * + * @code{.c} + * #include // for `tolower` + * #include + * int main() { + * char to_lower_lut[256]; + * for (int i = 0; i < 256; ++i) to_lower_lut[i] = tolower(i); + * char buffer[3] = {'A', 'B', 'C'}; + * sz_lookup(buffer, 3, buffer, to_lower_lut); + * return buffer[0] == 'a' && buffer[1] == 'b' && buffer[2] == 'c' ? 0 : 1; + * } + * @endcode + * + * @pre The @p lut must be exactly 256 bytes long, even if the @p source string has no characters in the top range. + * @pre The @p target and @p source can be the same, but must not overlap. + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_lookup_serial, sz_lookup_haswell, sz_lookup_ice, sz_lookup_neon + */ +SZ_DYNAMIC void sz_lookup(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut); + /** @copydoc sz_copy */ SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); +/** @copydoc sz_lookup */ +SZ_PUBLIC void sz_lookup_serial(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut); #if SZ_USE_HASWELL /** @copydoc sz_copy */ @@ -69,6 +150,8 @@ SZ_PUBLIC void sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t leng SZ_PUBLIC void sz_move_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_rfind_fill */ SZ_PUBLIC void sz_fill_haswell(sz_ptr_t target, sz_size_t length, sz_u8_t value); +/** @copydoc sz_lookup */ +SZ_PUBLIC void sz_lookup_haswell(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut); #endif #if SZ_USE_SKYLAKE @@ -80,6 +163,11 @@ SZ_PUBLIC void sz_move_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t leng SZ_PUBLIC void sz_fill_skylake(sz_ptr_t target, sz_size_t length, sz_u8_t value); #endif +#if SZ_USE_ICE +/** @copydoc sz_lookup */ +SZ_PUBLIC void sz_lookup_ice(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut); +#endif + #if SZ_USE_NEON /** @copydoc sz_copy */ SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); @@ -87,25 +175,10 @@ SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_rfind_fill */ SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); +/** @copydoc sz_lookup */ +SZ_PUBLIC void sz_lookup_neon(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut); #endif -/** - * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. - * - * Can be used to implement some form of string normalization, partially masking punctuation marks, - * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has - * broad implications in image processing, where image channel transformations are often done using LUTs. - * - * @param text String to be normalized. - * @param length Number of bytes in the string. - * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. - * @param result Output string, can point to the same address as ::text. - */ -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - -/** @copydoc sz_look_up_transform */ -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); - #pragma endregion // Core API #pragma region Helper API @@ -120,7 +193,7 @@ SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_ * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html * * @param text String to be normalized. - * @param length Number of bytes in the string. + * @param[in] length Number of bytes in the string. * @param result Output string, can point to the same address as ::text. */ SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); @@ -135,7 +208,7 @@ SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html * * @param text String to be normalized. - * @param length Number of bytes in the string. + * @param[in] length Number of bytes in the string. * @param result Output string, can point to the same address as ::text. */ SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); @@ -144,7 +217,7 @@ SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); * @brief Equivalent to `for (char & c : text) c = toascii(c)`. * * @param text String to be normalized. - * @param length Number of bytes in the string. + * @param[in] length Number of bytes in the string. * @param result Output string, can point to the same address as ::text. */ SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); @@ -203,7 +276,7 @@ SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { return upped[c]; } -SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { +SZ_PUBLIC void sz_lookup_serial(sz_ptr_t result, sz_size_t length, sz_cptr_t text, sz_cptr_t lut) { sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; sz_u8_t const *unsigned_text = (sz_u8_t const *)text; sz_u8_t *unsigned_result = (sz_u8_t *)result; @@ -454,13 +527,13 @@ SZ_PUBLIC void sz_move_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t leng } } -SZ_PUBLIC void sz_look_up_transform_haswell(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_PUBLIC void sz_lookup_haswell(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); + sz_lookup_serial(target, length, source, lut); return; } @@ -587,7 +660,7 @@ SZ_PUBLIC void sz_look_up_transform_haswell(sz_cptr_t source, sz_size_t length, } // Handle the tail. - if (length) sz_look_up_transform_serial(source, length, lut, target); + if (length) sz_lookup_serial(target, length, source, lut); } #pragma clang attribute pop @@ -838,13 +911,13 @@ SZ_PUBLIC void sz_move_skylake(sz_ptr_t target, sz_cptr_t source, sz_size_t leng #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,bmi,bmi2"))), \ apply_to = function) -SZ_PUBLIC void sz_look_up_transform_ice(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_PUBLIC void sz_lookup_ice(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); + sz_lookup_serial(target, length, source, lut); return; } @@ -1075,12 +1148,12 @@ SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { if (length) sz_fill_serial(target, length, value); } -SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_PUBLIC void sz_lookup_neon(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. if (length <= 128) { - sz_look_up_transform_serial(source, length, lut, target); + sz_lookup_serial(target, length, source, lut); return; } @@ -1291,15 +1364,15 @@ SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { #endif } -SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { +SZ_DYNAMIC void sz_lookup(sz_ptr_t target, sz_size_t length, sz_cptr_t source, sz_cptr_t lut) { #if SZ_USE_ICE - sz_look_up_transform_ice(source, length, lut, target); + sz_lookup_ice(target, length, source, lut); #elif SZ_USE_HASWELL - sz_look_up_transform_haswell(source, length, lut, target); + sz_lookup_haswell(target, length, source, lut); #elif SZ_USE_NEON - sz_look_up_transform_neon(source, length, lut, target); + sz_lookup_neon(target, length, source, lut); #else - sz_look_up_transform_serial(source, length, lut, target); + sz_lookup_serial(target, length, source, lut); #endif } diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 3f5466b6..24f8fc94 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -3438,7 +3438,7 @@ class basic_string { sz_ptr_t start; sz_size_t length; sz_string_range(&string_, &start, &length); - sz_look_up_transform((sz_cptr_t)start, (sz_size_t)length, (sz_cptr_t)table.raw(), (sz_ptr_t)output); + sz_lookup((sz_ptr_t)output, (sz_size_t)length, (sz_cptr_t)start, (sz_cptr_t)table.raw()); } private: @@ -3930,21 +3930,20 @@ void randomize(basic_string_slice string, sz_u64_t nonce) noexcept { * @brief Replaces @b (in-place) all characters in the string using the provided lookup table. */ template -void transform(basic_string_slice string, basic_look_up_table const &table) noexcept { +void lookup(basic_string_slice string, basic_look_up_table const &table) noexcept { static_assert(sizeof(char_type_) == 1, "The character type must be 1 byte long."); - sz_look_up_transform((sz_cptr_t)string.data(), (sz_size_t)string.size(), (sz_cptr_t)table.raw(), - (sz_ptr_t)string.data()); + sz_lookup((sz_ptr_t)string.data(), (sz_size_t)string.size(), (sz_cptr_t)string.data(), (sz_cptr_t)table.raw()); } /** * @brief Maps all characters in the current string into another buffer using the provided lookup table. */ template -void transform( // +void lookup( // basic_string_slice source, basic_look_up_table const &table, char_type_ *target) noexcept { static_assert(sizeof(char_type_) == 1, "The character type must be 1 byte long."); - sz_look_up_transform((sz_cptr_t)source.data(), (sz_size_t)source.size(), (sz_cptr_t)table.raw(), (sz_ptr_t)target); + sz_lookup((sz_ptr_t)target, (sz_size_t)source.size(), (sz_cptr_t)source.data(), (sz_cptr_t)table.raw()); } /** diff --git a/scripts/bench_memory.cpp b/scripts/bench_memory.cpp index 7a9acf25..4f52c282 100644 --- a/scripts/bench_memory.cpp +++ b/scripts/bench_memory.cpp @@ -191,15 +191,15 @@ tracked_unary_functions_t transform_functions() { std::transform(slice.begin(), slice.end(), output, [](char c) { return c + 1; }); return slice.size(); })}, - {"sz_look_up_transform_serial", wrap_sz(sz_look_up_transform_serial)}, + {"sz_lookup_serial", wrap_sz(sz_lookup_serial)}, #if SZ_USE_ICE - {"sz_look_up_transform_ice", wrap_sz(sz_look_up_transform_ice)}, + {"sz_lookup_ice", wrap_sz(sz_lookup_ice)}, #endif #if SZ_USE_HASWELL - {"sz_look_up_transform_haswell", wrap_sz(sz_look_up_transform_haswell)}, + {"sz_lookup_haswell", wrap_sz(sz_lookup_haswell)}, #endif #if SZ_USE_NEON - {"sz_look_up_transform_neon", wrap_sz(sz_look_up_transform_neon)}, + {"sz_lookup_neon", wrap_sz(sz_lookup_neon)}, #endif }; return result; From 944804e32588e5a785207644e4a942c18d43a13b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 13:34:41 +0000 Subject: [PATCH 116/751] Break: Return error-codes in sort functions --- include/stringzilla/sort.h | 353 +++++++++++++++++++++++------------- include/stringzilla/types.h | 19 ++ 2 files changed, 245 insertions(+), 127 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 190074a3..a394b646 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -23,7 +23,6 @@ * * - `sz_sequence_argsort_with_insertion` - for string collections. * - `sz_pgrams_sort_stable_with_insertion` - for continuous unsigned integers. - * */ #ifndef STRINGZILLA_SORT_H_ #define STRINGZILLA_SORT_H_ @@ -41,113 +40,209 @@ extern "C" { /** * @brief Faster @b arg-sort for an arbitrary @b string sequence, using QuickSort. - * Outputs the ::order of elements in the immutable ::sequence, that would sort it. - * The algorithm doesn't guarantee stability, meaning that the relative order of equal elements - * may not be preserved. + * Outputs the @p order of elements in the immutable @p sequence, that would sort it. + * + * @param[in] sequence Immutable sequence of strings to sort. + * @param[in] alloc Optional memory allocator for temporary storage. + * @param[out] order Output permutation that sorts the elements. Must fit at least `sequence->count` integers. + * + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @post The @p order array will contain a valid permutation of `[0, sequence->count - 1]`. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char const *strings[] = {"banana", "apple", "cherry"}; + * sz_sequence_t sequence; + * sz_sequence_from_null_terminated_strings(strings, 3, &sequence); + * sz_sorted_idx_t order[3]; + * sz_sequence_argsort(&sequence, NULL, order); + * return order[0] == 1 && order[1] == 0 && order[2] == 2 ? 0 : 1; + * } + * @endcode + * + * @note The algorithm has linear memory complexity, quadratic worst-case and log-linear average time complexity. + * @see https://en.wikipedia.org/wiki/Quicksort + * + * @note This algorithm is @b unstable: equal elements may change relative order. + * @sa sz_sequence_argsort_stable * - * @param sequence The sequence of strings to sort. - * @param alloc Memory allocator for temporary storage. - * @param order The output - indices of the sorted sequence elements. - * @return Whether the operation was successful. + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_sequence_argsort_serial, sz_sequence_argsort_skylake, sz_sequence_argsort_sve */ -SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** * @brief Faster @b inplace `std::sort` for a continuous @b unsigned-integer sequence, using QuickSort. - * Overwrites the input ::sequence with the sorted sequence and exports the permutation ::order. - * The algorithm doesn't guarantee stability, meaning that the relative order of equal elements - * may not be preserved. + * Overwrites the input @p pgrams with the sorted sequence and exports the @p order permutation. + * + * @param[inout] pgrams Continuous buffer of unsigned integers to sort in place. + * @param[in] count Number of elements in the sequence. + * @param[in] alloc Optional memory allocator for temporary storage. + * @param[out] order Output permutation that sorts the elements. Must fit at least @p count integers. + * + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @post The @p order array will contain a valid permutation of `[0, count - 1]`. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * sz_pgram_t pgrams[] = {42, 17, 99, 8}; + * sz_sorted_idx_t order[4]; + * sz_pgrams_sort(pgrams, 4, NULL, order); + * return order[0] == 3 && order[1] == 1 && order[2] == 0 && order[3] == 2 ? 0 : 1; + * } + * @endcode * - * @param pgrams The continuous buffer of unsigned integers to sort in place. - * @param count The number of elements in the sequence. - * @param alloc Memory allocator for temporary storage. - * @param order The output - indices of the sorted sequence elements. - * @return Whether the operation was successful. + * @note The algorithm has linear memory complexity, quadratic worst-case and log-linear average time complexity. + * @see https://en.wikipedia.org/wiki/Quicksort + * + * @note This algorithm is @b unstable: equal elements may change relative order. + * @sa sz_pgrams_sort_stable + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_pgrams_sort_serial, sz_pgrams_sort_skylake, sz_pgrams_sort_sve */ -SZ_DYNAMIC sz_bool_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** * @brief Faster @b arg-sort for an arbitrary @b string sequence, using MergeSort. - * Outputs the ::order of elements in the immutable ::sequence, that would sort it. - * The algorithm guarantees stability, meaning that the relative order of equal elements is preserved. + * Outputs the @p order of elements in the immutable @p sequence, that would sort it. + * + * This algorithm guarantees stability, ensuring that the relative order of equal elements is preserved. + * It uses more memory than `sz_sequence_argsort`, but its performance is more predictable. + * It's preferred for very large inputs, as most memory access happens in a sequential pattern. + * + * @param[in] sequence Immutable sequence of strings to sort. + * @param[in] alloc Optional memory allocator for temporary storage. + * @param[out] order Output permutation that sorts the elements. Must fit at least `sequence->count` integers. + * + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @post The @p order array will contain a valid permutation of `[0, sequence->count - 1]`. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char const *strings[] = {"banana", "apple", "cherry"}; + * sz_sequence_t sequence; + * sz_sequence_from_null_terminated_strings(strings, 3, &sequence); + * sz_sorted_idx_t order[3]; + * sz_sequence_argsort_stable(&sequence, NULL, order); + * return order[0] == 1 && order[1] == 0 && order[2] == 2 ? 0 : 1; + * } + * @endcode * - * This algorithm uses more memory than `sz_sequence_argsort`, but it's performance is more predictable. - * It's also preferred for very large inputs, as most memory access happens in a predictable sequential order. + * @note The algorithm has linear memory complexity and log-linear time complexity. + * @see https://en.wikipedia.org/wiki/Merge_sort * - * @param sequence The sequence of strings to sort. - * @param alloc Memory allocator for temporary storage. - * @param order The output - indices of the sorted sequence elements. - * @return Whether the operation was successful. + * @note This algorithm is @b stable: equal elements maintain their relative order. + * @sa sz_sequence_argsort + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_sequence_argsort_stable_serial, sz_sequence_argsort_stable_skylake, sz_sequence_argsort_stable_sve */ -SZ_DYNAMIC sz_bool_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** - * @brief Faster @b inplace `std::stable_sort sort` for a continuous @b unsigned-integer sequence, using MergeSort. - * Overwrites the input ::sequence with the sorted sequence and exports the permutation ::order. - * The algorithm guarantees stability, meaning that the relative order of equal elements is preserved. + * @brief Faster @b inplace `std::stable_sort` for a continuous @b unsigned-integer sequence, using MergeSort. + * Overwrites the input @p pgrams with the sorted sequence and exports the @p order permutation. + * + * This algorithm guarantees stability, ensuring that the relative order of equal elements is preserved. + * It uses more memory than `sz_pgrams_sort`, but its performance is more predictable. + * It's preferred for very large inputs, as most memory access happens in a sequential pattern. + * + * @param[inout] pgrams Continuous buffer of unsigned integers to sort in place. + * @param[in] count Number of elements in the sequence. + * @param[in] alloc Optional memory allocator for temporary storage. + * @param[out] order Output permutation that sorts the elements. Must fit at least @p count integers. + * + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @post The @p order array will contain a valid permutation of `[0, count - 1]`. + * + * Example usage: * - * This algorithm uses more memory than `sz_pgrams_sort`, but it's performance is more predictable. - * It's also preferred for very large inputs, as most memory access happens in a predictable sequential order. + * @code{.c} + * #include + * int main() { + * sz_pgram_t pgrams[] = {42, 17, 99, 8}; + * sz_sorted_idx_t order[4]; + * sz_pgrams_sort_stable(pgrams, 4, NULL, order); + * return order[0] == 3 && order[1] == 1 && order[2] == 0 && order[3] == 2 ? 0 : 1; + * } + * @endcode * - * @param pgrams The continuous buffer of unsigned integers to sort in place. - * @param count The number of elements in the sequence. - * @param alloc Memory allocator for temporary storage. - * @param order The output - indices of the sorted sequence elements. - * @return Whether the operation was successful. + * @note The algorithm has linear memory complexity and log-linear time complexity. + * @see [MergeSort Algorithm](https://en.wikipedia.org/wiki/Merge_sort) + * + * @note This algorithm is @b stable: equal elements maintain their relative order. + * @sa sz_pgrams_sort + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_pgrams_sort_stable_serial, sz_pgrams_sort_stable_skylake, sz_pgrams_sort_stable_sve */ -SZ_DYNAMIC sz_bool_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_DYNAMIC sz_status_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_sequence_argsort */ -SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_pgrams_sort */ -SZ_PUBLIC sz_bool_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_sequence_argsort */ -SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_pgrams_sort */ -SZ_PUBLIC sz_bool_t sz_pgrams_sort_ice(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_pgrams_sort_skylake(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_sequence_argsort */ -SZ_PUBLIC sz_bool_t sz_sequence_argsort_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_sequence_argsort_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_pgrams_sort */ -SZ_PUBLIC sz_bool_t sz_pgrams_sort_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_pgrams_sort_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_sequence_argsort_stable */ -SZ_PUBLIC sz_bool_t sz_sequence_argsort_stable_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_pgrams_sort_stable */ -SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_sequence_argsort_stable */ -SZ_PUBLIC sz_bool_t sz_sequence_argsort_stable_ice(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_skylake(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_pgrams_sort_stable */ -SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_ice(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_skylake(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_sequence_argsort_stable */ -SZ_PUBLIC sz_bool_t sz_sequence_argsort_stable_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_pgrams_sort_stable */ -SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); #pragma endregion @@ -557,8 +652,8 @@ SZ_PUBLIC void _sz_sequence_argsort_serial_next_pgrams( // } } -SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_status_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != sequence->count; ++i) order[i] = i; @@ -567,7 +662,7 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz // without any smart optimizations or memory allocations. if (sequence->count <= 32) { sz_sequence_argsort_with_insertion(sequence, order); - return sz_true_k; + return sz_success_k; } // One of the reasons for slow string operations is the significant overhead of branching when performing @@ -583,24 +678,24 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz // iteration of a recursive algorithm. sz_size_t memory_usage = sequence->count * sizeof(sz_pgram_t); sz_pgram_t *pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); - if (!pgrams) return sz_false_k; + if (!pgrams) return sz_bad_alloc_k; // Recursively sort the whole sequence. _sz_sequence_argsort_serial_next_pgrams(sequence, pgrams, order, 0, sequence->count, 0); // Free temporary storage. alloc->free(pgrams, memory_usage, alloc); - return sz_true_k; + return sz_success_k; } -SZ_PUBLIC sz_bool_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_status_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { sz_unused(alloc); // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != count; ++i) order[i] = i; // Reuse the string sorting algorithm for sorting the "pgrams". _sz_sequence_argsort_serial_recursively((sz_pgram_t *)pgrams, order, 0, count); - return sz_true_k; + return sz_success_k; } #pragma endregion // Serial QuickSort Implementation @@ -658,8 +753,8 @@ SZ_INTERNAL void _sz_sequence_argsort_stable_serial_merge( _sz_assert(merged_begin[i - 1] <= merged_begin[i] && "The merged pgrams must be in ascending order."); } -SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != count; ++i) order[i] = i; @@ -668,7 +763,7 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c // without any smart optimizations or memory allocations. if (count <= 32) { sz_pgrams_sort_stable_with_insertion(pgrams, count, order); - return sz_true_k; + return sz_success_k; } // Go through short chunks of 8 elements and sort them with a sorting network. @@ -686,7 +781,7 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c sz_size_t memory_usage = sizeof(sz_pgram_t) * count + sizeof(sz_sorted_idx_t) * count; sz_pgram_t *pgrams_temporary = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); sz_sorted_idx_t *order_temporary = (sz_sorted_idx_t *)(pgrams_temporary + count); - if (!pgrams_temporary) return sz_false_k; + if (!pgrams_temporary) return sz_bad_alloc_k; // Set initial run size (the sorted chunks). sz_size_t run_size = 8; @@ -731,7 +826,13 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c // Free the temporary memory used for merging. alloc->free(pgrams_temporary, memory_usage, alloc); - return sz_true_k; + return sz_success_k; +} + +SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + + return sz_success_k; } #pragma endregion // Serial MergeSort Implementation @@ -744,19 +845,17 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t c * * We are going to use VBMI2 for `_mm256_maskz_compress_epi8`. */ -#pragma region Ice Lake Implementation -#if SZ_USE_ICE +#pragma region Skylake Implementation +#if SZ_USE_SKYLAKE #pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") -#pragma clang attribute push( \ - __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ - apply_to = function) +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) /** * @brief The most important part of the QuickSort algorithm partitioning the elements around the pivot. * Unlike the serial algorithm, uses compressed stores to filter and move the elements around the pivot. */ -SZ_INTERNAL void _sz_sequence_argsort_ice_3way_partition( // +SZ_INTERNAL void _sz_sequence_argsort_skylake_3way_partition( // sz_pgram_t *const initial_pgrams, sz_sorted_idx_t *const initial_order, // sz_pgram_t *const partitioned_pgrams, sz_sorted_idx_t *const partitioned_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence, // @@ -850,10 +949,10 @@ SZ_INTERNAL void _sz_sequence_argsort_ice_3way_partition( } /** - * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort_ice` and `sz_pgrams_sort_ice`, - * and using the `_sz_sequence_argsort_ice_3way_partition` under the hood. + * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort_skylake` and + * `sz_pgrams_sort_skylake`, and using the `_sz_sequence_argsort_skylake_3way_partition` under the hood. */ -SZ_INTERNAL void _sz_sequence_argsort_ice_recursively( // +SZ_INTERNAL void _sz_sequence_argsort_skylake_recursively( // sz_pgram_t *initial_pgrams, sz_sorted_idx_t *initial_order, // sz_pgram_t *temporary_pgrams, sz_sorted_idx_t *temporary_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence) { @@ -870,24 +969,24 @@ SZ_INTERNAL void _sz_sequence_argsort_ice_recursively( // // Partition the collection around some pivot sz_size_t first_pivot_index, last_pivot_index; - _sz_sequence_argsort_ice_3way_partition( // + _sz_sequence_argsort_skylake_3way_partition( // initial_pgrams, initial_order, temporary_pgrams, temporary_order, // start_in_sequence, end_in_sequence, // &first_pivot_index, &last_pivot_index); // Recursively sort the left and right partitions, if there are at least 2 elements in each if (start_in_sequence + 1 < first_pivot_index) - _sz_sequence_argsort_ice_recursively( // + _sz_sequence_argsort_skylake_recursively( // initial_pgrams, initial_order, temporary_pgrams, temporary_order, // start_in_sequence, first_pivot_index); if (last_pivot_index + 2 < end_in_sequence) - _sz_sequence_argsort_ice_recursively( // + _sz_sequence_argsort_skylake_recursively( // initial_pgrams, initial_order, temporary_pgrams, temporary_order, // last_pivot_index + 1, end_in_sequence); } -SZ_PUBLIC sz_bool_t sz_pgrams_sort_ice(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_status_t sz_pgrams_sort_skylake(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != count; ++i) order[i] = i; @@ -896,14 +995,14 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_ice(sz_pgram_t *pgrams, sz_size_t count, sz_m sz_size_t memory_usage = sizeof(sz_pgram_t) * count + sizeof(sz_sorted_idx_t) * count; sz_pgram_t *temporary_pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); sz_sorted_idx_t *temporary_order = (sz_sorted_idx_t *)(temporary_pgrams + count); - if (!temporary_pgrams) return sz_false_k; + if (!temporary_pgrams) return sz_bad_alloc_k; // Reuse the string sorting algorithm for sorting the "pgrams". - _sz_sequence_argsort_ice_recursively(pgrams, order, temporary_pgrams, temporary_order, 0, count); + _sz_sequence_argsort_skylake_recursively(pgrams, order, temporary_pgrams, temporary_order, 0, count); // Deallocate the temporary memory used for partitioning. alloc->free(temporary_pgrams, memory_usage, alloc); - return sz_true_k; + return sz_success_k; } /** @@ -911,7 +1010,7 @@ SZ_PUBLIC sz_bool_t sz_pgrams_sort_ice(sz_pgram_t *pgrams, sz_size_t count, sz_m * It combines `_sz_sequence_argsort_serial_export_next_pgrams` and `_sz_sequence_argsort_serial_recursively`, * recursively diving into the identical pgrams. */ -SZ_PUBLIC void _sz_sequence_argsort_ice_next_pgrams( // +SZ_PUBLIC void _sz_sequence_argsort_skylake_next_pgrams( // sz_sequence_t const *const sequence, // sz_pgram_t *const global_pgrams, sz_sorted_idx_t *const global_order, // sz_pgram_t *const temporary_pgrams, sz_sorted_idx_t *const temporary_order, // @@ -923,8 +1022,8 @@ SZ_PUBLIC void _sz_sequence_argsort_ice_next_pgrams( end_in_sequence, start_character); // Sort current pgrams with a quicksort - _sz_sequence_argsort_ice_recursively(global_pgrams, global_order, temporary_pgrams, temporary_order, - start_in_sequence, end_in_sequence); + _sz_sequence_argsort_skylake_recursively(global_pgrams, global_order, temporary_pgrams, temporary_order, + start_in_sequence, end_in_sequence); // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. @@ -944,17 +1043,17 @@ SZ_PUBLIC void _sz_sequence_argsort_ice_next_pgrams( int has_multiple_strings = nested_end - nested_start > 1; int has_more_characters_in_each = current_pgram_length == pgram_capacity; if (has_multiple_strings && has_more_characters_in_each) { - _sz_sequence_argsort_ice_next_pgrams(sequence, global_pgrams, global_order, temporary_pgrams, - temporary_order, nested_start, nested_end, - start_character + pgram_capacity); + _sz_sequence_argsort_skylake_next_pgrams(sequence, global_pgrams, global_order, temporary_pgrams, + temporary_order, nested_start, nested_end, + start_character + pgram_capacity); } // Move to the next nested_start = nested_end; } } -SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // First, initialize the `order` with `std::iota`-like behavior. sz_size_t count = sequence->count; @@ -964,7 +1063,7 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_me // without any smart optimizations or memory allocations. if (count <= 32) { sz_sequence_argsort_with_insertion(sequence, order); - return sz_true_k; + return sz_success_k; } // Allocate memory for partitioning the elements around the pivot. @@ -972,20 +1071,20 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_me sz_pgram_t *global_pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); sz_pgram_t *temporary_pgrams = global_pgrams + count; sz_sorted_idx_t *temporary_order = (sz_sorted_idx_t *)(temporary_pgrams + count); - if (!global_pgrams) return sz_false_k; + if (!global_pgrams) return sz_bad_alloc_k; // Recursively sort the whole sequence. - _sz_sequence_argsort_ice_next_pgrams(sequence, global_pgrams, order, temporary_pgrams, temporary_order, // - 0, count, 0); + _sz_sequence_argsort_skylake_next_pgrams(sequence, global_pgrams, order, temporary_pgrams, temporary_order, // + 0, count, 0); // Free temporary storage. alloc->free(global_pgrams, memory_usage, alloc); - return sz_true_k; + return sz_success_k; } #pragma clang attribute pop #pragma GCC pop_options -#endif // SZ_USE_ICE +#endif // SZ_USE_SKYLAKE #pragma endregion // Ice Lake Implementation /* Pick the right implementation for the string search algorithms. @@ -994,10 +1093,10 @@ SZ_PUBLIC sz_bool_t sz_sequence_argsort_ice(sz_sequence_t const *sequence, sz_me #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { -#if SZ_USE_ICE - return sz_sequence_argsort_ice(sequence, alloc, order); +SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { +#if SZ_USE_SKYLAKE + return sz_sequence_argsort_skylake(sequence, alloc, order); #elif SZ_USE_SVE return sz_sequence_argsort_sve(sequence, alloc, order); #else @@ -1005,10 +1104,10 @@ SZ_DYNAMIC sz_bool_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memor #endif } -SZ_DYNAMIC sz_bool_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { -#if SZ_USE_ICE - return sz_pgrams_sort_ice(pgrams, count, alloc, order); +SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { +#if SZ_USE_SKYLAKE + return sz_pgrams_sort_skylake(pgrams, count, alloc, order); #elif SZ_USE_SVE return sz_pgrams_sort_sve(pgrams, count, alloc, order); #else @@ -1016,10 +1115,10 @@ SZ_DYNAMIC sz_bool_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memo #endif } -SZ_DYNAMIC sz_bool_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { -#if SZ_USE_ICE - return sz_sequence_argsort_ice(sequence, alloc, order); +SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { +#if SZ_USE_SKYLAKE + return sz_sequence_argsort_skylake(sequence, alloc, order); #elif SZ_USE_SVE return sz_sequence_argsort_sve(sequence, alloc, order); #else @@ -1027,10 +1126,10 @@ SZ_DYNAMIC sz_bool_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, s #endif } -SZ_DYNAMIC sz_bool_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { -#if SZ_USE_ICE - return sz_pgrams_sort_ice(pgrams, count, alloc, order); +SZ_DYNAMIC sz_status_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { +#if SZ_USE_SKYLAKE + return sz_pgrams_sort_skylake(pgrams, count, alloc, order); #elif SZ_USE_SVE return sz_pgrams_sort_sve(pgrams, count, alloc, order); #else diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 825e36f9..09b44eff 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -331,6 +331,25 @@ typedef sz_size_t sz_pgram_t; // "Pointer-sized N-gram" of a string typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> +/** + * @brief Describes an error status of a function. + */ +typedef enum { + /** + * For algorithms that return a status, this status indicates that the operation was successful. + */ + sz_success_k = 0, + /** + * For algorithms that require memory allocation, this status indicates that the allocation failed. + */ + sz_bad_alloc_k = -1, + /** + * For algorithms that have an upper bound on some parameter, like the maximum number of iterations, + * or the maximum edit distance, this status indicates that the limit was reached. + */ + sz_reached_limit_k = -2, +} sz_status_t; + /** * @brief Describes the length of a UTF8 @b rune / character / codepoint in bytes. */ From dc7c109ef088faa015bf2135211c9d7978d8ff54 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 13:57:00 +0000 Subject: [PATCH 117/751] Add: Missing `sz_sequence_t` helpers --- include/stringzilla/types.h | 112 +++++++++++++++++++++--------------- scripts/test.cpp | 15 ++++- 2 files changed, 80 insertions(+), 47 deletions(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 09b44eff..75d76e61 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -461,68 +461,71 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void #pragma region API Signature Types -/** @brief Signature of ::sz_hash. */ +/** @brief Signature of `sz_hash`. */ typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t, sz_u64_t); -/** @brief Signature of ::sz_hash_state_init. */ +/** @brief Signature of `sz_hash_state_init`. */ typedef void (*sz_hash_state_init_t)(struct sz_hash_state_t *, sz_u64_t); -/** @brief Signature of ::sz_hash_state_stream. */ +/** @brief Signature of `sz_hash_state_stream`. */ typedef void (*sz_hash_state_stream_t)(struct sz_hash_state_t *, sz_cptr_t, sz_size_t); -/** @brief Signature of ::sz_hash_state_fold. */ +/** @brief Signature of `sz_hash_state_fold`. */ typedef sz_u64_t (*sz_hash_state_fold_t)(struct sz_hash_state_t const *); -/** @brief Signature of ::sz_bytesum. */ +/** @brief Signature of `sz_bytesum`. */ typedef sz_u64_t (*sz_bytesum_t)(sz_cptr_t, sz_size_t); -/** @brief Signature of ::sz_generate. */ +/** @brief Signature of `sz_generate`. */ typedef void (*sz_generate_t)(sz_ptr_t, sz_size_t, sz_u64_t); -/** @brief Signature of ::sz_equal. */ +/** @brief Signature of `sz_equal`. */ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -/** @brief Signature of ::sz_order. */ +/** @brief Signature of `sz_order`. */ typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -/** @brief Signature of ::sz_look_up_transform. */ -typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); +/** @brief Signature of `sz_lookup`. */ +typedef void (*sz_lookup_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); -/** @brief Signature of ::sz_move. */ +/** @brief Signature of `sz_move`. */ typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); -/** @brief Signature of ::sz_fill. */ +/** @brief Signature of `sz_fill`. */ typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); -/** @brief Signature of ::sz_find_byte. */ +/** @brief Signature of `sz_find_byte`. */ typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -/** @brief Signature of ::sz_find. */ +/** @brief Signature of `sz_find`. */ typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -/** @brief Signature of ::sz_find_set. */ +/** @brief Signature of `sz_find_set`. */ typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); -/** @brief Signature of ::sz_hamming_distance. */ -typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); +/** @brief Signature of `sz_hamming_distance`. */ +typedef sz_status_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_size_t *); -/** @brief Signature of ::sz_edit_distance. */ -typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); +/** @brief Signature of `sz_levenshtein_distance`. */ +typedef sz_status_t (*sz_levenshtein_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, + sz_memory_allocator_t *, sz_size_t *); -/** @brief Signature of ::sz_alignment_score. */ -typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, - sz_error_cost_t, sz_memory_allocator_t *); +/** @brief Signature of `sz_needleman_wunsch_score`. */ +typedef sz_status_t (*sz_needleman_wunsch_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, + sz_error_cost_t, sz_memory_allocator_t *, sz_ssize_t *); -/** @brief Signature of ::sz_sequence_argsort. */ -typedef sz_bool_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); +/** @brief Signature of `sz_sequence_argsort`. */ +typedef sz_status_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *, + sz_bool_t *); -/** @brief Signature of ::sz_pgrams_sort. */ -typedef sz_bool_t (*sz_pgrams_sort_t)(sz_pgram_t *, sz_size_t, sz_memory_allocator_t *, sz_sorted_idx_t *); +/** @brief Signature of `sz_pgrams_sort`. */ +typedef sz_status_t (*sz_pgrams_sort_t)(sz_pgram_t *, sz_size_t, sz_memory_allocator_t *, sz_sorted_idx_t *, + sz_bool_t *); -/** @brief Signature of ::sz_sequence_argsort_stable. */ +/** @brief Signature of `sz_sequence_argsort_stable`. */ typedef sz_sequence_argsort_t sz_sequence_argsort_stable_t; -/** @brief Signature of ::sz_pgrams_sort_stable. */ +/** @brief Signature of `sz_pgrams_sort_stable`. */ typedef sz_pgrams_sort_t sz_pgrams_sort_stable_t; #pragma endregion @@ -683,9 +686,17 @@ SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_le #pragma region String Sequences API -typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); -typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); +/** @brief Signature of `sz_sequence_t::get_start` used to get the start of a member string at a given index. */ +typedef sz_cptr_t (*sz_sequence_member_start_t)(void const *, sz_size_t); +/** @brief Signature of `sz_sequence_t::get_length` used to get the length of a member string at a given index. */ +typedef sz_size_t (*sz_sequence_member_length_t)(void const *, sz_size_t); +/** + * @brief Structure to represent an ordered collection of strings. + * It's a generic structure that can be used to represent a sequence of strings in different layouts. + * It can be easily combined with Apache Arrow and its tape-like concatenated strings. + * @sa sz_sequence_from_null_terminated_strings + */ typedef struct sz_sequence_t { void const *handle; sz_size_t count; @@ -694,20 +705,12 @@ typedef struct sz_sequence_t { } sz_sequence_t; /** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. - */ -SZ_PUBLIC void sz_sequence_from_u32tape( // - sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, sz_sequence_t *sequence); - -/** - * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. - * Expects ::offsets to contains `count + 1` entries, the last pointing at the end - * of the last string, indicating the total length of the ::tape. + * @brief Initiates the sequence structure from a typical C-style strings array, like `char *[]`. + * @param[in] start Pointer to the array of strings. + * @param[in] count Number of strings in the array. + * @param[out] sequence Sequence structure to initialize. */ -SZ_PUBLIC void sz_sequence_from_u64tape( // - sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, sz_sequence_t *sequence); +SZ_PUBLIC void sz_sequence_from_null_terminated_strings(sz_cptr_t *start, sz_size_t count, sz_sequence_t *sequence); #pragma endregion @@ -857,7 +860,7 @@ SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } /** - * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. + * @brief Select bits from either @p a or @p b depending on the value of @p mask bits. * * Similar to `_mm_blend_epi16` intrinsic on x86. * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. @@ -987,7 +990,7 @@ SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { } /** - * @brief Compute the smallest power of two greater than or equal to ::x. + * @brief Compute the smallest power of two greater than or equal to @p x. */ SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. @@ -1149,6 +1152,25 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *(sz_ptr_t)buffer = *(sz_cptr_t)&length; } +SZ_PUBLIC sz_cptr_t _sz_sequence_from_null_terminated_strings_get_start(void const *handle, sz_size_t i) { + sz_cptr_t const *start = (sz_cptr_t const *)handle; + return start[i]; +} + +SZ_PUBLIC sz_size_t _sz_sequence_from_null_terminated_strings_get_length(void const *handle, sz_size_t i) { + sz_cptr_t const *start = (sz_cptr_t const *)handle; + sz_size_t length = 0; + for (sz_cptr_t ptr = start[i]; *ptr; ptr++) length++; + return length; +} + +SZ_PUBLIC void sz_sequence_from_null_terminated_strings(sz_cptr_t *start, sz_size_t count, sz_sequence_t *sequence) { + sequence->handle = start; + sequence->count = count; + sequence->get_start = _sz_sequence_from_null_terminated_strings_get_start; + sequence->get_length = _sz_sequence_from_null_terminated_strings_get_length; +} + #pragma endregion #ifdef __cplusplus diff --git a/scripts/test.cpp b/scripts/test.cpp index 74282523..e3a62f3d 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1576,8 +1576,8 @@ void test_replacements(std::size_t lookup_tables_to_try = 128, std::size_t slice std::size_t slice_offset = std::rand() % (body.length()); std::size_t slice_length = std::rand() % (body.length() - slice_offset); - sz::transform(sz::string_view(body.data() + slice_offset, slice_length), lut, - const_cast(transformed.data()) + slice_offset); + sz::lookup(sz::string_view(body.data() + slice_offset, slice_length), lut, + const_cast(transformed.data()) + slice_offset); for (std::size_t i = 0; i != slice_length; ++i) { assert(transformed[slice_offset + i] == lut[body[slice_offset + i]]); } @@ -1592,6 +1592,17 @@ static void test_sequence_algorithms() { using strs_t = std::vector; using order_t = std::vector; + // Make sure teh helper functions work as expected. + { + sz_sequence_t sequence; + sz_cptr_t strings[] = {"banana", "apple", "cherry"}; + sz_sequence_from_null_terminated_strings(strings, 3, &sequence); + assert(sequence.size == 3); + assert(sequence.get_start(sequence.handle, 0) == "banana"_sv); + assert(sequence.get_start(sequence.handle, 1) == "apple"_sv); + assert(sequence.get_start(sequence.handle, 2) == "cherry"_sv); + } + // Basic tests with predetermined orders. assert_scoped(strs_t x({"a", "b", "c", "d"}), (void)0, sz::argsort(x) == order_t({0u, 1u, 2u, 3u})); assert_scoped(strs_t x({"b", "c", "d", "a"}), (void)0, sz::argsort(x) == order_t({3u, 0u, 1u, 2u})); From 5a12c00d1dd6794da06f5267f01519394a1b49cf Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 21 Feb 2025 14:09:31 +0000 Subject: [PATCH 118/751] Improve: Use default allocator, when not provided --- include/stringzilla/sort.h | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index a394b646..c422e3c3 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -665,6 +665,13 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, return sz_success_k; } + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + // One of the reasons for slow string operations is the significant overhead of branching when performing // individual string comparisons. // @@ -773,6 +780,13 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t sz_size_t const tail_count = count & 7u; sz_pgrams_sort_stable_with_insertion(pgrams + count - tail_count, tail_count, order + count - tail_count); + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + // At this point, the array is partitioned into sorted runs. // We'll now merge these runs until the whole array is sorted. // Allocate temporary memory to hold merged results: @@ -991,6 +1005,13 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_skylake(sz_pgram_t *pgrams, sz_size_t count // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != count; ++i) order[i] = i; + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + // Allocate memory for partitioning the elements around the pivot. sz_size_t memory_usage = sizeof(sz_pgram_t) * count + sizeof(sz_sorted_idx_t) * count; sz_pgram_t *temporary_pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); @@ -1066,6 +1087,13 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, return sz_success_k; } + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + // Allocate memory for partitioning the elements around the pivot. sz_size_t memory_usage = sizeof(sz_pgram_t) * count * 2 + sizeof(sz_sorted_idx_t) * count; sz_pgram_t *global_pgrams = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); From 7698392efcfab4bb1dee856ceb88f9e618449638 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 22 Feb 2025 12:27:24 +0000 Subject: [PATCH 119/751] Improve: Clean `memory.h` header --- include/stringzilla/memory.h | 138 +++++++++++++++-------------------- 1 file changed, 59 insertions(+), 79 deletions(-) diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index de739f22..cc5cc6d7 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -5,15 +5,27 @@ * * Includes core APIs for contiguous memory operations: * - * - `sz_copy` - analog to `memcpy` - * - `sz_move` - analog to `memmove` - * - `sz_fill` - analog to `memset` - * - `sz_lookup` - LUT transformation of a string, similar to OpenCV LUT - * - TODO: `sz_detect_encoding` - similar to `iconv` or `chardet` + * - @b `sz_copy` - analog to `memcpy`, probably the most common operation in a computer + * - @b `sz_move` - analog to `memmove`, allowing overlapping memory regions, often used in string manipulation + * - @b `sz_fill` - analog to `memset`, often used to initialize memory with a constant value, like zero + * - @b `sz_lookup` - Look-Up Table @b (LUT) transformation of a string, mapping each byte to a new value + * - TODO: @b `sz_lookup_utf8` - LUT transformation of a UTF8 string, which can be used for normalization + * - TODO: @b `sz_detect_encoding` - detects the character encoding similar to "iconv" or "chardet" tools * - * Convenience functions for character-set mapping: + * All of the core APIs receive the target output buffer as the first argument, + * and aim to minimize the number of "store" instructions, especially unaligned ones, + * that can invalidate 2 cache lines. * - * - `sz_tolower`, `sz_toupper`, `sz_toascii` for ASCII ranges + * Unlike many other libraries focusing on trivial SIMD transformations, like converting + * lowercase to uppercase, StringZilla generalizes those to basic lookup table transforms. + * For typical ASCII conversions, you can use the following @b LUT initialization functions: + * + * - `sz_lookup_init_lower` for transforms like `tolower` + * - `sz_lookup_init_upper` for transforms like `toupper` + * - `sz_lookup_init_ascii` for transforms like `isascii` + * + * The header also exposes a minimalistic @b `sz_isascii` which can be used in UTF-8 capable + * methods to select a simpler execution path for ASCII characters. */ #ifndef STRINGZILLA_MEMORY_H_ #define STRINGZILLA_MEMORY_H_ @@ -28,6 +40,7 @@ extern "C" { /** * @brief Similar to `memcpy`, copies contents of one string into another. + * @see https://en.cppreference.com/w/c/string/byte/memcpy * * @param[out] target String to copy into. Can be `NULL`, if the @p length is zero. * @param[in] length Number of bytes to copy. Can be a zero. @@ -55,6 +68,7 @@ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** * @brief Similar to `memmove`, copies (moves) contents of one string into another. * Unlike `sz_copy`, allows overlapping strings as arguments. + * @see https://en.cppreference.com/w/c/string/byte/memmove * * @param[out] target String to copy into. Can be `NULL`, if the @p length is zero. * @param[in] length Number of bytes to copy. Can be a zero. @@ -78,6 +92,7 @@ SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** * @brief Similar to `memset`, fills a string with a given value. + * @see https://en.cppreference.com/w/c/string/byte/memset * * @param[out] target String to fill. Can be `NULL`, if the @p length is zero. * @param[in] length Number of bytes to fill. Can be a zero. @@ -184,52 +199,17 @@ SZ_PUBLIC void sz_lookup_neon(sz_ptr_t target, sz_size_t length, sz_cptr_t sourc #pragma region Helper API /** - * @brief Equivalent to `for (char & c : text) c = tolower(c)`. - * - * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. - * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. - * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html - * - * @param text String to be normalized. - * @param[in] length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toupper(c)`. + * @brief Initializes a lookup table for converting ASCII characters to lowercase. * * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. * So there are 26 english letters, shifted by 32 values, meaning that a conversion - * can be done by flipping the 5th bit each inappropriate character byte. This, however, - * breaks for extended ASCII, so a different solution is needed. + * can be done by flipping the 5th bit each inappropriate character byte. + * This, however, breaks for extended ASCII, so a different solution is needed. * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html * - * @param text String to be normalized. - * @param[in] length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. - */ -SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -/** - * @brief Equivalent to `for (char & c : text) c = toascii(c)`. - * - * @param text String to be normalized. - * @param[in] length Number of bytes in the string. - * @param result Output string, can point to the same address as ::text. + * @param[out] lut Lookup table to be initialized. Must be exactly 256 bytes long. */ -SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); - -#pragma endregion // Helper API - -#pragma region Serial Implementation - -/** - * @brief Uses a small lookup-table to convert a lowercase character to uppercase. - */ -SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { +SZ_PUBLIC void sz_lookup_init_lower(sz_ptr_t lut) { static sz_u8_t const lowered[256] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // @@ -248,13 +228,21 @@ SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // }; - return lowered[c]; + for (sz_size_t i = 0; i < 256; ++i) lut[i] = lowered[i]; } /** - * @brief Uses a small lookup-table to convert an uppercase character to lowercase. + * @brief Initializes a lookup table for converting ASCII characters to uppercase. + * + * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. + * So there are 26 english letters, shifted by 32 values, meaning that a conversion + * can be done by flipping the 5th bit each inappropriate character byte. + * This, however, breaks for extended ASCII, so a different solution is needed. + * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * + * @param[out] lut Lookup table to be initialized. Must be exactly 256 bytes long. */ -SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { +SZ_PUBLIC void sz_lookup_init_upper(sz_ptr_t lut) { static sz_u8_t const upped[256] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // @@ -273,43 +261,23 @@ SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // }; - return upped[c]; -} - -SZ_PUBLIC void sz_lookup_serial(sz_ptr_t result, sz_size_t length, sz_cptr_t text, sz_cptr_t lut) { - sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; -} - -SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); + for (sz_size_t i = 0; i < 256; ++i) lut[i] = upped[i]; } -SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); -} - -SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { - sz_u8_t *unsigned_result = (sz_u8_t *)result; - sz_u8_t const *unsigned_text = (sz_u8_t const *)text; - sz_u8_t const *end = unsigned_text + length; - for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; +/** + * @brief Initializes a lookup table for converting bytes to ASCII characters. + * + * @param[out] lut Lookup table to be initialized. Must be exactly 256 bytes long. + */ +SZ_PUBLIC void sz_lookup_init_ascii(sz_ptr_t lut) { + for (sz_size_t i = 0; i < 256; ++i) lut[i] = (sz_u8_t)(i & 0x7F); } /** * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. */ -SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length) { if (!length) return sz_true_k; sz_u8_t const *h = (sz_u8_t const *)text; @@ -334,6 +302,18 @@ SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { return sz_true_k; } +#pragma endregion // Helper API + +#pragma region Serial Implementation + +SZ_PUBLIC void sz_lookup_serial(sz_ptr_t result, sz_size_t length, sz_cptr_t text, sz_cptr_t lut) { + sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; + sz_u8_t const *unsigned_text = (sz_u8_t const *)text; + sz_u8_t *unsigned_result = (sz_u8_t *)result; + sz_u8_t const *end = unsigned_text + length; + for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; +} + // When overriding libc, disable optimizations for this function because MSVC will optimize the loops into a `memset`. // Which then causes a stack overflow due to infinite recursion (`memset` -> `sz_fill_serial` -> `memset`). #if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC From 095bc2da38574f8a92ec053efa4c7261eaf1a730 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 22 Feb 2025 12:28:20 +0000 Subject: [PATCH 120/751] Break: New calling convention in `similarity.h` --- include/stringzilla/similarity.h | 588 +++++++++++++++++----------- include/stringzilla/stringzilla.hpp | 54 +-- 2 files changed, 394 insertions(+), 248 deletions(-) diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index 188169ff..60540b33 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -5,9 +5,9 @@ * * Includes core APIs: * - * - `sz_edit_distance` & `sz_edit_distance_utf8` for Levenshtein edit-distance computation. - * - `sz_alignment_score` for weighted Needleman-Wunsch global alignment. * - `sz_hamming_distance` & `sz_hamming_distance_utf8` for Hamming distance computation. + * - `sz_levenshtein_distance` & `sz_levenshtein_distance_utf8` for Levenshtein edit-distance computation. + * - `sz_needleman_wunsch_score` for weighted Needleman-Wunsch global alignment. * * The Hamming distance is rarely used in string processing, so only minimal compatibility is provided. * The Levenshtein distance, however, is much more popular and computationally intensive. @@ -26,130 +26,220 @@ extern "C" { #pragma region Core API /** - * @brief Computes the Hamming distance between two strings - number of not matching characters. - * Difference in length is is counted as a mismatch. + * @brief Computes the Hamming distance between two strings. + * @see https://en.wikipedia.org/wiki/Hamming_distance * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. + * The Hamming distance is defined as the number of positions at which the corresponding bytes differ. + * If the strings have different lengths, the extra characters in the longer string are treated as mismatches. * - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. + * If the running distance reaches the @p bound, the computation aborts early. If the @p bound is zero, + * the function merely checks for equality. If the @p bound is larger than the maximum length of the strings, + * the function will compute the full "unbounded" distance. * - * @see sz_hamming_distance_utf8 - * @see https://en.wikipedia.org/wiki/Hamming_distance + * @param[in] a Pointer to the first string. + * @param[in] a_length Number of bytes in the first string. + * @param[in] b Pointer to the second string. + * @param[in] b_length Number of bytes in the second string. + * @param[in] bound Exclusive upper bound on the computed distance. + * @param[out] result On success, the computed byte-level Hamming distance is stored here. + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * + * Example usage: + * + * @code{.c} + * #include + * int main(void) { + * char const *s1 = "1011101"; + * char const *s2 = "1001001"; + * sz_size_t result, length = 7, bound = 10; + * sz_status_t status = sz_hamming_distance(s1, length, s2, length, bound, &result); + * return (status == sz_success_k && result == 2) ? 0 : 1; + * } + * @endcode + * + * @note This function isn't intended for UTF-8 texts and is not heavily optimized. + * @sa sz_hamming_distance_utf8 */ -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); +SZ_DYNAMIC sz_status_t sz_hamming_distance( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound, sz_size_t *result); /** - * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. - * Difference in length is is counted as a mismatch. + * @brief Computes the Hamming distance between two @b UTF-8 encoded strings. + * @see https://en.wikipedia.org/wiki/Hamming_distance * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. + * The Hamming distance is defined as the number of positions at which the corresponding Unicode runes differ. + * If the strings have different lengths, the extra characters in the longer string are treated as mismatches. * - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. + * If the running distance reaches the @p bound, the computation aborts early. If the @p bound is zero, + * the function merely checks for equality. If the @p bound is larger than the maximum length of the strings, + * the function will compute the full "unbounded" distance. * - * @see sz_hamming_distance - * @see https://en.wikipedia.org/wiki/Hamming_distance + * @param[in] a Pointer to the first string. + * @param[in] a_length Number of bytes in the first string. + * @param[in] b Pointer to the second string. + * @param[in] b_length Number of bytes in the second string. + * @param[in] bound Exclusive upper bound on the computed distance. + * @param[out] result On success, the computed Unicode character-level Hamming distance is stored here. + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @retval `sz_invalid_utf8_k` if the input strings are not valid UTF-8. + * + * Example usage: + * + * @code{.c} + * #include + * int main(void) { + * char const *s1 = "café"; + * char const *s2 = "cafe"; + * sz_size_t result, length1 = 5, length2 = 4, bound = 10; + * sz_status_t status = sz_hamming_distance_utf8(s1, length1, s2, length2, bound, &result); + * return (status == sz_success_k && result == 1) ? 0 : 1; + * } + * @endcode + * + * @note This function isn't heavily optimized. + * @sa sz_hamming_distance */ -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); +SZ_DYNAMIC sz_status_t sz_hamming_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound, sz_size_t *result); /** * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. + * If the running distance reaches the @p bound, the computation aborts early. If the @p bound is zero, + * the function merely checks for equality. If the @p bound is larger than the maximum length of the strings, + * the function will compute the full "unbounded" distance. + * + * @param[in] a Pointer to the first string. + * @param[in] a_length Number of bytes in the first string. + * @param[in] b Pointer to the second string. + * @param[in] b_length Number of bytes in the second string. + * @param[in] bound Exclusive upper bound on the computed distance. + * @param[in] alloc Optional memory allocator. If `NULL` is passed, will use to the systems default `malloc`. + * + * @param[out] result On success, the computed byte-level Levenshtein distance is stored here. + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @retval `sz_invalid_utf8_k` if the input strings are not valid UTF-8. * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. + * Example usage: * - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. - * Returns `SZ_SIZE_MAX` if the memory allocation failed. + * @code{.c} + * #include + * int main(void) { + * char const *s1 = "kitten"; + * char const *s2 = "sitting"; + * sz_size_t result, length1 = 6, length2 = 7, bound = 10; + * sz_status_t status = sz_levenshtein_distance(s1, length1, s2, length2, bound, NULL, &result); + * return (status == sz_success_k && result == 3) ? 0 : 1; + * } + * @endcode * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default + * @note The algorithm has linear memory complexity and @p a_length * @p b_length time complexity. * @see https://en.wikipedia.org/wiki/Levenshtein_distance + * + * @note This function isn't intended for UTF-8 texts. + * @sa sz_levenshtein_distance_utf8 + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_levenshtein_distance_serial, sz_levenshtein_distance_ice */ -SZ_DYNAMIC sz_size_t sz_edit_distance( // +SZ_DYNAMIC sz_status_t sz_levenshtein_distance( // sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result); /** - * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. - * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. + * @brief Computes the Levenshtein edit-distance between two @b UTF-8 strings using the Wagner-Fisher algorithm. + * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. + * + * If the running distance reaches the @p bound, the computation aborts early. If the @p bound is zero, + * the function merely checks for equality. If the @p bound is larger than the maximum length of the strings, + * the function will compute the full "unbounded" distance. * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. + * @param[in] a Pointer to the first string. + * @param[in] a_length Number of bytes in the first string. + * @param[in] b Pointer to the second string. + * @param[in] b_length Number of bytes in the second string. + * @param[in] bound Exclusive upper bound on the computed distance. + * @param[in] alloc Optional memory allocator. If `NULL` is passed, will use to the systems default `malloc`. * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. + * @param[out] result On success, the computed byte-level Levenshtein distance is stored here. + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @retval `sz_invalid_utf8_k` if the input strings are not valid UTF-8. * - * @param bound Exclusive upper bound on the distance, that allows us to exit early. - * Pass `SZ_SIZE_MAX` or any value greater than `(max(a_length, b_length))` to ignore. - * Pass zero to check if the strings are equal. - * @return Returns an unsigned integer for the edit distance. Zero means the strings are equal. - * Returns the `(max(a_length, b_length)) + 1` if the distance limit was reached. - * Returns `SZ_SIZE_MAX` if the memory allocation failed. + * Example usage: * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance + * @code{.c} + * #include + * int main(void) { + * char const *s1 = "café"; + * char const *s2 = "cafe"; + * sz_size_t result, length1 = 5, length2 = 4, bound = 10; + * sz_status_t status = sz_levenshtein_distance_utf8(s1, length1, s2, length2, bound, NULL, &result); + * return (status == sz_success_k && result == 1) ? 0 : 1; + * } + * @endcode + * + * @note The algorithm has linear memory complexity and @p a_length * @p b_length time complexity. * @see https://en.wikipedia.org/wiki/Levenshtein_distance + * + * @sa sz_levenshtein_distance + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_levenshtein_distance_utf8_serial, sz_levenshtein_distance_utf8_ice */ -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // +SZ_DYNAMIC sz_status_t sz_levenshtein_distance_utf8( // sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result); /** - * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. - * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. + * @brief Computes the Needleman–Wunsch alignment score for two strings. + * Often used in bioinformatics for sequence alignment. * - * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may - * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. - * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. + * This function calculates a similarity score by applying gap and substitution penalties, + * following the Needleman–Wunsch algorithm. Note that the result is generally @b not-commutative — + * that is, `sz_needleman_wunsch_score(a, b)` may differ from `sz_needleman_wunsch_score(b, a)` + * unless the @p subs matrix is symmetric. With a @p gap penalty of -1 and substitution costs defined + * as 0 for matches and -1 for mismatches, the score is equivalent to the negative Levenshtein distance. * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @param gap Penalty cost for gaps - insertions and removals. - * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. + * @param[in] a Pointer to the first string. + * @param[in] a_length Number of bytes in the first string. + * @param[in] b Pointer to the second string. + * @param[in] b_length Number of bytes in the second string. + * @param[in] subs Substitution cost matrix (256×256) for all pairs of characters. + * @param[in] gap Penalty cost for gaps (insertions and deletions). + * @param[in] alloc Optional memory allocator. If `NULL` is passed, the system default `malloc` is used. + * @param[out] result On success, the computed byte-level Levenshtein distance is stored here. + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. * - * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, - * so the memory usage is linear in relation to ::a_length and ::b_length. - * If SZ_NULL is passed, will initialize to the systems default `malloc`. + * Example usage: * - * @return Signed similarity score. Can be negative, depending on the substitution costs. - * Returns `SZ_SSIZE_MAX` if the memory allocation failed. + * @code{.c} + * #include + * int main(void) { + * char const *s1 = "GATTACA"; + * char const *s2 = "GCATGCU"; + * sz_error_cost_t subs[256][256] = { ... }; + * sz_error_cost_t gap = -1; + * sz_ssize_t score; + * sz_status_t status = sz_needleman_wunsch_score(s1, 7, s2, 7, subs, gap, NULL, &score); + * return (status == sz_success_k) ? 0 : 1; + * } + * @endcode * - * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default + * @note Algorithm has @p a_length * @p b_length worst-case time complexity and linear memory complexity. * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_needleman_wunsch_score_serial, sz_needleman_wunsch_score_ice */ -SZ_DYNAMIC sz_ssize_t sz_alignment_score( // +SZ_DYNAMIC sz_status_t sz_needleman_wunsch_score( // sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); + sz_memory_allocator_t *alloc, sz_ssize_t *result); /** * @brief Checks if all characters in the range are valid ASCII characters. @@ -161,40 +251,44 @@ SZ_DYNAMIC sz_ssize_t sz_alignment_score( // SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hamming_distance */ -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); +SZ_PUBLIC sz_status_t sz_hamming_distance_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result); /** @copydoc sz_hamming_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); +SZ_PUBLIC sz_status_t sz_hamming_distance_utf8_serial( // + sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result); -/** @copydoc sz_edit_distance */ -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // +/** @copydoc sz_levenshtein_distance */ +SZ_PUBLIC sz_status_t sz_levenshtein_distance_serial( // sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result); -/** @copydoc sz_edit_distance_utf8 */ -SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial( // +/** @copydoc sz_levenshtein_distance_utf8 */ +SZ_PUBLIC sz_status_t sz_levenshtein_distance_utf8_serial( // sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result); -/** @copydoc sz_alignment_score */ -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // +/** @copydoc sz_needleman_wunsch_score */ +SZ_PUBLIC sz_status_t sz_needleman_wunsch_score_serial( // sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc); + sz_memory_allocator_t *alloc, sz_ssize_t *result); #if SZ_USE_ICE -SZ_INTERNAL sz_size_t sz_edit_distance_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc); +/** @copydoc sz_levenshtein_distance */ +SZ_PUBLIC sz_status_t sz_levenshtein_distance_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result); -SZ_INTERNAL sz_ssize_t sz_alignment_score_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc); +/** @copydoc sz_needleman_wunsch_score */ +SZ_PUBLIC sz_status_t sz_needleman_wunsch_score_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc, sz_ssize_t *result); #endif @@ -202,10 +296,10 @@ SZ_INTERNAL sz_ssize_t sz_alignment_score_ice( // #pragma region Serial Implementation -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { +SZ_INTERNAL sz_status_t _sz_levenshtein_distance_skewed_diagonals_serial( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result_ptr) { // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; @@ -224,7 +318,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // sz_size_t n = shorter_length + 1; sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); - if (!distances) return SZ_SIZE_MAX; + if (!distances) return sz_bad_alloc_k; sz_size_t *previous_distances = distances; sz_size_t *current_distances = previous_distances + n; @@ -276,7 +370,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // // Cache scalar before `free` call. sz_size_t result = current_distances[0]; alloc->free(distances, buffer_length, alloc->handle); - return result; + *result_ptr = result; + return sz_success_k; } /** @@ -290,10 +385,10 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. * = 2400 bytes of memory or @b 12x memory amplification! */ -SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { +SZ_INTERNAL sz_status_t _sz_levenshtein_distance_wagner_fisher_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc, sz_size_t *result_ptr) { // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; @@ -329,7 +424,7 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // // If the allocation fails, return the maximum distance. sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); - if (!buffer) return SZ_SIZE_MAX; + if (!buffer) return sz_bad_alloc_k; // Let's export the UTF8 sequence into the newly allocated buffer at the end. if (can_be_unicode == sz_true_k) { @@ -378,7 +473,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // /* Cache scalar before `free` call. */ \ sz_size_t result = previous_distances[shorter_length]; \ alloc->free(buffer, buffer_length, alloc->handle); \ - return result; + *result_ptr = result; \ + return sz_success_k; // Let's define a separate variant for bounded distance computation. // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. @@ -408,7 +504,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // /* If the minimum distance in this row exceeded the bound, return early */ \ if (min_distance >= bound) { \ alloc->free(buffer, buffer_length, alloc->handle); \ - return longer_length + 1; \ + *result_ptr = bound; \ + return sz_success_k; \ } \ _distance_t *temporary = previous_distances; \ previous_distances = current_distances; \ @@ -416,7 +513,8 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // } \ sz_size_t result = previous_distances[shorter_length]; \ alloc->free(buffer, buffer_length, alloc->handle); \ - return result; + *result_ptr = result; \ + return sz_success_k; // Dispatch the actual computation. if (!bound) { @@ -429,10 +527,10 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // } } -SZ_PUBLIC sz_size_t sz_edit_distance_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { +SZ_PUBLIC sz_status_t sz_levenshtein_distance_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result_ptr) { // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. @@ -452,28 +550,48 @@ SZ_PUBLIC sz_size_t sz_edit_distance_serial( // int const is_bounded = bound < longer_length; if (is_bounded) { // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); + if (longer_length == 0) { + *result_ptr = sz_min_of_two(shorter_length, bound); + return sz_success_k; + } + if (shorter_length == 0) { + *result_ptr = sz_min_of_two(longer_length, bound); + return sz_success_k; + } // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; + if (longer_length - shorter_length > bound) { + *result_ptr = bound; + return sz_success_k; + } } - if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. + // If no mismatches were found - the distance is zero. + if (shorter_length == 0) { + *result_ptr = longer_length; + return sz_success_k; + } if (shorter_length == longer_length && !is_bounded) - return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); - return _sz_edit_distance_wagner_fisher_serial( // - longer, longer_length, shorter, shorter_length, bound, sz_false_k, alloc); + return _sz_levenshtein_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, + alloc, result_ptr); + return _sz_levenshtein_distance_wagner_fisher_serial( // + longer, longer_length, shorter, shorter_length, bound, sz_false_k, alloc, result_ptr); } -SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // - sz_cptr_t longer, sz_size_t longer_length, // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, // - sz_memory_allocator_t *alloc) { +SZ_PUBLIC sz_status_t sz_needleman_wunsch_score_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc, sz_ssize_t *result_ptr) { // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; + if (longer_length == 0) { + *result_ptr = (sz_ssize_t)shorter_length * gap; + return sz_success_k; + } + if (shorter_length == 0) { + *result_ptr = (sz_ssize_t)longer_length * gap; + return sz_success_k; + } // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. @@ -519,13 +637,14 @@ SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // // Cache scalar before `free` call. sz_ssize_t result = previous_distances[shorter_length]; alloc->free(distances, buffer_length, alloc->handle); - return result; + *result_ptr = result; + return sz_success_k; } -SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { +SZ_PUBLIC sz_status_t sz_hamming_distance_serial( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result_ptr) { sz_size_t const min_length = sz_min_of_two(a_length, b_length); sz_size_t const max_length = sz_max_of_two(a_length, b_length); @@ -547,13 +666,14 @@ SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // #endif for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } - return sz_min_of_two(distance, bound); + *result_ptr = sz_min_of_two(distance, bound); + return sz_success_k; } -SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { +SZ_PUBLIC sz_status_t sz_hamming_distance_utf8_serial( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result_ptr) { sz_cptr_t const a_end = a + a_length; sz_cptr_t const b_end = b + b_length; @@ -587,7 +707,8 @@ SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); } - return distance; + *result_ptr = distance; + return sz_success_k; } #pragma endregion // Serial Implementation @@ -646,9 +767,9 @@ SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // *? Bounds check, for inputs ranging from 33 to 64 bytes doesn't affect the performance at all. *? It's also worth exploring `_mm512_alignr_epi8` and `_mm512_maskz_compress_epi8` for the shift. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_levenshtein_distance_skewed_diagonals_upto63_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_size_t const max_length = 63u; @@ -815,9 +936,9 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto63_ice( // * - source code analysis, assuming most lines are either under 80 or under 120 characters long. * - DNA sequence alignment, as most short reads are 50-300 characters long. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_levenshtein_distance_skewed_diagonals_upto127_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; @@ -835,9 +956,9 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto127_ice( // * This is the largest space-efficient variant, as strings beyond 255 characters may require * 16-bit accumulators, which would be a significant bottleneck. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_levenshtein_distance_skewed_diagonals_upto_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; @@ -856,9 +977,9 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto_ice( // * This is the largest space-efficient variant, as strings beyond 255 characters may require * 16-bit accumulators, which would be a significant bottleneck. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_levenshtein_distance_skewed_diagonals_upto255bound_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; @@ -873,20 +994,20 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto255bound_ice( // * * Each string is unpacked into 128 characters * 4 bytes per character / 64 bytes per register = 8 registers. */ -SZ_INTERNAL sz_size_t _sz_edit_distance_utf8_skewed_diagonals_upto127_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // +SZ_INTERNAL sz_size_t _sz_levenshtein_distance_utf8_skewed_diagonals_upto127_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound) { sz_unused(shorter && shorter_length && longer && longer_length && bound); return 0; } -SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { +SZ_INTERNAL sz_status_t _sz_levenshtein_distance_skewed_diagonals_upto65k_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result_ptr) { - sz_unused(shorter && longer && bound && alloc); + sz_unused(shorter && longer && bound && alloc && result_ptr); // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; @@ -1037,22 +1158,31 @@ SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_ice( // alloc->free(distances, buffer_length, alloc->handle); return result; #endif - return 0; + return sz_success_k; } -SZ_INTERNAL sz_size_t sz_edit_distance_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { +SZ_PUBLIC sz_status_t sz_levenshtein_distance_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result_ptr) { // Bounded computations may exit early. int const is_bounded = bound < longer_length; if (is_bounded) { // If one of the strings is empty - the edit distance is equal to the length of the other one. - if (longer_length == 0) return sz_min_of_two(shorter_length, bound); - if (shorter_length == 0) return sz_min_of_two(longer_length, bound); + if (longer_length == 0) { + *result_ptr = sz_min_of_two(shorter_length, bound); + return sz_success_k; + } + if (shorter_length == 0) { + *result_ptr = sz_min_of_two(longer_length, bound); + return sz_success_k; + } // If the difference in length is beyond the `bound`, there is no need to check at all. - if (longer_length - shorter_length > bound) return bound; + if (longer_length - shorter_length > bound) { + *result_ptr = bound; + return sz_success_k; + } } // Make sure the shorter string is actually shorter. @@ -1066,14 +1196,17 @@ SZ_INTERNAL sz_size_t sz_edit_distance_ice( // } // Dispatch the right implementation based on the length of the strings. - if (longer_length < 64u) - return _sz_edit_distance_skewed_diagonals_upto63_ice( // + if (longer_length < 64u) { + *result_ptr = _sz_levenshtein_distance_skewed_diagonals_upto63_ice( // shorter, shorter_length, longer, longer_length, bound); + return sz_success_k; + } + // else if (longer_length < 256u * 256u) - // return _sz_edit_distance_skewed_diagonals_upto65k_ice( // + // return _sz_levenshtein_distance_skewed_diagonals_upto65k_ice( // // shorter, shorter_length, longer, longer_length, bound, alloc); else - return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); + return sz_levenshtein_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc, result_ptr); } /** @@ -1082,21 +1215,27 @@ SZ_INTERNAL sz_size_t sz_edit_distance_ice( // * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used * on strings not exceeding 2^24 length or 16.7 million characters. * - * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store + * Unlike the `_sz_levenshtein_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string * with a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character * against a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of * a 256 x 256 matrix, but from a single row! */ -SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { +SZ_INTERNAL sz_status_t _sz_needleman_wunsch_score_wagner_fisher_upto17m_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc, sz_ssize_t *result_ptr) { // If one of the strings is empty - the edit distance is equal to the length of the other one - if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; - if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; + if (longer_length == 0) { + *result_ptr = (sz_ssize_t)shorter_length * gap; + return sz_success_k; + } + if (shorter_length == 0) { + *result_ptr = (sz_ssize_t)longer_length * gap; + return sz_success_k; + } // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. @@ -1119,6 +1258,7 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); + if (!distances) return sz_bad_alloc_k; sz_i32_t *previous_distances = distances; sz_i32_t *current_distances = previous_distances + n; @@ -1297,19 +1437,21 @@ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_ice( // // Cache scalar before `free` call. sz_ssize_t result = previous_distances[longer_length]; alloc->free(distances, buffer_length, alloc->handle); - return result; + *result_ptr = result; + return sz_success_k; } -SZ_INTERNAL sz_ssize_t sz_alignment_score_ice( // - sz_cptr_t shorter, sz_size_t shorter_length, // - sz_cptr_t longer, sz_size_t longer_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { +SZ_PUBLIC sz_status_t sz_needleman_wunsch_score_ice( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc, sz_ssize_t *result_ptr) { if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) - return _sz_alignment_score_wagner_fisher_upto17m_ice(shorter, shorter_length, longer, longer_length, subs, gap, - alloc); + return _sz_needleman_wunsch_score_wagner_fisher_upto17m_ice(shorter, shorter_length, longer, longer_length, + subs, gap, alloc, result_ptr); else - return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); + return sz_needleman_wunsch_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc, + result_ptr); } #pragma clang attribute pop @@ -1351,46 +1493,46 @@ SZ_INTERNAL sz_ssize_t sz_alignment_score_ice( // #pragma region Compile Time Dispatching #if !SZ_DYNAMIC_DISPATCH -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); +SZ_DYNAMIC sz_status_t sz_hamming_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result_ptr) { + return sz_hamming_distance_serial(a, a_length, b, b_length, bound, result_ptr); } -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); +SZ_DYNAMIC sz_status_t sz_hamming_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result_ptr) { + return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound, result_ptr); } -SZ_DYNAMIC sz_size_t sz_edit_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { +SZ_DYNAMIC sz_status_t sz_levenshtein_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result_ptr) { #if SZ_USE_ICE - return sz_edit_distance_ice(a, a_length, b, b_length, bound, alloc); + return sz_levenshtein_distance_ice(a, a_length, b, b_length, bound, alloc, result_ptr); #else - return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); + return sz_levenshtein_distance_serial(a, a_length, b, b_length, bound, alloc, result_ptr); #endif } -SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); +SZ_DYNAMIC sz_status_t sz_levenshtein_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result_ptr) { + return _sz_levenshtein_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc, result_ptr); } -SZ_DYNAMIC sz_ssize_t sz_alignment_score( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { +SZ_DYNAMIC sz_status_t sz_needleman_wunsch_score( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc, sz_ssize_t *result_ptr) { #if SZ_USE_ICE - return sz_alignment_score_ice(a, a_length, b, b_length, subs, gap, alloc); + return sz_needleman_wunsch_score_ice(a, a_length, b, b_length, subs, gap, alloc, result_ptr); #else - return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); + return sz_needleman_wunsch_score_serial(a, a_length, b, b_length, subs, gap, alloc, result_ptr); #endif } diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 24f8fc94..143f252e 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -3316,12 +3316,12 @@ class basic_string { concatenation operator|(string_view other) const noexcept { return {view(), other}; } size_type edit_distance(string_view other, size_type bound = 0) const noexcept { - size_type distance; + size_type result; _with_alloc([&](sz_alloc_type &alloc) { - distance = sz_edit_distance(data(), size(), other.data(), other.size(), bound, &alloc); - return true; + return sz_levenshtein_distance(data(), size(), other.data(), other.size(), bound, &alloc, &result) != + sz_bad_alloc_k; }); - return distance; + return result; } /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ @@ -3783,18 +3783,20 @@ typename concatenation_result::type /** * @brief Calculates the Hamming edit distance in @b bytes between two strings. - * @see sz_edit_distance + * @see sz_levenshtein_distance */ template std::size_t hamming_distance( // basic_string_slice const &a, basic_string_slice const &b, // std::size_t bound = 0) noexcept { - return sz_hamming_distance(a.data(), a.size(), b.data(), b.size(), bound); + std::size_t result; + sz_hamming_distance(a.data(), a.size(), b.data(), b.size(), bound, &result); + return result; } /** * @brief Calculates the Hamming edit distance in @b bytes between two strings. - * @see sz_edit_distance + * @see sz_levenshtein_distance */ template ::type>> std::size_t hamming_distance( // @@ -3810,12 +3812,14 @@ std::size_t hamming_distance( template std::size_t hamming_distance_utf8( // basic_string_slice const &a, basic_string_slice const &b, std::size_t bound = 0) noexcept { - return sz_hamming_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound); + std::size_t result; + sz_hamming_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &result); + return result; } /** * @brief Calculates the Hamming edit distance in @b unicode codepoints between two strings. - * @see sz_edit_distance + * @see sz_levenshtein_distance */ template ::type>> std::size_t hamming_distance_utf8( // @@ -3826,7 +3830,7 @@ std::size_t hamming_distance_utf8( // /** * @brief Calculates the Levenshtein edit distance in @b bytes between two strings. - * @see sz_edit_distance + * @see sz_levenshtein_distance */ template ::type>> std::size_t edit_distance( // @@ -3834,8 +3838,8 @@ std::size_t edit_distance( // allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { std::size_t result; if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { - result = sz_edit_distance(a.data(), a.size(), b.data(), b.size(), bound, &alloc); - return result != SZ_SIZE_MAX; + return sz_levenshtein_distance(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result) != + sz_bad_alloc_k; })) throw std::bad_alloc(); return result; @@ -3843,7 +3847,7 @@ std::size_t edit_distance( // /** * @brief Calculates the Levenshtein edit distance in @b bytes between two strings. - * @see sz_edit_distance + * @see sz_levenshtein_distance */ template > std::size_t edit_distance( // @@ -3854,7 +3858,7 @@ std::size_t edit_distance( /** * @brief Calculates the Levenshtein edit distance in @b unicode codepoints between two strings. - * @see sz_edit_distance_utf8 + * @see sz_levenshtein_distance_utf8 */ template ::type>> std::size_t edit_distance_utf8( // @@ -3862,8 +3866,8 @@ std::size_t edit_distance_utf8( std::size_t bound = SZ_SIZE_MAX, allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { std::size_t result; if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { - result = sz_edit_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &alloc); - return result != SZ_SIZE_MAX; + return sz_levenshtein_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result) != + sz_bad_alloc_k; })) throw std::bad_alloc(); return result; @@ -3871,7 +3875,7 @@ std::size_t edit_distance_utf8( /** * @brief Calculates the Levenshtein edit distance in @b unicode codepoints between two strings. - * @see sz_edit_distance_utf8 + * @see sz_levenshtein_distance_utf8 */ template > std::size_t edit_distance_utf8( // @@ -3882,7 +3886,7 @@ std::size_t edit_distance_utf8( /** * @brief Calculates the Needleman-Wunsch alignment score between two strings. - * @see sz_alignment_score + * @see sz_needleman_wunsch_score */ template ::type>> std::ptrdiff_t alignment_score( // @@ -3896,8 +3900,8 @@ std::ptrdiff_t alignment_score( std::ptrdiff_t result; if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { - result = sz_alignment_score(a.data(), a.size(), b.data(), b.size(), &subs[0][0], gap, &alloc); - return result != SZ_SSIZE_MAX; + return sz_needleman_wunsch_score(a.data(), a.size(), b.data(), b.size(), &subs[0][0], gap, &alloc, + &result) != sz_bad_alloc_k; })) throw std::bad_alloc(); return result; @@ -3905,7 +3909,7 @@ std::ptrdiff_t alignment_score( /** * @brief Calculates the Needleman-Wunsch alignment score between two strings. - * @see sz_alignment_score + * @see sz_needleman_wunsch_score */ template > std::ptrdiff_t alignment_score( // @@ -3973,17 +3977,17 @@ struct _sequence_args { }; template -sz_cptr_t _call_sequence_member_start(struct sz_sequence_t const *sequence, sz_size_t i) { +sz_cptr_t _call_sequence_member_start(void const *sequence, sz_size_t i) { using handle_type = _sequence_args; - handle_type const *args = reinterpret_cast(sequence->handle); + handle_type const *args = reinterpret_cast(sequence); string_view member = args->extractor(args->begin[i]); return member.data(); } template -sz_size_t _call_sequence_member_length(struct sz_sequence_t const *sequence, sz_size_t i) { +sz_size_t _call_sequence_member_length(void const *sequence, sz_size_t i) { using handle_type = _sequence_args; - handle_type const *args = reinterpret_cast(sequence->handle); + handle_type const *args = reinterpret_cast(sequence); string_view member = args->extractor(args->begin[i]); return static_cast(member.size()); } From d7bab8d58fc7d9e77e6553f7e39aa9cd34a29f93 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 22 Feb 2025 12:36:44 +0000 Subject: [PATCH 121/751] Docs: Explaining `compare.h` operations --- include/stringzilla/compare.h | 73 ++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/include/stringzilla/compare.h b/include/stringzilla/compare.h index 9f2e276d..4d0a7cb5 100644 --- a/include/stringzilla/compare.h +++ b/include/stringzilla/compare.h @@ -7,8 +7,13 @@ * * - `sz_equal` - for equality comparison of two strings. * - `sz_order` - for the relative order of two strings, similar to `memcmp`. - * - TODO: `sz_mismatch`, `sz_rmismatch` - to supersede `sz_equal`. - * - TODO: `sz_order_utf8` - for the relative order of two UTF-8 strings. + * + * A valid suggestion may be to add an `sz_mismatch`, as the shared part of the `sz_order` and `sz_equal`. + * That would be great for a general-purpose library, but has little practical use for string processing. + * + * The functions in this file can be used for both UTF-8 and other inputs. + * On platforms without masked loads they use interleaved prefix and suffix vector-loads + * to avoid scalar code, similar to the kernels in `memory.h`. */ #ifndef STRINGZILLA_COMPARE_H_ #define STRINGZILLA_COMPARE_H_ @@ -22,29 +27,61 @@ extern "C" { #pragma region Core API /** - * @brief Checks if two string are equal. - * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. + * @brief Checks if two strings are equal. Equivalent to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. + * @see https://en.cppreference.com/w/c/string/byte/memcmp + * + * @param[in] a First string to compare. + * @param[in] b Second string to compare. + * @param[in] length Number of bytes to compare in both strings. + * + * @retval `sz_true_k` if strings are equal. + * @retval `sz_false_k` if strings are different. + * + * Example usage: * - * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. - * This function is more often used in parsing, while `sz_order` is often used in sorting. - * It works best on platforms with cheap + * @code{.c} + * #include + * int main() { + * return sz_equal("hello", "hello", 5) && !sz_equal("hello", "world", 5); + * } + * @endcode * - * @param a First string to compare. - * @param b Second string to compare. - * @param length Number of bytes in both strings. - * @return 1 if strings match, 0 otherwise. + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_equal_serial, sz_equal_haswell, sz_equal_skylake, sz_equal_neon, sz_equal_sve */ SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** - * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. - * Can be used on different length strings. + * @brief Compares two strings lexicographically. Equivalent to `memcmp(a, b, length)` in LibC. + * Mostly used in sorting and associative containers. Can be used for @b UTF-8 inputs. + * @see https://en.cppreference.com/w/c/string/byte/memcmp + * + * This function uses scalar code on most platforms, as in the majority of cases the strings that + * differ - will have differences among the very first characters and fetching more than one cache + * line may not be justified. + * + * @param[in] a First string to compare. + * @param[in] a_length Number of bytes in the first string. + * @param[in] b Second string to compare. + * @param[in] b_length Number of bytes in the second string. + * + * @retval `sz_less_k` if @p a is lexicographically smaller than @p b. + * @retval `sz_greater_k` if @p a is lexicographically greater than @p b. + * @retval `sz_equal_k` if strings @p a and @p b are identical. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * return sz_order("apple", 5, "banana", 6) < 0 && + * sz_order("grape", 5, "grape", 5) == 0 && + * sz_order("zebra", 5, "apple", 5) > 0; + * } + * @endcode * - * @param a First string to compare. - * @param a_length Number of bytes in the first string. - * @param b Second string to compare. - * @param b_length Number of bytes in the second string. - * @return Negative if (a < b), positive if (a > b), zero if they are equal. + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_order_serial, sz_order_haswell, sz_order_skylake, sz_order_neon, sz_order_sve */ SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); From 7aad4bb681d1f1d7aae031dd657ed53d82291b17 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 13:07:44 +0000 Subject: [PATCH 122/751] Improve: Vectorize `sz_equal_haswell` --- include/stringzilla/compare.h | 69 ++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/include/stringzilla/compare.h b/include/stringzilla/compare.h index 4d0a7cb5..494d1442 100644 --- a/include/stringzilla/compare.h +++ b/include/stringzilla/compare.h @@ -10,7 +10,7 @@ * * A valid suggestion may be to add an `sz_mismatch`, as the shared part of the `sz_order` and `sz_equal`. * That would be great for a general-purpose library, but has little practical use for string processing. - * + * * The functions in this file can be used for both UTF-8 and other inputs. * On platforms without masked loads they use interleaved prefix and suffix vector-loads * to avoid scalar code, similar to the kernels in `memory.h`. @@ -59,7 +59,7 @@ SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); * This function uses scalar code on most platforms, as in the majority of cases the strings that * differ - will have differences among the very first characters and fetching more than one cache * line may not be justified. - * + * * @param[in] a First string to compare. * @param[in] a_length Number of bytes in the first string. * @param[in] b Second string to compare. @@ -172,19 +172,60 @@ SZ_PUBLIC sz_ordering_t sz_order_haswell(sz_cptr_t a, sz_size_t a_length, sz_cpt } SZ_PUBLIC sz_bool_t sz_equal_haswell(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { - sz_u256_vec_t a_vec, b_vec; - - while (length >= 32) { - a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); - b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); - // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. - int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); - if (difference_mask == 0) { a += 32, b += 32, length -= 32; } - else { return sz_false_k; } - } - if (length) return sz_equal_serial(a, b, length); - return sz_true_k; + if (length < 8) { + sz_cptr_t const a_end = a + length; + while (a != a_end && *a == *b) a++, b++; + return (sz_bool_t)(a_end == a); + } + // We can use 2x 64-bit interleaving loads for each string, and then compare them for equality. + // The same approach is used in GLibC and was suggest by Denis Yaroshevskiy. + // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memcmp-avx2-movbe.S.html#518 + // It shouldn't improve performance on microbenchmarks, but should be better in practice. + else if (length <= 16) { + sz_u64_t a_first_word = sz_u64_load(a).u64; + sz_u64_t b_first_word = sz_u64_load(b).u64; + sz_u64_t a_second_word = sz_u64_load(a + length - 8).u64; + sz_u64_t b_second_word = sz_u64_load(b + length - 8).u64; + return (sz_bool_t)((a_first_word == b_first_word) & (a_second_word == b_second_word)); + } + // We can use 2x 128-bit interleaving loads for each string, and then compare them for equality. + else if (length <= 32) { + sz_u128_vec_t a_first_vec, b_first_vec, a_second_vec, b_second_vec; + a_first_vec.xmm = _mm_lddqu_si128((__m128i const *)(a)); + b_first_vec.xmm = _mm_lddqu_si128((__m128i const *)(b)); + a_second_vec.xmm = _mm_lddqu_si128((__m128i const *)(a + length - 16)); + b_second_vec.xmm = _mm_lddqu_si128((__m128i const *)(b + length - 16)); + return (sz_bool_t)(_mm_movemask_epi8(_mm_and_si128( // + _mm_cmpeq_epi8(a_first_vec.xmm, b_first_vec.xmm), + _mm_cmpeq_epi8(a_second_vec.xmm, b_second_vec.xmm))) == 0xFFFF); + } + // We can use 2x 256-bit interleaving loads for each string, and then compare them for equality. + else if (length <= 64) { + sz_u256_vec_t a_first_vec, b_first_vec, a_second_vec, b_second_vec; + a_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(a)); + b_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(b)); + a_second_vec.ymm = _mm256_lddqu_si256((__m256i const *)(a + length - 32)); + b_second_vec.ymm = _mm256_lddqu_si256((__m256i const *)(b + length - 32)); + return (sz_bool_t)(_mm256_movemask_epi8(_mm256_and_si256( // + _mm256_cmpeq_epi8(a_first_vec.ymm, b_first_vec.ymm), + _mm256_cmpeq_epi8(a_second_vec.ymm, b_second_vec.ymm))) == (int)0xFFFFFFFF); + } + else { + sz_size_t i = 0; + sz_u256_vec_t a_vec, b_vec; + do { + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)(a + i)); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)(b + i)); + // One approach can be to use "movemasks", but we could also use a bitwise + // matching like `_mm256_testnzc_si256`. + if (_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)) != (int)0xFFFFFFFF) return sz_false_k; + i += 32; + } while (i + 32 <= length); + a_vec.ymm = _mm256_lddqu_si256((__m256i const *)(a + length - 32)); + b_vec.ymm = _mm256_lddqu_si256((__m256i const *)(b + length - 32)); + return (sz_bool_t)(_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)) == (int)0xFFFFFFFF); + } } #pragma clang attribute pop From 6e715362a7f47bfa35ae01690a314e0bf4baefc8 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 13:09:44 +0000 Subject: [PATCH 123/751] Improve: Ordering includes --- include/stringzilla/stringzilla.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 349aba79..660ffa6c 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -30,7 +30,7 @@ * * - `SZ_USE_HASWELL=?` - whether to use AVX2 instructions on x86_64. * - `SZ_USE_SKYLAKE=?` - whether to use AVX-512 instructions on x86_64. - * - `SZ_USE_ICE=?` - whether to use AVX-512 VBMI instructions on x86_64. + * - `SZ_USE_ICE=?` - whether to use AVX-512 VBMI & wider AES instructions on x86_64. * - `SZ_USE_NEON=?` - whether to use NEON instructions on ARM. * - `SZ_USE_SVE=?` - whether to use SVE and SVE2 instructions on ARM. */ @@ -41,14 +41,14 @@ #define STRINGZILLA_VERSION_MINOR 11 #define STRINGZILLA_VERSION_PATCH 3 +#include "types.h" // `sz_size_t`, `sz_bool_t`, `sz_ordering_t` #include "compare.h" // `sz_equal`, `sz_order` -#include "find.h" // `sz_find`, `sz_find_charset`, `sz_rfind` -#include "hash.h" // `sz_bytesum`, `sz_hash`, `sz_state_init`, `sz_state_stream`, `sz_state_fold` #include "memory.h" // `sz_copy`, `sz_move`, `sz_fill` -#include "similarity.h" // `sz_edit_distance`, `sz_alignment_score` +#include "hash.h" // `sz_bytesum`, `sz_hash`, `sz_state_init`, `sz_state_stream`, `sz_state_fold` +#include "find.h" // `sz_find`, `sz_find_charset`, `sz_rfind` #include "small_string.h" // `sz_string_t`, `sz_string_init`, `sz_string_free` +#include "similarity.h" // `sz_levenshtein_distance`, `sz_needleman_wunsch_score` #include "sort.h" // `sz_sequence_argsort`, `sz_pgrams_sort`, `sz_pgrams_sort_stable` -#include "types.h" // `sz_size_t`, `sz_bool_t`, `sz_ordering_t` #ifdef __cplusplus extern "C" { From 2225488a0e035a01e50cc78e8c2aa4d3ff8c33ef Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 20:19:15 +0000 Subject: [PATCH 124/751] Docs: Announce JOINs --- include/stringzilla/sort.h | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index c422e3c3..8e387b70 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -1,28 +1,29 @@ /** - * @brief Hardware-accelerated string collection sorting. + * @brief Hardware-accelerated string collection sorting & joins. * @file sort.h * @author Ash Vardanian * * Includes core APIs for `sz_sequence_t` string collections: * - * - `sz_sequence_argsort` - to get the sorting permutation of a string collection with QuickSort. - * - `sz_sequence_argsort_stable` - to get the stable-sorting permutation of a string collection with a MergeSort. + * - `sz_sequence_argsort` - to get the sorting permutation of a string collection. + * - `sz_sequence_join` - to compute the intersection of two arbitrary string collections. * - * The core idea of all following string algorithms is to sort strings not based on 1 character at a time, + * The core idea of all following string algorithms is to process strings not based on 1 character at a time, * but on a larger "Pointer-sized N-grams" fitting in 4 or 8 bytes at once, on 32-bit or 64-bit architectures, - * respectively. In reality we may not use the full pointer size, but only a few bytes from it, and keep the rest - * for some metadata. + * respectively. In reality we may not use the full pointer size, but only a few bytes from it, and keep the + * rest for some metadata. * - * That, however, means, that unsigned integer sorting is a constituent part of our string sorting and we can - * expose it as an additional set of APIs for the users: + * That, however, means, that unsigned integer sorting & matching is a constituent part of our sequence + * algorithms and we can expose them as an additional set of APIs for the users: * - * - `sz_pgrams_sort` - to inplace sort continuous pointer-sized integers with QuickSort. - * - `sz_pgrams_sort_stable` - to inplace stable-sort continuous pointer-sized integers with a MergeSort. + * - `sz_pgrams_sort` - to inplace sort continuous pointer-sized integers. + * - `sz_pgrams_join` - to compute the intersection of two arbitrary integer collections. * - * For cases, when the input is known to be tiny, we provide quadratic-complexity insertion sort adaptations: + * Other helpers include: * - * - `sz_sequence_argsort_with_insertion` - for string collections. - * - `sz_pgrams_sort_stable_with_insertion` - for continuous unsigned integers. + * - `sz_pgrams_sort_stable_with_insertion` - for quadratic-complexity sorting of small continuous integer arrays. + * - `sz_sequence_argsort_with_insertion` - for quadratic-complexity sorting of small string collections. + * - `sz_sequence_argsort_stabilize` - updates the sorting permutation to be stable. */ #ifndef STRINGZILLA_SORT_H_ #define STRINGZILLA_SORT_H_ @@ -845,7 +846,7 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order) { - + sz_unused(sequence && alloc && order); return sz_success_k; } From 3c345bcbeefa4bc5dcf3745dc2663d0b35168327 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 20:19:53 +0000 Subject: [PATCH 125/751] Add: Hashing on Haswell & Skylake-X --- include/stringzilla/hash.h | 444 ++++++++++++++++++++++++++++--------- 1 file changed, 341 insertions(+), 103 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 2094a0d3..539f016a 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -66,9 +66,21 @@ extern "C" { * @brief Computes the 64-bit check-sum of bytes in a string. * Similar to `std::ranges::accumulate`. * - * @param text String to aggregate. - * @param length Number of bytes in the text. - * @return 64-bit unsigned value. + * @param[in] text String to aggregate. + * @param[in] length Number of bytes in the text. + * @return 64-bit unsigned value. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * return sz_bytesum("hi", 2) == 209 ? 0 : 1; + * } + * @endcode + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_bytesum_serial, sz_bytesum_haswell, sz_bytesum_skylake, sz_bytesum_ice, sz_bytesum_neon */ SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length); @@ -78,10 +90,25 @@ SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length); * It passes the SMHasher suite by Austin Appleby with no collisions, even with `--extra` flag. * @see HASH.md for a detailed explanation of the algorithm. * - * @param text String to hash. - * @param length Number of bytes in the text. - * @param seed 64-bit unsigned seed for the hash. - * @return 64-bit hash value. + * @param[in] text String to hash. + * @param[in] length Number of bytes in the text. + * @param[in] seed 64-bit unsigned seed for the hash. + * @return 64-bit hash value. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * return sz_hash("hello", 5, 0) != sz_hash("world", 5, 0) ? 0 : 1; + * } + * @endcode + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_hash_serial, sz_hash_haswell, sz_hash_skylake, sz_hash_ice, sz_hash_neon + * + * @note The algorithm must provide the same output on all platforms in both single-shot and incremental modes. + * @sa sz_hash_state_init, sz_hash_state_stream, sz_hash_state_fold */ SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed); @@ -98,9 +125,24 @@ SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed); * In this case, it doesn't apply, as we only use one round of AES mixing. We also don't expose a separate "key", * only a "nonce", to keep the API simple. * - * @param text Output string buffer to be populated. - * @param length Number of bytes in the string. - * @param nonce "Number used ONCE" to ensure uniqueness of produced blocks. + * @param[out] text Output string buffer to be populated. + * @param[in] length Number of bytes in the string. + * @param[in] nonce "Number used ONCE" to ensure uniqueness of produced blocks. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char first_buffer[5], second_buffer[5]; + * sz_generate(first_buffer, 5, 0); + * sz_generate(second_buffer, 5, 0); //? Same nonce will produce the same output + * return sz_bytesum(first_buffer, 5) == sz_bytesum(second_buffer, 5) ? 0 : 1; + * } + * @endcode + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_generate_serial, sz_generate_haswell, sz_generate_skylake, sz_generate_ice, sz_generate_neon */ SZ_DYNAMIC void sz_generate(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); @@ -126,25 +168,25 @@ typedef struct _sz_hash_minimal_t { /** * @brief Initializes the state for incremental construction of a hash. * - * @param state The state to initialize. - * @param seed The 64-bit unsigned seed for the hash. + * @param[out] state The state to initialize. + * @param[in] seed The 64-bit unsigned seed for the hash. */ SZ_DYNAMIC void sz_hash_state_init(sz_hash_state_t *state, sz_u64_t seed); /** * @brief Updates the state with new data. * - * @param state The state to stream. - * @param text The new data to include in the hash. - * @param length The number of bytes in the new data. + * @param[inout] state The state to stream. + * @param[in] text The new data to include in the hash. + * @param[in] length The number of bytes in the new data. */ SZ_DYNAMIC void sz_hash_state_stream(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length); /** - * @brief Finalizes the state and returns the hash. + * @brief Finalizes the immutable state and returns the hash. * - * @param state The state to fold. - * @return The 64-bit hash value. + * @param[in] state The state to fold. + * @return The 64-bit hash value. */ SZ_DYNAMIC sz_u64_t sz_hash_state_fold(sz_hash_state_t const *state); @@ -365,8 +407,182 @@ SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length) { } } -SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { - return sz_hash_serial(text, length, seed); +SZ_INTERNAL void _sz_hash_minimal_init_haswell(_sz_hash_minimal_t *state, sz_u64_t seed) { + __m128i seed_vec = _mm_set1_epi64x(seed); + __m128i pi0 = _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull); + __m128i pi1 = _mm_set_epi64x(0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + // XOR the user-supplied keys with the two "pi" constants + __m128i k1 = _mm_xor_si128(seed_vec, pi0); + __m128i k2 = _mm_xor_si128(seed_vec, pi1); + // Export the keys to the state + state->aes.xmm = k1; + state->sum.xmm = k2; + state->key.xmm = _mm_xor_si128(pi0, pi1); +} + +SZ_INTERNAL sz_u64_t _sz_hash_minimal_finalize_haswell(_sz_hash_minimal_t const *state) { + // Combine the sum and the AES block + __m128i mixed_registers = _mm_aesenc_si128(state->sum.xmm, state->aes.xmm); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + __m128i mixed_within_register = + _mm_aesdec_si128(_mm_aesdec_si128(mixed_registers, state->key.xmm), mixed_registers); + // Extract the low 64 bits + return _mm_cvtsi128_si64(mixed_within_register); +} + +SZ_INTERNAL void _sz_hash_minimal_update_haswell(_sz_hash_minimal_t *state, __m128i block) { + // This shuffle mask is identical to "aHash": + __m128i const shuffle_mask = _mm_set_epi8( // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02); + state->aes.xmm = _mm_aesdec_si128(state->aes.xmm, block); + state->sum.xmm = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmm, shuffle_mask), block); +} + +SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) { + __m128i seed_vec = _mm_set1_epi64x(seed); + __m128i pi0 = _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull); + __m128i pi1 = _mm_set_epi64x(0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + // XOR the user-supplied keys with the two "pi" constants + __m128i k1 = _mm_xor_si128(seed_vec, pi0); + __m128i k2 = _mm_xor_si128(seed_vec, pi1); + // Export the keys to the state + state->aes.xmms[0] = state->aes.xmms[1] = state->aes.xmms[2] = state->aes.xmms[3] = k1; + state->sum.xmms[0] = state->sum.xmms[1] = state->sum.xmms[2] = state->sum.xmms[3] = k2; + state->key.xmms[0] = state->key.xmms[1] = state->key.xmms[2] = state->key.xmms[3] = _mm_xor_si128(pi0, pi1); + state->ins_length = 0; +} + +SZ_INTERNAL void _sz_hash_state_update_haswell(sz_hash_state_t *state, __m128i block0, __m128i block1, __m128i block2, + __m128i block3) { + // This shuffle mask is identical to "aHash": + __m128i const shuffle_mask = _mm_set_epi8( // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02); + state->aes.xmms[0] = _mm_aesdec_si128(state->aes.xmms[0], block0); + state->sum.xmms[0] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[0], shuffle_mask), block0); + state->aes.xmms[1] = _mm_aesdec_si128(state->aes.xmms[1], block1); + state->sum.xmms[1] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[1], shuffle_mask), block1); + state->aes.xmms[2] = _mm_aesdec_si128(state->aes.xmms[2], block2); + state->sum.xmms[2] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[2], shuffle_mask), block2); + state->aes.xmms[3] = _mm_aesdec_si128(state->aes.xmms[3], block3); + state->sum.xmms[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[3], shuffle_mask), block3); +} + +SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *state) { + // Combine the sum and the AES block + __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.xmms[0], state->aes.xmms[0]); + __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.xmms[1], state->aes.xmms[1]); + __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.xmms[2], state->aes.xmms[2]); + __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.xmms[3], state->aes.xmms[3]); + // Combine the mixed registers + __m128i mixed_registers01 = _mm_aesenc_si128(mixed_registers0, mixed_registers1); + __m128i mixed_registers23 = _mm_aesenc_si128(mixed_registers2, mixed_registers3); + __m128i mixed_registers = _mm_aesenc_si128(mixed_registers01, mixed_registers23); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + __m128i mixed_within_register = _mm_aesdec_si128( // + _mm_aesdec_si128(mixed_registers, state->key.xmms[0]), mixed_registers); + // Extract the low 64 bits + return _mm_cvtsi128_si64(mixed_within_register); +} + +SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { + + if (length <= 16) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data_vec; + data_vec.xmm = _mm_setzero_si128(); + for (sz_size_t i = 0; i < length; ++i) data_vec.u8s[i] = start[i]; + _sz_hash_minimal_update_haswell(&state, data_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 32) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec; + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_lddqu_si128(start + length - 16); + // Let's shift the data within the register to de-interleave the bytes. + data1_vec.xmm = _mm_bsrli_si128(data1_vec.xmm, 32 - length); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 48) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec; + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_lddqu_si128(start + 16); + data2_vec.xmm = _mm_lddqu_si128(start + length - 16); + // Let's shift the data within the register to de-interleave the bytes. + data2_vec.xmm = _mm_bsrli_si128(data2_vec.xmm, 48 - length); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 64) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_lddqu_si128(start + 16); + data2_vec.xmm = _mm_lddqu_si128(start + 32); + data3_vec.xmm = _mm_lddqu_si128(start + length - 16); + // Let's shift the data within the register to de-interleave the bytes. + data3_vec.xmm = _mm_bsrli_si128(data3_vec.xmm, 64 - length); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data3_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else { + // Use a larger state to handle the main loop and add different offsets + // to different lanes of the register + sz_hash_state_t state; + sz_hash_state_init_haswell(&state, seed); + state.aes.xmms[0] = _mm_add_epi64(state.aes.xmms[0], _mm_set_epi64x(0, length)); + state.aes.xmms[1] = _mm_add_epi64(state.aes.xmms[1], _mm_set_epi64x(16, length)); + state.aes.xmms[2] = _mm_add_epi64(state.aes.xmms[2], _mm_set_epi64x(32, length)); + state.aes.xmms[3] = _mm_add_epi64(state.aes.xmms[3], _mm_set_epi64x(48, length)); + + for (; state.ins_length + 64 <= length; state.ins_length += 64) { + state.ins.xmms[0] = _mm_lddqu_si128(start + state.ins_length); + state.ins.xmms[1] = _mm_lddqu_si128(start + state.ins_length + 16); + state.ins.xmms[2] = _mm_lddqu_si128(start + state.ins_length + 32); + state.ins.xmms[3] = _mm_lddqu_si128(start + state.ins_length + 48); + _sz_hash_state_update_haswell(&state, state.ins.xmms[0], state.ins.xmms[1], state.ins.xmms[2], + state.ins.xmms[3]); + } + if (state.ins_length < length) { + state.ins.xmms[0] = _mm_setzero_si128(); + state.ins.xmms[1] = _mm_setzero_si128(); + state.ins.xmms[2] = _mm_setzero_si128(); + state.ins.xmms[3] = _mm_setzero_si128(); + for (sz_size_t i = 0; state.ins_length < length; ++i, ++state.ins_length) + state.ins.u8s[i] = start[state.ins_length]; + _sz_hash_state_update_haswell(&state, state.ins.xmms[0], state.ins.xmms[1], state.ins.xmms[2], + state.ins.xmms[3]); + } + return _sz_hash_state_finalize_haswell(&state); + } } SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { @@ -499,16 +715,107 @@ SZ_PUBLIC sz_u64_t sz_bytesum_skylake(sz_cptr_t text, sz_size_t length) { } } -SZ_PUBLIC sz_u64_t sz_hash_skylake(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { - return sz_hash_serial(text, length, seed); +SZ_PUBLIC void sz_hash_state_init_skylake(sz_hash_state_t *state, sz_u64_t seed) { + __m512i seed_vec = _mm512_set1_epi64(seed); + __m512i pi0 = _mm512_set_epi64( // + 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, + 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull); + __m512i pi1 = _mm512_set_epi64( // + 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, + 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + // XOR the user-supplied keys with the two "pi" constants + __m512i k1 = _mm512_xor_si512(seed_vec, pi0); + __m512i k2 = _mm512_xor_si512(seed_vec, pi1); + // Export the keys to the state + state->aes.zmm = k1; + state->sum.zmm = k2; + state->key.zmm = _mm512_xor_si512(pi0, pi1); + state->ins_length = 0; } -SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { - sz_generate_serial(text, length, nonce); +SZ_PUBLIC sz_u64_t sz_hash_skylake(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { + + if (length <= 16) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data_vec; + data_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length), start); + _sz_hash_minimal_update_haswell(&state, data_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 32) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec; + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 16), start + 16); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 48) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec; + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_lddqu_si128(start + 16); + data2_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 32), start + 32); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else if (length <= 64) { + // Initialize the AES block with a given seed and update with the input length + _sz_hash_minimal_t state; + _sz_hash_minimal_init_haswell(&state, seed); + state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_lddqu_si128(start + 16); + data2_vec.xmm = _mm_lddqu_si128(start + 32); + data3_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 48), start + 48); + _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); + _sz_hash_minimal_update_haswell(&state, data3_vec.xmm); + return _sz_hash_minimal_finalize_haswell(&state); + } + else { + // Use a larger state to handle the main loop and add different offsets + // to different lanes of the register + sz_hash_state_t state; + sz_hash_state_init_skylake(&state, seed); + state.aes.zmm = _mm512_add_epi64( // + state.aes.zmm, // + _mm512_set_epi64(0, length, 16, length, 32, length, 48, length)); + + for (; state.ins_length + 64 <= length; state.ins_length += 64) { + state.ins.zmm = _mm512_loadu_epi8(start + state.ins_length); + _sz_hash_state_update_haswell(&state, state.ins.xmms[0], state.ins.xmms[1], state.ins.xmms[2], + state.ins.xmms[3]); + } + if (state.ins_length < length) { + state.ins.zmm = _mm512_maskz_loadu_epi8( // + _sz_u64_mask_until(length - state.ins_length), start + state.ins_length); + _sz_hash_state_update_skylake(&state, state.ins.zmm); + } + return _sz_hash_state_finalize_haswell(&state); + } } -SZ_PUBLIC void sz_hash_state_init_skylake(sz_hash_state_t *state, sz_u64_t seed) { - sz_hash_state_init_serial(state, seed); +SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_generate_serial(text, length, nonce); } SZ_PUBLIC void sz_hash_state_stream_skylake(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { @@ -663,57 +970,6 @@ SZ_PUBLIC sz_u64_t sz_bytesum_ice(sz_cptr_t text, sz_size_t length) { } } -SZ_INTERNAL void _sz_hash_minimal_init_haswell(_sz_hash_minimal_t *state, sz_u64_t seed) { - __m128i seed_vec = _mm_set1_epi64x(seed); - __m128i pi0 = _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull); - __m128i pi1 = _mm_set_epi64x(0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); - // XOR the user-supplied keys with the two "pi" constants - __m128i k1 = _mm_xor_si128(seed_vec, pi0); - __m128i k2 = _mm_xor_si128(seed_vec, pi1); - // Export the keys to the state - state->aes.xmm = k1; - state->sum.xmm = k2; - state->key.xmm = _mm_xor_si128(pi0, pi1); -} - -SZ_INTERNAL sz_u64_t _sz_hash_minimal_finalize_haswell(_sz_hash_minimal_t const *state) { - // Combine the sum and the AES block - __m128i mixed_registers = _mm_aesenc_si128(state->sum.xmm, state->aes.xmm); - // Make sure the "key" mixes enough with the state, - // as with less than 2 rounds - SMHasher fails - __m128i mixed_within_register = - _mm_aesdec_si128(_mm_aesdec_si128(mixed_registers, state->key.xmm), mixed_registers); - // Extract the low 64 bits - return _mm_cvtsi128_si64(mixed_within_register); -} - -SZ_INTERNAL void _sz_hash_minimal_update_haswell(_sz_hash_minimal_t *state, __m128i block) { - // This shuffle mask is identical to "aHash": - __m128i const shuffle_mask = _mm_set_epi8( // - 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // - 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02); - state->aes.xmm = _mm_aesdec_si128(state->aes.xmm, block); - state->sum.xmm = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmm, shuffle_mask), block); -} - -SZ_PUBLIC void sz_hash_state_init_ice(sz_hash_state_t *state, sz_u64_t seed) { - __m512i seed_vec = _mm512_set1_epi64(seed); - __m512i pi0 = _mm512_set_epi64( // - 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, - 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull); - __m512i pi1 = _mm512_set_epi64( // - 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, - 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); - // XOR the user-supplied keys with the two "pi" constants - __m512i k1 = _mm512_xor_si512(seed_vec, pi0); - __m512i k2 = _mm512_xor_si512(seed_vec, pi1); - // Export the keys to the state - state->aes.zmm = k1; - state->sum.zmm = k2; - state->key.zmm = _mm512_xor_si512(pi0, pi1); - state->ins_length = 0; -} - SZ_INTERNAL void _sz_hash_state_update_ice(sz_hash_state_t *state, __m512i block) { // This shuffle mask is identical to "aHash": __m512i const shuffle_mask = _mm512_set_epi8( // @@ -730,24 +986,6 @@ SZ_INTERNAL void _sz_hash_state_update_ice(sz_hash_state_t *state, __m512i block state->sum.zmm = _mm512_add_epi64(_mm512_shuffle_epi8(state->sum.zmm, shuffle_mask), block); } -SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_ice(sz_hash_state_t const *state) { - // Combine the sum and the AES block - __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.xmms[0], state->aes.xmms[0]); - __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.xmms[1], state->aes.xmms[1]); - __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.xmms[2], state->aes.xmms[2]); - __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.xmms[3], state->aes.xmms[3]); - // Combine the mixed registers - __m128i mixed_registers01 = _mm_aesenc_si128(mixed_registers0, mixed_registers1); - __m128i mixed_registers23 = _mm_aesenc_si128(mixed_registers2, mixed_registers3); - __m128i mixed_registers = _mm_aesenc_si128(mixed_registers01, mixed_registers23); - // Make sure the "key" mixes enough with the state, - // as with less than 2 rounds - SMHasher fails - __m128i mixed_within_register = _mm_aesdec_si128( // - _mm_aesdec_si128(mixed_registers, state->key.xmms[0]), mixed_registers); - // Extract the low 64 bits - return _mm_cvtsi128_si64(mixed_within_register); -} - SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { if (length <= 16) { @@ -768,7 +1006,7 @@ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec; - data0_vec.xmm = _mm_loadu_epi8(start); + data0_vec.xmm = _mm_lddqu_si128(start); data1_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 16), start + 16); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); @@ -781,8 +1019,8 @@ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec; - data0_vec.xmm = _mm_loadu_epi8(start); - data1_vec.xmm = _mm_loadu_epi8(start + 16); + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_lddqu_si128(start + 16); data2_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 32), start + 32); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); @@ -796,9 +1034,9 @@ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; - data0_vec.xmm = _mm_loadu_epi8(start); - data1_vec.xmm = _mm_loadu_epi8(start + 16); - data2_vec.xmm = _mm_loadu_epi8(start + 32); + data0_vec.xmm = _mm_lddqu_si128(start); + data1_vec.xmm = _mm_lddqu_si128(start + 16); + data2_vec.xmm = _mm_lddqu_si128(start + 32); data3_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 48), start + 48); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); @@ -810,7 +1048,7 @@ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) // Use a larger state to handle the main loop and add different offsets // to different lanes of the register sz_hash_state_t state; - sz_hash_state_init_ice(&state, seed); + sz_hash_state_init_skylake(&state, seed); state.aes.zmm = _mm512_add_epi64( // state.aes.zmm, // _mm512_set_epi64(0, length, 16, length, 32, length, 48, length)); @@ -824,7 +1062,7 @@ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) _sz_u64_mask_until(length - state.ins_length), start + state.ins_length); _sz_hash_state_update_ice(&state, state.ins.zmm); } - return _sz_hash_state_finalize_ice(&state); + return _sz_hash_state_finalize_haswell(&state); } } From 69dfa10cc9ca9e9f8c65876369223a82fb91d7f7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 23 Feb 2025 20:20:24 +0000 Subject: [PATCH 126/751] Improve: `copy`/`move` on Haswell with interleaving --- include/stringzilla/memory.h | 91 ++++++++++++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 4 deletions(-) diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index cc5cc6d7..79cd840c 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -439,7 +439,44 @@ SZ_PUBLIC void sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t leng // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. // For now, let's avoid the cases beyond the L2 size. int is_huge = length > 1ull * 1024ull * 1024ull; - if (length <= 32) { sz_copy_serial(target, source, length); } + if (length < 8) { + while (length--) *(target++) = *(source++); + } + // The next few sections are identical here and in the `sz_move_haswell` function. + // We can use 2x 64-bit interleaving loads for each string, and then compare them for equality. + // The same approach is used in GLibC and was suggest by Denis Yaroshevskiy. + // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memcmp-avx2-movbe.S.html#518 + // It shouldn't improve performance on microbenchmarks, but should be better in practice. + else if (length <= 16) { + sz_u64_t source_first_word = *(sz_u64_t const *)(source); + sz_u64_t source_second_word = *(sz_u64_t const *)(source + length - 8); + sz_u64_t *target_first_word_ptr = (sz_u64_t *)(target); + sz_u64_t *target_second_word_ptr = (sz_u64_t *)(target + length - 8); + *target_first_word_ptr = source_first_word; + *target_second_word_ptr = source_second_word; + } + // We can use 2x 128-bit interleaving loads for each string, and then compare them for equality. + else if (length <= 32) { + sz_u128_vec_t source_first_vec, source_second_vec; + sz_u128_vec_t *target_first_word_ptr, *target_second_word_ptr; + source_first_vec.xmm = _mm_lddqu_si128((__m128i const *)(source)); + source_second_vec.xmm = _mm_lddqu_si128((__m128i const *)(source + length - 16)); + target_first_word_ptr = (sz_u128_vec_t *)(target); + target_second_word_ptr = (sz_u128_vec_t *)(target + length - 16); + _mm_storeu_si128(&target_first_word_ptr->xmm, source_first_vec.xmm); + _mm_storeu_si128(&target_second_word_ptr->xmm, source_second_vec.xmm); + } + // We can use 2x 256-bit interleaving loads for each string, and then compare them for equality. + else if (length <= 64) { + sz_u256_vec_t source_first_vec, source_second_vec; + sz_u256_vec_t *target_first_word_ptr, *target_second_word_ptr; + source_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(source)); + source_second_vec.ymm = _mm256_lddqu_si256((__m256i const *)(source + length - 32)); + target_first_word_ptr = (sz_u256_vec_t *)(target); + target_second_word_ptr = (sz_u256_vec_t *)(target + length - 32); + _mm256_storeu_si256(&target_first_word_ptr->ymm, source_first_vec.ymm); + _mm256_storeu_si256(&target_second_word_ptr->ymm, source_second_vec.ymm); + } // When dealing with larger arrays, the optimization is not as simple as with the `sz_fill_haswell` function, // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, // we can use aligned loads and stores, and the performance will be great. @@ -471,7 +508,7 @@ SZ_PUBLIC void sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t leng for (; body_length >= 32; target += 32, source += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); } - // When the biffer is huge, we can traverse it in 2 directions. + // When the buffer is huge, we can traverse it in 2 directions. else { for (; body_length >= 64; target += 32, source += 32, body_length -= 64) { _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); @@ -494,13 +531,59 @@ SZ_PUBLIC void sz_copy_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t leng } SZ_PUBLIC void sz_move_haswell(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { - if (target < source || target >= source + length) { + + if (length < 8) { + if (target < source) + while (length--) *(target++) = *(source++); + else { + // Jump to the end and walk backwards: + target += length, source += length; + while (length--) *(--target) = *(--source); + } + } + // The next few sections are identical here and in the `sz_copy_haswell` function. + // We can use 2x 64-bit interleaving loads for each string, and then compare them for equality. + // The same approach is used in GLibC and was suggest by Denis Yaroshevskiy. + // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memcmp-avx2-movbe.S.html#518 + // It shouldn't improve performance on microbenchmarks, but should be better in practice. + else if (length <= 16) { + sz_u64_t source_first_word = *(sz_u64_t const *)(source); + sz_u64_t source_second_word = *(sz_u64_t const *)(source + length - 8); + sz_u64_t *target_first_word_ptr = (sz_u64_t *)(target); + sz_u64_t *target_second_word_ptr = (sz_u64_t *)(target + length - 8); + *target_first_word_ptr = source_first_word; + *target_second_word_ptr = source_second_word; + } + // We can use 2x 128-bit interleaving loads for each string, and then compare them for equality. + else if (length <= 32) { + sz_u128_vec_t source_first_vec, source_second_vec; + sz_u128_vec_t *target_first_word_ptr, *target_second_word_ptr; + source_first_vec.xmm = _mm_lddqu_si128((__m128i const *)(source)); + source_second_vec.xmm = _mm_lddqu_si128((__m128i const *)(source + length - 16)); + target_first_word_ptr = (sz_u128_vec_t *)(target); + target_second_word_ptr = (sz_u128_vec_t *)(target + length - 16); + _mm_storeu_si128(&target_first_word_ptr->xmm, source_first_vec.xmm); + _mm_storeu_si128(&target_second_word_ptr->xmm, source_second_vec.xmm); + } + // We can use 2x 256-bit interleaving loads for each string, and then compare them for equality. + else if (length <= 64) { + sz_u256_vec_t source_first_vec, source_second_vec; + sz_u256_vec_t *target_first_word_ptr, *target_second_word_ptr; + source_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(source)); + source_second_vec.ymm = _mm256_lddqu_si256((__m256i const *)(source + length - 32)); + target_first_word_ptr = (sz_u256_vec_t *)(target); + target_second_word_ptr = (sz_u256_vec_t *)(target + length - 32); + _mm256_storeu_si256(&target_first_word_ptr->ymm, source_first_vec.ymm); + _mm256_storeu_si256(&target_second_word_ptr->ymm, source_second_vec.ymm); + } + // When dealing with larger arrays, we keep things simple: + else if (target < source || target >= source + length) { for (; length >= 32; target += 32, source += 32, length -= 32) _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); while (length--) *(target++) = *(source++); } else { - // Jump to the end and walk backwards. + // Jump to the end and walk backwards: for (target += length, source += length; length >= 32; length -= 32) _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); while (length--) *(--target) = *(--source); From 268af531e29a0077272c3bf023944c2c92ec4ccc Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 25 Feb 2025 23:39:26 +0000 Subject: [PATCH 127/751] Fix: Passing new hashing tests --- include/stringzilla/hash.h | 781 ++++++++++++++++++++++++++++--------- 1 file changed, 605 insertions(+), 176 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 539f016a..105ffabe 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -36,7 +36,7 @@ * are combined with "shuffle & add" instructions to provide a high level of entropy in the output. That operation * is practically free, as many modern CPUs will dispatch them on different ports. On x86, for example: * - * - `VAESDEC` (ZMM, ZMM, ZMM)`: + * - `VAESENC` (ZMM, ZMM, ZMM)`: * - on Intel Ice Lake: 5 cycles on port 0. * - On AMD Zen4: 4 cycles on ports 0 or 1. * - `VPSHUFB_Z (ZMM, K, ZMM, ZMM)` @@ -46,10 +46,19 @@ * - on Intel Ice Lake: 1 cycle on ports 0 or 5. * - On AMD Zen4: 1 cycle on ports 0, 1, 2, 3. * - * Unlike "aHash", on long inputs, we use a procedure that is more vector-friendly on modern servers. + * Unlike "aHash", the length is not mixed into "AES" block at start to allow incremental construction. + * Unlike "aHash", on long inputs, we use a heavier procedure that is more vector-friendly on modern servers. * Unlike "aHash", we don't load interleaved memory regions, making vectorized variant more similar to sequential. - * On platforms like Skylake-X or newer, we also benefit from masked loads. + * Unlike "aHash", on platforms like Intel Skylake-X or AWS Graviton 3, we use masked loads. + * Unlike "aHash", in final folding procedure, we use the same `VAESENC` instead of `VAESDEC`, which + * still provides the same level of mixing, but allows us to have a lighter serial fallback implementation. * + * @see Reini Urban's more active fork of SMHasher by Austin Appleby: https://github.com/rurban/smhasher + * @see The serial AES routines are based on Morten Jensen's "tiny-AES-c": https://github.com/kokke/tiny-AES-c + * @see The "xxHash" C implementation by Yann Collet: https://github.com/Cyan4973/xxHash + * @see The "aHash" Rust implementation by Tom Kaitchuck: https://github.com/tkaitchuck/aHash + * @see "Emulating x86 AES Intrinsics on ARMv8-A" by Michael Brase: + * https://blog.michaelbrase.com/2018/05/08/emulating-x86-aes-intrinsics-on-armv8-a/ */ #ifndef STRINGZILLA_HASH_H_ #define STRINGZILLA_HASH_H_ @@ -153,9 +162,8 @@ SZ_DYNAMIC void sz_generate(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); typedef struct sz_hash_state_t { sz_u512_vec_t aes; sz_u512_vec_t sum; - sz_u512_vec_t key; - sz_u512_vec_t ins; + sz_u128_vec_t key; sz_size_t ins_length; } sz_hash_state_t; @@ -282,6 +290,24 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_neon(sz_hash_state_t const *state); #pragma endregion // Core API +#pragma region Helper Methods + +/** + * @brief Compares the state of two running hashes. + * @note The current content of the `ins` buffer and its length is ignored. + */ +SZ_PUBLIC sz_bool_t sz_hash_state_equal(sz_hash_state_t const *lhs, sz_hash_state_t const *rhs) { + return lhs->aes.u64s[0] == rhs->aes.u64s[0] && lhs->aes.u64s[1] == rhs->aes.u64s[1] && + lhs->aes.u64s[2] == rhs->aes.u64s[2] && lhs->aes.u64s[3] == rhs->aes.u64s[3] && + lhs->sum.u64s[0] == rhs->sum.u64s[0] && lhs->sum.u64s[1] == rhs->sum.u64s[1] && + lhs->sum.u64s[2] == rhs->sum.u64s[2] && lhs->sum.u64s[3] == rhs->sum.u64s[3] && + lhs->key.u64s[0] == rhs->key.u64s[0] && lhs->key.u64s[1] == rhs->key.u64s[1] + ? sz_true_k + : sz_false_k; +} + +#pragma endregion // Helper Methods + #pragma region Serial Implementation SZ_PUBLIC sz_u64_t sz_bytesum_serial(sz_cptr_t text, sz_size_t length) { @@ -292,24 +318,392 @@ SZ_PUBLIC sz_u64_t sz_bytesum_serial(sz_cptr_t text, sz_size_t length) { return bytesum; } +/** + * @brief Emulates the behaviour of `_mm_aesenc_si128` for a single round. + * This function is used as a fallback when the hardware-accelerated version is not available. + * @return Result of `MixColumns(SubBytes(ShiftRows(state))) ^ round_key`. + * @see Based on Jean-Philippe Aumasson's reference implementation: https://github.com/veorq/aesenc-noNI + */ +SZ_INTERNAL sz_u128_vec_t _sz_emulate_aesenc_si128_serial(sz_u128_vec_t state_vec, sz_u128_vec_t round_key_vec) { + static sz_u8_t const sbox[256] = { + // 0 1 2 3 4 5 6 7 8 9 A B C D E F + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, // + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, // + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, // + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, // + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, // + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, // + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, // + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, // + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, // + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, // + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, // + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, // + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, // + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, // + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, // + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16}; + + // Combine `ShiftRows` and `SubBytes` + sz_u8_t state_2d[4][4]; + for (int i = 0; i < 16; ++i) state_2d[((i / 4) + 4 - (i % 4)) % 4][i % 4] = sbox[state_vec.u8s[i]]; +#define _sz_gf2_double(x) (((x) << 1) ^ ((((x) >> 7) & 1) * 0x1b)) + // Perform `MixColumns` using GF2 multiplication by 2 + for (int i = 0; i < 4; ++i) { + sz_u8_t t = state_2d[i][0]; + sz_u8_t u = state_2d[i][0] ^ state_2d[i][1] ^ state_2d[i][2] ^ state_2d[i][3]; + state_2d[i][0] ^= u ^ _sz_gf2_double(state_2d[i][0] ^ state_2d[i][1]); + state_2d[i][1] ^= u ^ _sz_gf2_double(state_2d[i][1] ^ state_2d[i][2]); + state_2d[i][2] ^= u ^ _sz_gf2_double(state_2d[i][2] ^ state_2d[i][3]); + state_2d[i][3] ^= u ^ _sz_gf2_double(state_2d[i][3] ^ t); + } +#undef _sz_gf2_double + // Export `XOR`-ing with the round key + sz_u128_vec_t result; + for (int i = 0; i < 16; ++i) result.u8s[i] = state_2d[i / 4][i % 4] ^ round_key_vec.u8s[i]; + return result; +} + +SZ_INTERNAL sz_u128_vec_t _sz_emulate_shuffle_epi8_serial(sz_u128_vec_t state_vec, sz_u8_t const order[16]) { + sz_u128_vec_t result; + for (int i = 0; i < 16; ++i) result.u8s[i] = state_vec.u8s[order[i]]; + return result; +} + +/** + * @brief Provides 1024 bits worth of precomputed Pi constants for the hash. + * @return Pointer aligned to 64 bytes on SIMD-capable platforms. + * + * Bailey-Borwein-Plouffe @b (BBP) formula is used to compute the hexadecimal digits of Pi. + * It can be easily implemented in just 10 lines of Python and for 1024 bits requires 256 digits: + * + * @code{.py} + * def pi(digits: int) -> str: + * n, d = 0, 1 + * HEX = "0123456789ABCDEF" + * result = ["3."] + * for i in range(digits): + * xn = 120 * i**2 + 151 * i + 47 + * xd = 512 * i**4 + 1024 * i**3 + 712 * i**2 + 194 * i + 15 + * n = ((16 * n * xd) + (xn * d)) % (d * xd) + * d *= xd + * result.append(HEX[(16 * n) // d]) + * return "".join(result) + * @endcode + * + * For `pi(16)` the result is `3.243F6A8885A308D3` and you can find the digits after the dot in + * the first element of output array. + * + * @see Bailey-Borwein-Plouffe @b (BBP) formula explanation by Mosè Giordano: + * https://giordano.github.io/blog/2017-11-21-hexadecimal-pi/ + * + */ +SZ_INTERNAL sz_u64_t const *_sz_hash_pi_constants(void) { + static _SZ_ALIGN64 sz_u64_t const pi[16] = { + 0x243F6A8885A308D3ull, 0x13198A2E03707344ull, 0xA4093822299F31D0ull, 0x082EFA98EC4E6C89ull, + 0x452821E638D01377ull, 0xBE5466CF34E90C6Cull, 0xC0AC29B7C97C50DDull, 0x3F84D5B5B5470917ull, + 0x9216D5D98979FB1Bull, 0xD1310BA698DFB5ACull, 0x2FFD72DBD01ADFB7ull, 0xB8E1AFED6A267E96ull, + 0xBA7C9045F12C7F99ull, 0x24A19947B3916CF7ull, 0x0801F2E2858EFC16ull, 0x636920D871574E69ull, + }; + return &pi[0]; +} + +/** + * @brief Provides a shuffle mask for the additive part, identical to "aHash" in a single lane. + * @return Pointer aligned to 64 bytes on SIMD-capable platforms. + */ +SZ_INTERNAL sz_u8_t const *_sz_hash_u8x16x4_shuffle(void) { + static _SZ_ALIGN64 sz_u8_t const shuffle[64] = { + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // + 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // + 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02 // + }; + return &shuffle[0]; +} + +SZ_INTERNAL void _sz_hash_minimal_init_serial(_sz_hash_minimal_t *state, sz_u64_t seed) { + + // The key is made from the seed and half of it will be mixed with the length in the end + state->key.u64s[1] = seed; + state->key.u64s[0] = seed; + + // XOR the user-supplied keys with the two "pi" constants + sz_u64_t const *pi = _sz_hash_pi_constants(); + state->aes.u64s[0] = seed ^ pi[0]; + state->aes.u64s[1] = seed ^ pi[1]; + state->sum.u64s[0] = seed ^ pi[8]; + state->sum.u64s[1] = seed ^ pi[9]; +} + +SZ_INTERNAL void _sz_hash_minimal_update_serial(_sz_hash_minimal_t *state, sz_u128_vec_t block) { + sz_u8_t const *shuffle = _sz_hash_u8x16x4_shuffle(); + state->aes = _sz_emulate_aesenc_si128_serial(state->aes, block); + state->sum = _sz_emulate_shuffle_epi8_serial(state->sum, shuffle); + state->sum.u64s[0] += block.u64s[0], state->sum.u64s[1] += block.u64s[1]; +} + +SZ_INTERNAL sz_u64_t _sz_hash_minimal_finalize_serial(_sz_hash_minimal_t const *state, sz_size_t length) { + // Mix the length into the key + sz_u128_vec_t key_with_length = state->key; + key_with_length.u64s[0] += length; + // Combine the "sum" and the "AES" blocks + sz_u128_vec_t mixed_registers = _sz_emulate_aesenc_si128_serial(state->sum, state->aes); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + sz_u128_vec_t mixed_within_register = _sz_emulate_aesenc_si128_serial( + _sz_emulate_aesenc_si128_serial(mixed_registers, key_with_length), mixed_registers); + // Extract the low 64 bits + return mixed_within_register.u64s[0]; +} + +SZ_INTERNAL void _sz_hash_shift_in_register_serial(sz_u128_vec_t *vec, int shift_bytes) { + // One of the ridiculous things about x86, the `bsrli` instruction requires its operand to be an immediate. + // On GCC and Clang, we could use the provided `__int128` type, but MSVC doesn't support it. + // So we need to emulate it with 2x 64-bit shifts. + if (shift_bytes >= 8) { + vec->u64s[0] = (vec->u64s[1] >> (shift_bytes - 8) * 8); + vec->u64s[1] = (0); + } + else if (shift_bytes) { //! If `shift_bytes == 0`, the shift would cause UB. + vec->u64s[0] = (vec->u64s[0] >> shift_bytes * 8) | (vec->u64s[1] << (8 - shift_bytes) * 8); + vec->u64s[1] = (vec->u64s[1] >> shift_bytes * 8); + } +} + +SZ_PUBLIC void sz_hash_state_init_serial(sz_hash_state_t *state, sz_u64_t seed) { + + // The key is made from the seed and half of it will be mixed with the length in the end + state->key.u64s[0] = seed; + state->key.u64s[1] = seed; + + // XOR the user-supplied keys with the two "pi" constants + sz_u64_t const *pi = _sz_hash_pi_constants(); + for (int i = 0; i < 8; ++i) state->aes.u64s[i] = seed ^ pi[i]; + for (int i = 0; i < 8; ++i) state->sum.u64s[i] = seed ^ pi[i + 8]; + + // The inputs are zeroed out at the beginning + for (int i = 0; i < 8; ++i) state->ins.u64s[i] = 0; + state->ins_length = 0; +} + +SZ_INTERNAL void _sz_hash_state_update_serial(sz_hash_state_t *state) { + sz_u8_t const *shuffle = _sz_hash_u8x16x4_shuffle(); + + // To reuse the snippets above, let's cast to our familiar 128-bit vectors + sz_u128_vec_t *aes_vecs = (sz_u128_vec_t *)&state->aes.u64s[0]; + sz_u128_vec_t *sum_vecs = (sz_u128_vec_t *)&state->sum.u64s[0]; + sz_u128_vec_t *ins_vecs = (sz_u128_vec_t *)&state->ins.u64s[0]; + + // First 128-bit block + aes_vecs[0] = _sz_emulate_aesenc_si128_serial(aes_vecs[0], ins_vecs[0]); + sum_vecs[0] = _sz_emulate_shuffle_epi8_serial(sum_vecs[0], shuffle); + sum_vecs[0].u64s[0] += ins_vecs[0].u64s[0], sum_vecs[0].u64s[1] += ins_vecs[0].u64s[1]; + + // Second 128-bit block + aes_vecs[1] = _sz_emulate_aesenc_si128_serial(aes_vecs[1], ins_vecs[1]); + sum_vecs[1] = _sz_emulate_shuffle_epi8_serial(sum_vecs[1], shuffle); + sum_vecs[1].u64s[0] += ins_vecs[1].u64s[0], sum_vecs[1].u64s[1] += ins_vecs[1].u64s[1]; + + // Third 128-bit block + aes_vecs[2] = _sz_emulate_aesenc_si128_serial(aes_vecs[2], ins_vecs[2]); + sum_vecs[2] = _sz_emulate_shuffle_epi8_serial(sum_vecs[2], shuffle); + sum_vecs[2].u64s[0] += ins_vecs[2].u64s[0], sum_vecs[2].u64s[1] += ins_vecs[2].u64s[1]; + + // Fourth 128-bit block + aes_vecs[3] = _sz_emulate_aesenc_si128_serial(aes_vecs[3], ins_vecs[3]); + sum_vecs[3] = _sz_emulate_shuffle_epi8_serial(sum_vecs[3], shuffle); + sum_vecs[3].u64s[0] += ins_vecs[3].u64s[0], sum_vecs[3].u64s[1] += ins_vecs[3].u64s[1]; +} + +SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_serial(sz_hash_state_t const *state) { + + // Mix the length into the key + sz_u128_vec_t key_with_length = state->key; + key_with_length.u64s[0] += state->ins_length; + + // To reuse the snippets above, let's cast to our familiar 128-bit vectors + sz_u128_vec_t *aes_vecs = (sz_u128_vec_t *)&state->aes.u64s[0]; + sz_u128_vec_t *sum_vecs = (sz_u128_vec_t *)&state->sum.u64s[0]; + + // Combine the "sum" and the "AES" blocks + sz_u128_vec_t mixed_registers0 = _sz_emulate_aesenc_si128_serial(sum_vecs[0], aes_vecs[0]); + sz_u128_vec_t mixed_registers1 = _sz_emulate_aesenc_si128_serial(sum_vecs[1], aes_vecs[1]); + sz_u128_vec_t mixed_registers2 = _sz_emulate_aesenc_si128_serial(sum_vecs[2], aes_vecs[2]); + sz_u128_vec_t mixed_registers3 = _sz_emulate_aesenc_si128_serial(sum_vecs[3], aes_vecs[3]); + + // Combine the mixed registers + sz_u128_vec_t mixed_registers01 = _sz_emulate_aesenc_si128_serial(mixed_registers0, mixed_registers1); + sz_u128_vec_t mixed_registers23 = _sz_emulate_aesenc_si128_serial(mixed_registers2, mixed_registers3); + sz_u128_vec_t mixed_registers = _sz_emulate_aesenc_si128_serial(mixed_registers01, mixed_registers23); + + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + sz_u128_vec_t mixed_within_register = _sz_emulate_aesenc_si128_serial( + _sz_emulate_aesenc_si128_serial(mixed_registers, key_with_length), mixed_registers); + + // Extract the low 64 bits + return mixed_within_register.u64s[0]; +} + SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { - sz_unused(start && length && seed); - return 0; + if (length <= 16) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_serial(&state, seed); + // Load the data and update the state + sz_u128_vec_t data_vec; + data_vec.u64s[0] = data_vec.u64s[1] = 0; + for (sz_size_t i = 0; i < length; ++i) data_vec.u8s[i] = start[i]; + _sz_hash_minimal_update_serial(&state, data_vec); + return _sz_hash_minimal_finalize_serial(&state, length); + } + else if (length <= 32) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_serial(&state, seed); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec; + data0_vec.u64s[0] = *(sz_u64_t const *)(start); + data0_vec.u64s[1] = *(sz_u64_t const *)(start + 8); + data1_vec.u64s[0] = *(sz_u64_t const *)(start + length - 16); + data1_vec.u64s[1] = *(sz_u64_t const *)(start + length - 8); + // Let's shift the data within the register to de-interleave the bytes. + _sz_hash_shift_in_register_serial(&data1_vec, 32 - length); + _sz_hash_minimal_update_serial(&state, data0_vec); + _sz_hash_minimal_update_serial(&state, data1_vec); + return _sz_hash_minimal_finalize_serial(&state, length); + } + else if (length <= 48) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_serial(&state, seed); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec; + data0_vec.u64s[0] = *(sz_u64_t const *)(start); + data0_vec.u64s[1] = *(sz_u64_t const *)(start + 8); + data1_vec.u64s[0] = *(sz_u64_t const *)(start + 16); + data1_vec.u64s[1] = *(sz_u64_t const *)(start + 24); + data2_vec.u64s[0] = *(sz_u64_t const *)(start + length - 16); + data2_vec.u64s[1] = *(sz_u64_t const *)(start + length - 8); + // Let's shift the data within the register to de-interleave the bytes. + _sz_hash_shift_in_register_serial(&data2_vec, 48 - length); + _sz_hash_minimal_update_serial(&state, data0_vec); + _sz_hash_minimal_update_serial(&state, data1_vec); + _sz_hash_minimal_update_serial(&state, data2_vec); + return _sz_hash_minimal_finalize_serial(&state, length); + } + else if (length <= 64) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_serial(&state, seed); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; + data0_vec.u64s[0] = *(sz_u64_t const *)(start); + data0_vec.u64s[1] = *(sz_u64_t const *)(start + 8); + data1_vec.u64s[0] = *(sz_u64_t const *)(start + 16); + data1_vec.u64s[1] = *(sz_u64_t const *)(start + 24); + data2_vec.u64s[0] = *(sz_u64_t const *)(start + 32); + data2_vec.u64s[1] = *(sz_u64_t const *)(start + 40); + data3_vec.u64s[0] = *(sz_u64_t const *)(start + length - 16); + data3_vec.u64s[1] = *(sz_u64_t const *)(start + length - 8); + // Let's shift the data within the register to de-interleave the bytes. + _sz_hash_shift_in_register_serial(&data3_vec, 64 - length); + _sz_hash_minimal_update_serial(&state, data0_vec); + _sz_hash_minimal_update_serial(&state, data1_vec); + _sz_hash_minimal_update_serial(&state, data2_vec); + _sz_hash_minimal_update_serial(&state, data3_vec); + return _sz_hash_minimal_finalize_serial(&state, length); + } + else { + // Use a larger state to handle the main loop and add different offsets + // to different lanes of the register + sz_hash_state_t state; + sz_hash_state_init_serial(&state, seed); + + for (; state.ins_length + 64 <= length; state.ins_length += 64) { + state.ins.u64s[0] = *(sz_u64_t const *)(start + state.ins_length); + state.ins.u64s[1] = *(sz_u64_t const *)(start + state.ins_length + 8); + state.ins.u64s[2] = *(sz_u64_t const *)(start + state.ins_length + 16); + state.ins.u64s[3] = *(sz_u64_t const *)(start + state.ins_length + 24); + state.ins.u64s[4] = *(sz_u64_t const *)(start + state.ins_length + 32); + state.ins.u64s[5] = *(sz_u64_t const *)(start + state.ins_length + 40); + state.ins.u64s[6] = *(sz_u64_t const *)(start + state.ins_length + 48); + state.ins.u64s[7] = *(sz_u64_t const *)(start + state.ins_length + 56); + _sz_hash_state_update_serial(&state); + } + if (state.ins_length < length) { + for (sz_size_t i = 0; i != 8; ++i) state.ins.u64s[i] = 0; + for (sz_size_t i = 0; state.ins_length < length; ++i, ++state.ins_length) + state.ins.u8s[i] = start[state.ins_length]; + _sz_hash_state_update_serial(&state); + state.ins_length = length; + } + return _sz_hash_state_finalize_serial(&state); + } } -SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { - sz_unused(text && length && nonce); +SZ_PUBLIC void sz_hash_state_stream_serial(sz_hash_state_t *state_ptr, sz_cptr_t text, sz_size_t length) { + while (length) { + sz_size_t progress_in_block = state_ptr->ins_length % 64; + sz_size_t to_copy = sz_min_of_two(length, 64 - progress_in_block); + int const will_fill_block = progress_in_block + to_copy == 64; + // Update the metadata before we modify the `to_copy` variable + state_ptr->ins_length += to_copy; + length -= to_copy; + // Append to the internal buffer until it's full + while (to_copy--) state_ptr->ins.u8s[progress_in_block++] = *text++; + // If we've reached the end of the buffer, update the state + if (will_fill_block) { + _sz_hash_state_update_serial(state_ptr); + // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state + for (int i = 0; i < 8; ++i) state_ptr->ins.u64s[i] = 0; + } + } } -SZ_PUBLIC void sz_hash_state_init_serial(sz_hash_state_t *state, sz_u64_t seed) { sz_unused(state && seed); } +SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state_ptr) { + sz_size_t length = state_ptr->ins_length; + if (length >= 64) return _sz_hash_state_finalize_serial(state_ptr); + + // Switch back to a smaller "minimal" state for small inputs + _sz_hash_minimal_t state; + state.key = state_ptr->key; + state.aes = *(sz_u128_vec_t const *)&state_ptr->aes.u64s[0]; + state.sum = *(sz_u128_vec_t const *)&state_ptr->sum.u64s[0]; -SZ_PUBLIC void sz_hash_state_stream_serial(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { - sz_unused(state && text && length); + // The logic is different depending on the length of the input + sz_u128_vec_t const *ins_vecs = (sz_u128_vec_t const *)&state_ptr->ins.u64s[0]; + if (length <= 16) { + _sz_hash_minimal_update_serial(&state, ins_vecs[0]); + return _sz_hash_minimal_finalize_serial(&state, length); + } + else if (length <= 32) { + _sz_hash_minimal_update_serial(&state, ins_vecs[0]); + _sz_hash_minimal_update_serial(&state, ins_vecs[1]); + return _sz_hash_minimal_finalize_serial(&state, length); + } + else if (length <= 48) { + _sz_hash_minimal_update_serial(&state, ins_vecs[0]); + _sz_hash_minimal_update_serial(&state, ins_vecs[1]); + _sz_hash_minimal_update_serial(&state, ins_vecs[2]); + return _sz_hash_minimal_finalize_serial(&state, length); + } + else { + _sz_hash_minimal_update_serial(&state, ins_vecs[0]); + _sz_hash_minimal_update_serial(&state, ins_vecs[1]); + _sz_hash_minimal_update_serial(&state, ins_vecs[2]); + _sz_hash_minimal_update_serial(&state, ins_vecs[3]); + return _sz_hash_minimal_finalize_serial(&state, length); + } } -SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state) { - sz_unused(state); - return 0; +SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_unused(text && length && nonce); } #pragma endregion // Serial Implementation @@ -321,7 +715,7 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state) { #if SZ_USE_HASWELL #pragma GCC push_options #pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx3332"))), apply_to = function) +#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. @@ -408,82 +802,87 @@ SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length) { } SZ_INTERNAL void _sz_hash_minimal_init_haswell(_sz_hash_minimal_t *state, sz_u64_t seed) { + sz_u64_t const *pi = _sz_hash_pi_constants(); + __m128i const pi0 = _mm_load_si128((__m128i const *)(pi)); + __m128i const pi1 = _mm_load_si128((__m128i const *)(pi + 8)); + + // The key is made from the seed and half of it will be mixed with the length in the end __m128i seed_vec = _mm_set1_epi64x(seed); - __m128i pi0 = _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull); - __m128i pi1 = _mm_set_epi64x(0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + state->key.xmm = seed_vec; + // XOR the user-supplied keys with the two "pi" constants __m128i k1 = _mm_xor_si128(seed_vec, pi0); __m128i k2 = _mm_xor_si128(seed_vec, pi1); - // Export the keys to the state + + // The first 128 bits of the "sum" and "AES" blocks are the same state->aes.xmm = k1; state->sum.xmm = k2; - state->key.xmm = _mm_xor_si128(pi0, pi1); } -SZ_INTERNAL sz_u64_t _sz_hash_minimal_finalize_haswell(_sz_hash_minimal_t const *state) { - // Combine the sum and the AES block +SZ_INTERNAL sz_u64_t _sz_hash_minimal_finalize_haswell(_sz_hash_minimal_t const *state, sz_size_t length) { + // Mix the length into the key + __m128i key_with_length = _mm_add_epi64(state->key.xmm, _mm_set_epi64x(0, length)); + // Combine the "sum" and the "AES" blocks __m128i mixed_registers = _mm_aesenc_si128(state->sum.xmm, state->aes.xmm); // Make sure the "key" mixes enough with the state, // as with less than 2 rounds - SMHasher fails __m128i mixed_within_register = - _mm_aesdec_si128(_mm_aesdec_si128(mixed_registers, state->key.xmm), mixed_registers); + _mm_aesenc_si128(_mm_aesenc_si128(mixed_registers, key_with_length), mixed_registers); // Extract the low 64 bits return _mm_cvtsi128_si64(mixed_within_register); } SZ_INTERNAL void _sz_hash_minimal_update_haswell(_sz_hash_minimal_t *state, __m128i block) { - // This shuffle mask is identical to "aHash": - __m128i const shuffle_mask = _mm_set_epi8( // - 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // - 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02); - state->aes.xmm = _mm_aesdec_si128(state->aes.xmm, block); + __m128i const shuffle_mask = _mm_load_si128((__m128i const *)_sz_hash_u8x16x4_shuffle()); + state->aes.xmm = _mm_aesenc_si128(state->aes.xmm, block); state->sum.xmm = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmm, shuffle_mask), block); } SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) { + // The key is made from the seed and half of it will be mixed with the length in the end __m128i seed_vec = _mm_set1_epi64x(seed); - __m128i pi0 = _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull); - __m128i pi1 = _mm_set_epi64x(0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + state->key.xmm = seed_vec; + // XOR the user-supplied keys with the two "pi" constants - __m128i k1 = _mm_xor_si128(seed_vec, pi0); - __m128i k2 = _mm_xor_si128(seed_vec, pi1); - // Export the keys to the state - state->aes.xmms[0] = state->aes.xmms[1] = state->aes.xmms[2] = state->aes.xmms[3] = k1; - state->sum.xmms[0] = state->sum.xmms[1] = state->sum.xmms[2] = state->sum.xmms[3] = k2; - state->key.xmms[0] = state->key.xmms[1] = state->key.xmms[2] = state->key.xmms[3] = _mm_xor_si128(pi0, pi1); + sz_u64_t const *pi = _sz_hash_pi_constants(); + for (int i = 0; i < 4; ++i) + state->aes.xmms[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2))); + for (int i = 0; i < 4; ++i) + state->sum.xmms[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2 + 8))); + + // The inputs are zeroed out at the beginning + state->ins.xmms[0] = state->ins.xmms[1] = state->ins.xmms[2] = state->ins.xmms[3] = _mm_setzero_si128(); state->ins_length = 0; } -SZ_INTERNAL void _sz_hash_state_update_haswell(sz_hash_state_t *state, __m128i block0, __m128i block1, __m128i block2, - __m128i block3) { - // This shuffle mask is identical to "aHash": - __m128i const shuffle_mask = _mm_set_epi8( // - 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // - 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02); - state->aes.xmms[0] = _mm_aesdec_si128(state->aes.xmms[0], block0); - state->sum.xmms[0] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[0], shuffle_mask), block0); - state->aes.xmms[1] = _mm_aesdec_si128(state->aes.xmms[1], block1); - state->sum.xmms[1] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[1], shuffle_mask), block1); - state->aes.xmms[2] = _mm_aesdec_si128(state->aes.xmms[2], block2); - state->sum.xmms[2] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[2], shuffle_mask), block2); - state->aes.xmms[3] = _mm_aesdec_si128(state->aes.xmms[3], block3); - state->sum.xmms[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[3], shuffle_mask), block3); +SZ_INTERNAL void _sz_hash_state_update_haswell(sz_hash_state_t *state) { + __m128i const shuffle_mask = _mm_load_si128((__m128i const *)_sz_hash_u8x16x4_shuffle()); + state->aes.xmms[0] = _mm_aesenc_si128(state->aes.xmms[0], state->ins.xmms[0]); + state->sum.xmms[0] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[0], shuffle_mask), state->ins.xmms[0]); + state->aes.xmms[1] = _mm_aesenc_si128(state->aes.xmms[1], state->ins.xmms[1]); + state->sum.xmms[1] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[1], shuffle_mask), state->ins.xmms[1]); + state->aes.xmms[2] = _mm_aesenc_si128(state->aes.xmms[2], state->ins.xmms[2]); + state->sum.xmms[2] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[2], shuffle_mask), state->ins.xmms[2]); + state->aes.xmms[3] = _mm_aesenc_si128(state->aes.xmms[3], state->ins.xmms[3]); + state->sum.xmms[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[3], shuffle_mask), state->ins.xmms[3]); } -SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *state) { - // Combine the sum and the AES block - __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.xmms[0], state->aes.xmms[0]); - __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.xmms[1], state->aes.xmms[1]); - __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.xmms[2], state->aes.xmms[2]); - __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.xmms[3], state->aes.xmms[3]); +SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *state_ptr) { + // Mix the length into the key + __m128i key_with_length = _mm_add_epi64(state_ptr->key.xmm, _mm_set_epi64x(0, state_ptr->ins_length)); + // Combine the "sum" and the "AES" blocks + __m128i mixed_registers0 = _mm_aesenc_si128(state_ptr->sum.xmms[0], state_ptr->aes.xmms[0]); + __m128i mixed_registers1 = _mm_aesenc_si128(state_ptr->sum.xmms[1], state_ptr->aes.xmms[1]); + __m128i mixed_registers2 = _mm_aesenc_si128(state_ptr->sum.xmms[2], state_ptr->aes.xmms[2]); + __m128i mixed_registers3 = _mm_aesenc_si128(state_ptr->sum.xmms[3], state_ptr->aes.xmms[3]); // Combine the mixed registers __m128i mixed_registers01 = _mm_aesenc_si128(mixed_registers0, mixed_registers1); __m128i mixed_registers23 = _mm_aesenc_si128(mixed_registers2, mixed_registers3); __m128i mixed_registers = _mm_aesenc_si128(mixed_registers01, mixed_registers23); // Make sure the "key" mixes enough with the state, // as with less than 2 rounds - SMHasher fails - __m128i mixed_within_register = _mm_aesdec_si128( // - _mm_aesdec_si128(mixed_registers, state->key.xmms[0]), mixed_registers); + __m128i mixed_within_register = + _mm_aesenc_si128(_mm_aesenc_si128(mixed_registers, key_with_length), mixed_registers); // Extract the low 64 bits return _mm_cvtsi128_si64(mixed_within_register); } @@ -491,86 +890,77 @@ SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *stat SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { if (length <= 16) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data_vec; data_vec.xmm = _mm_setzero_si128(); for (sz_size_t i = 0; i < length; ++i) data_vec.u8s[i] = start[i]; _sz_hash_minimal_update_haswell(&state, data_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 32) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec; - data0_vec.xmm = _mm_lddqu_si128(start); - data1_vec.xmm = _mm_lddqu_si128(start + length - 16); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + length - 16)); // Let's shift the data within the register to de-interleave the bytes. - data1_vec.xmm = _mm_bsrli_si128(data1_vec.xmm, 32 - length); + _sz_hash_shift_in_register_serial(&data1_vec, 32 - length); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 48) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec; - data0_vec.xmm = _mm_lddqu_si128(start); - data1_vec.xmm = _mm_lddqu_si128(start + 16); - data2_vec.xmm = _mm_lddqu_si128(start + length - 16); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); + data2_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + length - 16)); // Let's shift the data within the register to de-interleave the bytes. - data2_vec.xmm = _mm_bsrli_si128(data2_vec.xmm, 48 - length); + _sz_hash_shift_in_register_serial(&data2_vec, 48 - length); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 64) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; - data0_vec.xmm = _mm_lddqu_si128(start); - data1_vec.xmm = _mm_lddqu_si128(start + 16); - data2_vec.xmm = _mm_lddqu_si128(start + 32); - data3_vec.xmm = _mm_lddqu_si128(start + length - 16); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); + data2_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 32)); + data3_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + length - 16)); // Let's shift the data within the register to de-interleave the bytes. - data3_vec.xmm = _mm_bsrli_si128(data3_vec.xmm, 64 - length); + _sz_hash_shift_in_register_serial(&data3_vec, 64 - length); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); _sz_hash_minimal_update_haswell(&state, data3_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else { // Use a larger state to handle the main loop and add different offsets // to different lanes of the register sz_hash_state_t state; sz_hash_state_init_haswell(&state, seed); - state.aes.xmms[0] = _mm_add_epi64(state.aes.xmms[0], _mm_set_epi64x(0, length)); - state.aes.xmms[1] = _mm_add_epi64(state.aes.xmms[1], _mm_set_epi64x(16, length)); - state.aes.xmms[2] = _mm_add_epi64(state.aes.xmms[2], _mm_set_epi64x(32, length)); - state.aes.xmms[3] = _mm_add_epi64(state.aes.xmms[3], _mm_set_epi64x(48, length)); - for (; state.ins_length + 64 <= length; state.ins_length += 64) { - state.ins.xmms[0] = _mm_lddqu_si128(start + state.ins_length); - state.ins.xmms[1] = _mm_lddqu_si128(start + state.ins_length + 16); - state.ins.xmms[2] = _mm_lddqu_si128(start + state.ins_length + 32); - state.ins.xmms[3] = _mm_lddqu_si128(start + state.ins_length + 48); - _sz_hash_state_update_haswell(&state, state.ins.xmms[0], state.ins.xmms[1], state.ins.xmms[2], - state.ins.xmms[3]); + state.ins.xmms[0] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length)); + state.ins.xmms[1] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length + 16)); + state.ins.xmms[2] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length + 32)); + state.ins.xmms[3] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length + 48)); + _sz_hash_state_update_haswell(&state); } + // Handle the tail, resetting the registers to zero first if (state.ins_length < length) { state.ins.xmms[0] = _mm_setzero_si128(); state.ins.xmms[1] = _mm_setzero_si128(); @@ -578,26 +968,85 @@ SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t start, sz_size_t length, sz_u64_t s state.ins.xmms[3] = _mm_setzero_si128(); for (sz_size_t i = 0; state.ins_length < length; ++i, ++state.ins_length) state.ins.u8s[i] = start[state.ins_length]; - _sz_hash_state_update_haswell(&state, state.ins.xmms[0], state.ins.xmms[1], state.ins.xmms[2], - state.ins.xmms[3]); + _sz_hash_state_update_haswell(&state); + state.ins_length = length; } return _sz_hash_state_finalize_haswell(&state); } } -SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { - sz_generate_serial(text, length, nonce); +SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state_ptr, sz_cptr_t text, sz_size_t length) { + while (length) { + // Append to the internal buffer until it's full + if (state_ptr->ins_length % 64 == 0 && length >= 64) { + state_ptr->ins.xmms[0] = _mm_lddqu_si128((__m128i const *)text); + state_ptr->ins.xmms[1] = _mm_lddqu_si128((__m128i const *)(text + 16)); + state_ptr->ins.xmms[2] = _mm_lddqu_si128((__m128i const *)(text + 32)); + state_ptr->ins.xmms[3] = _mm_lddqu_si128((__m128i const *)(text + 48)); + _sz_hash_state_update_haswell(state_ptr); + state_ptr->ins_length += 64; + text += 64; + length -= 64; + } + // If vectorization isn't that trivial - fall back to the serial implementation + else { + sz_size_t progress_in_block = state_ptr->ins_length % 64; + sz_size_t to_copy = sz_min_of_two(length, 64 - progress_in_block); + int const will_fill_block = progress_in_block + to_copy == 64; + // Update the metadata before we modify the `to_copy` variable + state_ptr->ins_length += to_copy; + length -= to_copy; + // Append to the internal buffer until it's full + while (to_copy--) state_ptr->ins.u8s[progress_in_block++] = *text++; + // If we've reached the end of the buffer, update the state + if (will_fill_block) { + _sz_hash_state_update_haswell(state_ptr); + // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state + for (int i = 0; i < 4; ++i) state_ptr->ins.xmms[i] = _mm_setzero_si128(); + } + } + } } -SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) { - sz_hash_state_init_serial(state, seed); -} +SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state_ptr) { + sz_size_t length = state_ptr->ins_length; + if (length >= 64) return _sz_hash_state_finalize_haswell(state_ptr); -SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { - sz_hash_state_stream_serial(state, text, length); + // Switch back to a smaller "minimal" state for small inputs + _sz_hash_minimal_t state; + state.key.xmm = state_ptr->key.xmm; + state.aes.xmm = state_ptr->aes.xmms[0]; + state.sum.xmm = state_ptr->sum.xmms[0]; + + // The logic is different depending on the length of the input + __m128i const *ins_vecs = (__m128i const *)&state_ptr->ins.xmms[0]; + if (length <= 16) { + _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); + return _sz_hash_minimal_finalize_haswell(&state, length); + } + else if (length <= 32) { + _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); + _sz_hash_minimal_update_haswell(&state, ins_vecs[1]); + return _sz_hash_minimal_finalize_haswell(&state, length); + } + else if (length <= 48) { + _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); + _sz_hash_minimal_update_haswell(&state, ins_vecs[1]); + _sz_hash_minimal_update_haswell(&state, ins_vecs[2]); + return _sz_hash_minimal_finalize_haswell(&state, length); + } + else { + _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); + _sz_hash_minimal_update_haswell(&state, ins_vecs[1]); + _sz_hash_minimal_update_haswell(&state, ins_vecs[2]); + _sz_hash_minimal_update_haswell(&state, ins_vecs[3]); + return _sz_hash_minimal_finalize_haswell(&state, length); + } } -SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } +SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_generate_serial(text, length, nonce); +} #pragma clang attribute pop #pragma GCC pop_options @@ -716,99 +1165,91 @@ SZ_PUBLIC sz_u64_t sz_bytesum_skylake(sz_cptr_t text, sz_size_t length) { } SZ_PUBLIC void sz_hash_state_init_skylake(sz_hash_state_t *state, sz_u64_t seed) { + // The key is made from the seed and half of it will be mixed with the length in the end __m512i seed_vec = _mm512_set1_epi64(seed); - __m512i pi0 = _mm512_set_epi64( // - 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, - 0x13198a2e03707344ull, 0x243f6a8885a308d3ull, 0x13198a2e03707344ull, 0x243f6a8885a308d3ull); - __m512i pi1 = _mm512_set_epi64( // - 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, - 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull, 0x082efa98ec4e6c89ull, 0xa4093822299f31d0ull); + state->key.xmm = _mm512_castsi512_si128(seed_vec); + // XOR the user-supplied keys with the two "pi" constants - __m512i k1 = _mm512_xor_si512(seed_vec, pi0); - __m512i k2 = _mm512_xor_si512(seed_vec, pi1); - // Export the keys to the state - state->aes.zmm = k1; - state->sum.zmm = k2; - state->key.zmm = _mm512_xor_si512(pi0, pi1); + sz_u64_t const *pi = _sz_hash_pi_constants(); + __m512i const pi0 = _mm512_load_epi64((__m512i const *)(pi)); + __m512i const pi1 = _mm512_load_epi64((__m512i const *)(pi + 8)); + state->aes.zmm = _mm512_xor_si512(seed_vec, pi0); + state->sum.zmm = _mm512_xor_si512(seed_vec, pi1); + + // The inputs are zeroed out at the beginning + state->ins.zmm = _mm512_setzero_si512(); state->ins_length = 0; } SZ_PUBLIC sz_u64_t sz_hash_skylake(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { if (length <= 16) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data_vec; data_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length), start); _sz_hash_minimal_update_haswell(&state, data_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 32) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec; - data0_vec.xmm = _mm_lddqu_si128(start); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); data1_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 16), start + 16); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 48) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec; - data0_vec.xmm = _mm_lddqu_si128(start); - data1_vec.xmm = _mm_lddqu_si128(start + 16); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); data2_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 32), start + 32); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 64) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; - data0_vec.xmm = _mm_lddqu_si128(start); - data1_vec.xmm = _mm_lddqu_si128(start + 16); - data2_vec.xmm = _mm_lddqu_si128(start + 32); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); + data2_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 32)); data3_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 48), start + 48); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); _sz_hash_minimal_update_haswell(&state, data3_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else { // Use a larger state to handle the main loop and add different offsets // to different lanes of the register sz_hash_state_t state; sz_hash_state_init_skylake(&state, seed); - state.aes.zmm = _mm512_add_epi64( // - state.aes.zmm, // - _mm512_set_epi64(0, length, 16, length, 32, length, 48, length)); for (; state.ins_length + 64 <= length; state.ins_length += 64) { state.ins.zmm = _mm512_loadu_epi8(start + state.ins_length); - _sz_hash_state_update_haswell(&state, state.ins.xmms[0], state.ins.xmms[1], state.ins.xmms[2], - state.ins.xmms[3]); + _sz_hash_state_update_haswell(&state); } if (state.ins_length < length) { state.ins.zmm = _mm512_maskz_loadu_epi8( // _sz_u64_mask_until(length - state.ins_length), start + state.ins_length); - _sz_hash_state_update_skylake(&state, state.ins.zmm); + _sz_hash_state_update_haswell(&state); + state.ins_length = length; } return _sz_hash_state_finalize_haswell(&state); } @@ -970,97 +1411,81 @@ SZ_PUBLIC sz_u64_t sz_bytesum_ice(sz_cptr_t text, sz_size_t length) { } } -SZ_INTERNAL void _sz_hash_state_update_ice(sz_hash_state_t *state, __m512i block) { - // This shuffle mask is identical to "aHash": - __m512i const shuffle_mask = _mm512_set_epi8( // - 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // - 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // - 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // - 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // - 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // - 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02, // - 0x04, 0x0b, 0x09, 0x06, 0x08, 0x0d, 0x0f, 0x05, // - 0x0e, 0x03, 0x01, 0x0c, 0x00, 0x07, 0x0a, 0x02 // - ); - state->aes.zmm = _mm512_aesdec_epi128(state->aes.zmm, block); - state->sum.zmm = _mm512_add_epi64(_mm512_shuffle_epi8(state->sum.zmm, shuffle_mask), block); +SZ_INTERNAL void _sz_hash_state_update_ice(sz_hash_state_t *state) { + __m512i const shuffle_mask = _mm512_load_si512((__m512i const *)_sz_hash_u8x16x4_shuffle()); + state->aes.zmm = _mm512_aesenc_epi128(state->aes.zmm, state->ins.zmm); + state->sum.zmm = _mm512_add_epi64(_mm512_shuffle_epi8(state->sum.zmm, shuffle_mask), state->ins.zmm); } SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { if (length <= 16) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data_vec; data_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length), start); _sz_hash_minimal_update_haswell(&state, data_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 32) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec; - data0_vec.xmm = _mm_lddqu_si128(start); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); data1_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 16), start + 16); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 48) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec; - data0_vec.xmm = _mm_lddqu_si128(start); - data1_vec.xmm = _mm_lddqu_si128(start + 16); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); data2_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 32), start + 32); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else if (length <= 64) { - // Initialize the AES block with a given seed and update with the input length + // Initialize the AES block with a given seed _sz_hash_minimal_t state; _sz_hash_minimal_init_haswell(&state, seed); - state.aes.xmm = _mm_add_epi64(state.aes.xmm, _mm_set_epi64x(0, length)); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; - data0_vec.xmm = _mm_lddqu_si128(start); - data1_vec.xmm = _mm_lddqu_si128(start + 16); - data2_vec.xmm = _mm_lddqu_si128(start + 32); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); + data2_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 32)); data3_vec.xmm = _mm_maskz_loadu_epi8(_sz_u16_mask_until(length - 48), start + 48); _sz_hash_minimal_update_haswell(&state, data0_vec.xmm); _sz_hash_minimal_update_haswell(&state, data1_vec.xmm); _sz_hash_minimal_update_haswell(&state, data2_vec.xmm); _sz_hash_minimal_update_haswell(&state, data3_vec.xmm); - return _sz_hash_minimal_finalize_haswell(&state); + return _sz_hash_minimal_finalize_haswell(&state, length); } else { // Use a larger state to handle the main loop and add different offsets // to different lanes of the register sz_hash_state_t state; sz_hash_state_init_skylake(&state, seed); - state.aes.zmm = _mm512_add_epi64( // - state.aes.zmm, // - _mm512_set_epi64(0, length, 16, length, 32, length, 48, length)); for (; state.ins_length + 64 <= length; state.ins_length += 64) { state.ins.zmm = _mm512_loadu_epi8(start + state.ins_length); - _sz_hash_state_update_ice(&state, state.ins.zmm); + _sz_hash_state_update_ice(&state); } if (state.ins_length < length) { state.ins.zmm = _mm512_maskz_loadu_epi8( // _sz_u64_mask_until(length - state.ins_length), start + state.ins_length); - _sz_hash_state_update_ice(&state, state.ins.zmm); + _sz_hash_state_update_ice(&state); + state.ins_length = length; } return _sz_hash_state_finalize_haswell(&state); } @@ -1119,6 +1544,10 @@ SZ_PUBLIC void sz_generate_ice(sz_ptr_t output, sz_size_t length, sz_u64_t nonce } } +SZ_PUBLIC void sz_hash_state_init_ice(sz_hash_state_t *state, sz_u64_t seed) { + sz_hash_state_init_skylake(state, seed); +} + SZ_PUBLIC void sz_hash_state_stream_ice(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { sz_hash_state_stream_serial(state, text, length); } From 8ac3a23a8db20a44f883475f6b865df45d883308 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 26 Feb 2025 10:46:07 +0000 Subject: [PATCH 128/751] Add: Streaming hash benchmarks --- scripts/bench_token.cpp | 51 +++++++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 93ae2b7e..378ad4f0 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -38,24 +38,60 @@ tracked_unary_functions_t bytesum_functions() { return result; } -tracked_unary_functions_t hashing_functions() { +tracked_unary_functions_t hash_functions() { auto wrap_sz = [](auto function) -> unary_function_t { return unary_function_t([function](std::string_view s) { return function(s.data(), s.size(), 42); }); }; tracked_unary_functions_t result = { - {"std::hash", [](std::string_view s) { return std::hash {}(s); }}, {"sz_hash_serial", wrap_sz(sz_hash_serial)}, #if SZ_USE_HASWELL - {"sz_hash_haswell", wrap_sz(sz_hash_haswell)}, + {"sz_hash_haswell", wrap_sz(sz_hash_haswell), true}, +#endif +#if SZ_USE_SKYLAKE + {"sz_hash_skylake", wrap_sz(sz_hash_skylake), true}, +#endif +#if SZ_USE_ICE + {"sz_hash_ice", wrap_sz(sz_hash_ice), true}, +#endif +#if SZ_USE_NEON + {"sz_hash_neon", wrap_sz(sz_hash_neon), true}, +#endif + {"std::hash", [](std::string_view s) { return std::hash {}(s); }}, + }; + return result; +} + +struct wrapped_incremental_hash { + sz_hash_state_t state; + sz_hash_state_stream_t stream; + sz_hash_state_fold_t fold; + + wrapped_incremental_hash(sz_hash_state_stream_t s, sz_hash_state_fold_t f) : stream(s), fold(f) { + sz_hash_state_init(&state, 42); + } + + std::size_t operator()(std::string_view s) noexcept { + stream(&state, s.data(), s.size()); + return fold(&state); + } +}; + +tracked_unary_functions_t hash_stream_functions() { + tracked_unary_functions_t result = { + {"sz_hash_stream_serial", wrapped_incremental_hash(sz_hash_state_stream_serial, sz_hash_state_fold_serial)}, +#if SZ_USE_HASWELL + {"sz_hash_stream_haswell", wrapped_incremental_hash(sz_hash_state_stream_haswell, sz_hash_state_fold_haswell), + true}, #endif #if SZ_USE_SKYLAKE - {"sz_hash_skylake", wrap_sz(sz_hash_skylake)}, + {"sz_hash_stream_skylake", wrapped_incremental_hash(sz_hash_state_stream_skylake, sz_hash_state_fold_skylake), + true}, #endif #if SZ_USE_ICE - {"sz_hash_ice", wrap_sz(sz_hash_ice)}, + {"sz_hash_stream_ice", wrapped_incremental_hash(sz_hash_state_stream_ice, sz_hash_state_fold_ice), true}, #endif #if SZ_USE_NEON - {"sz_hash_neon", wrap_sz(sz_hash_neon)}, + {"sz_hash_stream_neon", wrapped_incremental_hash(sz_hash_state_stream_neon, sz_hash_state_fold_neon), true}, #endif }; return result; @@ -152,7 +188,8 @@ void bench(strings_type &&strings) { // Benchmark logical operations bench_unary_functions(strings, bytesum_functions()); - bench_unary_functions(strings, hashing_functions()); + bench_unary_functions(strings, hash_functions()); + bench_unary_functions(strings, hash_stream_functions()); bench_binary_functions(strings, equality_functions()); bench_binary_functions(strings, ordering_functions()); From 2607d4513f97b2dd55ac6301e7dc956ad4a5d9ad Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 26 Feb 2025 11:54:45 +0000 Subject: [PATCH 129/751] Add: Streaming hashing on Ice Lake & Skylake X --- include/stringzilla/hash.h | 228 ++++++++++++++++++++++--------------- 1 file changed, 134 insertions(+), 94 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 105ffabe..50deb776 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -297,13 +297,15 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_neon(sz_hash_state_t const *state); * @note The current content of the `ins` buffer and its length is ignored. */ SZ_PUBLIC sz_bool_t sz_hash_state_equal(sz_hash_state_t const *lhs, sz_hash_state_t const *rhs) { - return lhs->aes.u64s[0] == rhs->aes.u64s[0] && lhs->aes.u64s[1] == rhs->aes.u64s[1] && - lhs->aes.u64s[2] == rhs->aes.u64s[2] && lhs->aes.u64s[3] == rhs->aes.u64s[3] && - lhs->sum.u64s[0] == rhs->sum.u64s[0] && lhs->sum.u64s[1] == rhs->sum.u64s[1] && - lhs->sum.u64s[2] == rhs->sum.u64s[2] && lhs->sum.u64s[3] == rhs->sum.u64s[3] && - lhs->key.u64s[0] == rhs->key.u64s[0] && lhs->key.u64s[1] == rhs->key.u64s[1] - ? sz_true_k - : sz_false_k; + int same_aes = // + lhs->aes.u64s[0] == rhs->aes.u64s[0] && lhs->aes.u64s[1] == rhs->aes.u64s[1] && + lhs->aes.u64s[2] == rhs->aes.u64s[2] && lhs->aes.u64s[3] == rhs->aes.u64s[3]; + int same_sum = // + lhs->sum.u64s[0] == rhs->sum.u64s[0] && lhs->sum.u64s[1] == rhs->sum.u64s[1] && + lhs->sum.u64s[2] == rhs->sum.u64s[2] && lhs->sum.u64s[3] == rhs->sum.u64s[3]; + int same_key = // + lhs->key.u64s[0] == rhs->key.u64s[0] && lhs->key.u64s[1] == rhs->key.u64s[1]; + return same_aes && same_sum && same_key ? sz_true_k : sz_false_k; } #pragma endregion // Helper Methods @@ -647,58 +649,58 @@ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length, sz_u64_t se } } -SZ_PUBLIC void sz_hash_state_stream_serial(sz_hash_state_t *state_ptr, sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC void sz_hash_state_stream_serial(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { while (length) { - sz_size_t progress_in_block = state_ptr->ins_length % 64; + sz_size_t progress_in_block = state->ins_length % 64; sz_size_t to_copy = sz_min_of_two(length, 64 - progress_in_block); int const will_fill_block = progress_in_block + to_copy == 64; // Update the metadata before we modify the `to_copy` variable - state_ptr->ins_length += to_copy; + state->ins_length += to_copy; length -= to_copy; // Append to the internal buffer until it's full - while (to_copy--) state_ptr->ins.u8s[progress_in_block++] = *text++; + while (to_copy--) state->ins.u8s[progress_in_block++] = *text++; // If we've reached the end of the buffer, update the state if (will_fill_block) { - _sz_hash_state_update_serial(state_ptr); + _sz_hash_state_update_serial(state); // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state - for (int i = 0; i < 8; ++i) state_ptr->ins.u64s[i] = 0; + for (int i = 0; i < 8; ++i) state->ins.u64s[i] = 0; } } } -SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state_ptr) { - sz_size_t length = state_ptr->ins_length; - if (length >= 64) return _sz_hash_state_finalize_serial(state_ptr); +SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state) { + sz_size_t length = state->ins_length; + if (length >= 64) return _sz_hash_state_finalize_serial(state); // Switch back to a smaller "minimal" state for small inputs - _sz_hash_minimal_t state; - state.key = state_ptr->key; - state.aes = *(sz_u128_vec_t const *)&state_ptr->aes.u64s[0]; - state.sum = *(sz_u128_vec_t const *)&state_ptr->sum.u64s[0]; + _sz_hash_minimal_t minimal_state; + minimal_state.key = state->key; + minimal_state.aes = *(sz_u128_vec_t const *)&state->aes.u64s[0]; + minimal_state.sum = *(sz_u128_vec_t const *)&state->sum.u64s[0]; // The logic is different depending on the length of the input - sz_u128_vec_t const *ins_vecs = (sz_u128_vec_t const *)&state_ptr->ins.u64s[0]; + sz_u128_vec_t const *ins_vecs = (sz_u128_vec_t const *)&state->ins.u64s[0]; if (length <= 16) { - _sz_hash_minimal_update_serial(&state, ins_vecs[0]); - return _sz_hash_minimal_finalize_serial(&state, length); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[0]); + return _sz_hash_minimal_finalize_serial(&minimal_state, length); } else if (length <= 32) { - _sz_hash_minimal_update_serial(&state, ins_vecs[0]); - _sz_hash_minimal_update_serial(&state, ins_vecs[1]); - return _sz_hash_minimal_finalize_serial(&state, length); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[1]); + return _sz_hash_minimal_finalize_serial(&minimal_state, length); } else if (length <= 48) { - _sz_hash_minimal_update_serial(&state, ins_vecs[0]); - _sz_hash_minimal_update_serial(&state, ins_vecs[1]); - _sz_hash_minimal_update_serial(&state, ins_vecs[2]); - return _sz_hash_minimal_finalize_serial(&state, length); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[1]); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[2]); + return _sz_hash_minimal_finalize_serial(&minimal_state, length); } else { - _sz_hash_minimal_update_serial(&state, ins_vecs[0]); - _sz_hash_minimal_update_serial(&state, ins_vecs[1]); - _sz_hash_minimal_update_serial(&state, ins_vecs[2]); - _sz_hash_minimal_update_serial(&state, ins_vecs[3]); - return _sz_hash_minimal_finalize_serial(&state, length); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[1]); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[2]); + _sz_hash_minimal_update_serial(&minimal_state, ins_vecs[3]); + return _sz_hash_minimal_finalize_serial(&minimal_state, length); } } @@ -802,15 +804,15 @@ SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length) { } SZ_INTERNAL void _sz_hash_minimal_init_haswell(_sz_hash_minimal_t *state, sz_u64_t seed) { - sz_u64_t const *pi = _sz_hash_pi_constants(); - __m128i const pi0 = _mm_load_si128((__m128i const *)(pi)); - __m128i const pi1 = _mm_load_si128((__m128i const *)(pi + 8)); // The key is made from the seed and half of it will be mixed with the length in the end __m128i seed_vec = _mm_set1_epi64x(seed); state->key.xmm = seed_vec; // XOR the user-supplied keys with the two "pi" constants + sz_u64_t const *pi = _sz_hash_pi_constants(); + __m128i const pi0 = _mm_load_si128((__m128i const *)(pi)); + __m128i const pi1 = _mm_load_si128((__m128i const *)(pi + 8)); __m128i k1 = _mm_xor_si128(seed_vec, pi0); __m128i k2 = _mm_xor_si128(seed_vec, pi1); @@ -867,14 +869,14 @@ SZ_INTERNAL void _sz_hash_state_update_haswell(sz_hash_state_t *state) { state->sum.xmms[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[3], shuffle_mask), state->ins.xmms[3]); } -SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *state_ptr) { +SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *state) { // Mix the length into the key - __m128i key_with_length = _mm_add_epi64(state_ptr->key.xmm, _mm_set_epi64x(0, state_ptr->ins_length)); + __m128i key_with_length = _mm_add_epi64(state->key.xmm, _mm_set_epi64x(0, state->ins_length)); // Combine the "sum" and the "AES" blocks - __m128i mixed_registers0 = _mm_aesenc_si128(state_ptr->sum.xmms[0], state_ptr->aes.xmms[0]); - __m128i mixed_registers1 = _mm_aesenc_si128(state_ptr->sum.xmms[1], state_ptr->aes.xmms[1]); - __m128i mixed_registers2 = _mm_aesenc_si128(state_ptr->sum.xmms[2], state_ptr->aes.xmms[2]); - __m128i mixed_registers3 = _mm_aesenc_si128(state_ptr->sum.xmms[3], state_ptr->aes.xmms[3]); + __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.xmms[0], state->aes.xmms[0]); + __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.xmms[1], state->aes.xmms[1]); + __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.xmms[2], state->aes.xmms[2]); + __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.xmms[3], state->aes.xmms[3]); // Combine the mixed registers __m128i mixed_registers01 = _mm_aesenc_si128(mixed_registers0, mixed_registers1); __m128i mixed_registers23 = _mm_aesenc_si128(mixed_registers2, mixed_registers3); @@ -975,72 +977,72 @@ SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t start, sz_size_t length, sz_u64_t s } } -SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state_ptr, sz_cptr_t text, sz_size_t length) { +SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { while (length) { // Append to the internal buffer until it's full - if (state_ptr->ins_length % 64 == 0 && length >= 64) { - state_ptr->ins.xmms[0] = _mm_lddqu_si128((__m128i const *)text); - state_ptr->ins.xmms[1] = _mm_lddqu_si128((__m128i const *)(text + 16)); - state_ptr->ins.xmms[2] = _mm_lddqu_si128((__m128i const *)(text + 32)); - state_ptr->ins.xmms[3] = _mm_lddqu_si128((__m128i const *)(text + 48)); - _sz_hash_state_update_haswell(state_ptr); - state_ptr->ins_length += 64; + if (state->ins_length % 64 == 0 && length >= 64) { + state->ins.xmms[0] = _mm_lddqu_si128((__m128i const *)text); + state->ins.xmms[1] = _mm_lddqu_si128((__m128i const *)(text + 16)); + state->ins.xmms[2] = _mm_lddqu_si128((__m128i const *)(text + 32)); + state->ins.xmms[3] = _mm_lddqu_si128((__m128i const *)(text + 48)); + _sz_hash_state_update_haswell(state); + state->ins_length += 64; text += 64; length -= 64; } // If vectorization isn't that trivial - fall back to the serial implementation else { - sz_size_t progress_in_block = state_ptr->ins_length % 64; + sz_size_t progress_in_block = state->ins_length % 64; sz_size_t to_copy = sz_min_of_two(length, 64 - progress_in_block); int const will_fill_block = progress_in_block + to_copy == 64; // Update the metadata before we modify the `to_copy` variable - state_ptr->ins_length += to_copy; + state->ins_length += to_copy; length -= to_copy; // Append to the internal buffer until it's full - while (to_copy--) state_ptr->ins.u8s[progress_in_block++] = *text++; + while (to_copy--) state->ins.u8s[progress_in_block++] = *text++; // If we've reached the end of the buffer, update the state if (will_fill_block) { - _sz_hash_state_update_haswell(state_ptr); + _sz_hash_state_update_haswell(state); // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state - for (int i = 0; i < 4; ++i) state_ptr->ins.xmms[i] = _mm_setzero_si128(); + for (int i = 0; i < 4; ++i) state->ins.xmms[i] = _mm_setzero_si128(); } } } } -SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state_ptr) { - sz_size_t length = state_ptr->ins_length; - if (length >= 64) return _sz_hash_state_finalize_haswell(state_ptr); +SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state) { + sz_size_t length = state->ins_length; + if (length >= 64) return _sz_hash_state_finalize_haswell(state); // Switch back to a smaller "minimal" state for small inputs - _sz_hash_minimal_t state; - state.key.xmm = state_ptr->key.xmm; - state.aes.xmm = state_ptr->aes.xmms[0]; - state.sum.xmm = state_ptr->sum.xmms[0]; + _sz_hash_minimal_t minimal_state; + minimal_state.key.xmm = state->key.xmm; + minimal_state.aes.xmm = state->aes.xmms[0]; + minimal_state.sum.xmm = state->sum.xmms[0]; // The logic is different depending on the length of the input - __m128i const *ins_vecs = (__m128i const *)&state_ptr->ins.xmms[0]; + __m128i const *ins_vecs = (__m128i const *)&state->ins.xmms[0]; if (length <= 16) { - _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); - return _sz_hash_minimal_finalize_haswell(&state, length); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[0]); + return _sz_hash_minimal_finalize_haswell(&minimal_state, length); } else if (length <= 32) { - _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); - _sz_hash_minimal_update_haswell(&state, ins_vecs[1]); - return _sz_hash_minimal_finalize_haswell(&state, length); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[1]); + return _sz_hash_minimal_finalize_haswell(&minimal_state, length); } else if (length <= 48) { - _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); - _sz_hash_minimal_update_haswell(&state, ins_vecs[1]); - _sz_hash_minimal_update_haswell(&state, ins_vecs[2]); - return _sz_hash_minimal_finalize_haswell(&state, length); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[1]); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[2]); + return _sz_hash_minimal_finalize_haswell(&minimal_state, length); } else { - _sz_hash_minimal_update_haswell(&state, ins_vecs[0]); - _sz_hash_minimal_update_haswell(&state, ins_vecs[1]); - _sz_hash_minimal_update_haswell(&state, ins_vecs[2]); - _sz_hash_minimal_update_haswell(&state, ins_vecs[3]); - return _sz_hash_minimal_finalize_haswell(&state, length); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[1]); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[2]); + _sz_hash_minimal_update_haswell(&minimal_state, ins_vecs[3]); + return _sz_hash_minimal_finalize_haswell(&minimal_state, length); } } @@ -1255,15 +1257,35 @@ SZ_PUBLIC sz_u64_t sz_hash_skylake(sz_cptr_t start, sz_size_t length, sz_u64_t s } } -SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { - sz_generate_serial(text, length, nonce); +SZ_PUBLIC void sz_hash_state_stream_skylake(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + while (length) { + sz_size_t const progress_in_block = state->ins_length % 64; + sz_size_t const to_copy = sz_min_of_two(length, 64 - progress_in_block); + int const will_fill_block = progress_in_block + to_copy == 64; + // Update the metadata before we modify the `to_copy` variable + state->ins_length += to_copy; + length -= to_copy; + // Append to the internal buffer until it's full + __mmask64 to_copy_mask = _sz_u64_mask_until(to_copy); + _mm512_mask_storeu_epi8(&state->ins.u8s[0] + progress_in_block, to_copy_mask, + _mm512_maskz_loadu_epi8(to_copy_mask, text)); + text += to_copy; + // If we've reached the end of the buffer, update the state + if (will_fill_block) { + _sz_hash_state_update_haswell(state); + // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state + state->ins.zmm = _mm512_setzero_si512(); + } + } } -SZ_PUBLIC void sz_hash_state_stream_skylake(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { - sz_hash_state_stream_serial(state, text, length); +SZ_PUBLIC sz_u64_t sz_hash_state_fold_skylake(sz_hash_state_t const *state) { + return sz_hash_state_fold_haswell(state); } -SZ_PUBLIC sz_u64_t sz_hash_state_fold_skylake(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } +SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_generate_serial(text, length, nonce); +} #pragma clang attribute pop #pragma GCC pop_options @@ -1491,6 +1513,34 @@ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) } } +SZ_PUBLIC void sz_hash_state_init_ice(sz_hash_state_t *state, sz_u64_t seed) { + sz_hash_state_init_skylake(state, seed); +} + +SZ_PUBLIC void sz_hash_state_stream_ice(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + while (length) { + sz_size_t progress_in_block = state->ins_length % 64; + sz_size_t to_copy = sz_min_of_two(length, 64 - progress_in_block); + int const will_fill_block = progress_in_block + to_copy == 64; + // Update the metadata before we modify the `to_copy` variable + state->ins_length += to_copy; + length -= to_copy; + // Append to the internal buffer until it's full + __mmask64 to_copy_mask = _sz_u64_mask_until(to_copy); + _mm512_mask_storeu_epi8(state->ins.u8s + progress_in_block, to_copy_mask, + _mm512_maskz_loadu_epi8(to_copy_mask, text)); + text += to_copy; + // If we've reached the end of the buffer, update the state + if (will_fill_block) { + _sz_hash_state_update_ice(state); + // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state + state->ins.zmm = _mm512_setzero_si512(); + } + } +} + +SZ_PUBLIC sz_u64_t sz_hash_state_fold_ice(sz_hash_state_t const *state) { return sz_hash_state_fold_haswell(state); } + SZ_PUBLIC void sz_generate_ice(sz_ptr_t output, sz_size_t length, sz_u64_t nonce) { // We can use `_mm512_broadcast_i32x4` and the `vbroadcasti32x4` instruction, but its latency is freaking 8 cycles. // The `_mm512_shuffle_i32x4` and the `vshufi32x4` instruction has a latency of 3 cycles, somewhat better. @@ -1544,16 +1594,6 @@ SZ_PUBLIC void sz_generate_ice(sz_ptr_t output, sz_size_t length, sz_u64_t nonce } } -SZ_PUBLIC void sz_hash_state_init_ice(sz_hash_state_t *state, sz_u64_t seed) { - sz_hash_state_init_skylake(state, seed); -} - -SZ_PUBLIC void sz_hash_state_stream_ice(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { - sz_hash_state_stream_serial(state, text, length); -} - -SZ_PUBLIC sz_u64_t sz_hash_state_fold_ice(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_ICE From 80688bb11e136bb6227d1bc9089606620ea142c8 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 26 Feb 2025 17:25:22 +0000 Subject: [PATCH 130/751] Improve: Testing hash functions --- scripts/test.cpp | 145 ++++++++++++++++++++++++++++++++++++++++------- scripts/test.hpp | 26 ++++++++- 2 files changed, 151 insertions(+), 20 deletions(-) diff --git a/scripts/test.cpp b/scripts/test.cpp index e3a62f3d..13090d8c 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1,27 +1,43 @@ +/** + * @brief Extensive @b unit-testing suite for StringZilla, written in C++. + * @note It tests one target hardware platform at a time and should be compiled and run separately for each. + * To override the default hardware platform, overrides the @b `SZ_USE_*` flags at the top of this file. + * + * @see Stress-tests on real-world and synthetic data are integrated into the @b `scripts/bench*.cpp` benchmarks. + * + * @file test.cpp + * @author Ash Vardanian + */ + +#include #undef NDEBUG // Enable all assertions -// Enable assertions for iterators +/* The Visual C++ run-time library detects incorrect iterator use, + * and asserts and displays a dialog box at run time on Windows. + */ #if !defined(_ITERATOR_DEBUG_LEVEL) || _ITERATOR_DEBUG_LEVEL == 0 #define _ITERATOR_DEBUG_LEVEL 1 #endif #include // assertions -// Overload the following with caution. -// Those parameters must never be explicitly set during releases, -// but they come handy during development, if you want to validate -// different ISA-specific implementations. +/** + * ! Overload the following with caution. + * ! Those parameters must never be explicitly set during releases, + * ! but they come handy during development, if you want to validate + * ! different ISA-specific implementations. + */ // #define SZ_USE_HASWELL 0 +// #define SZ_USE_SKYLAKE 0 // #define SZ_USE_ICE 0 // #define SZ_USE_NEON 0 // #define SZ_USE_SVE 0 #define SZ_DEBUG 1 // Enforce aggressive logging for this unit. -// Put this at the top to make sure it pulls all the right dependencies #include #if defined(__SANITIZE_ADDRESS__) -#include // ASAN +#include // We use ASAN API to poison memory addresses #endif #include // `std::transform` @@ -133,6 +149,84 @@ static void test_arithmetical_utilities() { #endif } +/** + * @brief Several string processing operations rely on computing integer logarithms. + * Failures in such operations will result in wrong `resize` outcomes and heap corruption. + */ +static void test_hashing_on_platform( // + sz_hash_t hash_base, sz_hash_state_init_t init_base, // + sz_hash_state_stream_t stream_base, sz_hash_state_fold_t fold_base, // + sz_hash_t hash_simd, sz_hash_state_init_t init_simd, // + sz_hash_state_stream_t stream_simd, sz_hash_state_fold_t fold_simd) { + + auto test_on_seed = [&](std::string text, sz_u64_t seed) { + // Compute the entire hash at once, expecting the same output + sz_u64_t result_base = hash_base(text.data(), text.size(), seed); + sz_u64_t result_simd = hash_simd(text.data(), text.size(), seed); + assert(result_base == result_simd); + + // Compare incremental hashing across platforms + sz_hash_state_t state_base, state_simd; + init_base(&state_base, seed); + init_simd(&state_simd, seed); + assert(sz_hash_state_equal(&state_base, &state_base) == sz_true_k); // Self-equality + assert(sz_hash_state_equal(&state_simd, &state_simd) == sz_true_k); // Self-equality + assert(sz_hash_state_equal(&state_base, &state_simd) == sz_true_k); // Same across platforms + + // Try breaking those strings into arbitrary chunks, expecting the same output in the streaming mode. + // The length of each chunk and the number of chunks will be determined with a coin toss. + iterate_in_random_slices(text, [&](std::string slice) { + stream_base(&state_base, slice.data(), slice.size()); + stream_simd(&state_simd, slice.data(), slice.size()); + assert(sz_hash_state_equal(&state_base, &state_simd) == sz_true_k); // Same across platforms + result_base = fold_base(&state_base); + result_simd = fold_simd(&state_simd); + assert(result_base == result_simd); + }); + }; + + // Let's try different-length strings repeating a "abc" pattern: + std::vector seeds = { + 0u, 42u, // + std::numeric_limits::max(), // + std::numeric_limits::max(), // + }; + for (auto seed : seeds) + for (std::size_t copies = 1; copies != 100; ++copies) // + test_on_seed(repeat("abc", copies), seed); +} + +static void test_hashing_across_platforms() { +#if SZ_USE_HASWELL + test_hashing_on_platform( // + sz_hash_serial, sz_hash_state_init_serial, // + sz_hash_state_stream_serial, sz_hash_state_fold_serial, // + sz_hash_haswell, sz_hash_state_init_haswell, // + sz_hash_state_stream_haswell, sz_hash_state_fold_haswell); +#endif +#if SZ_USE_SKYLAKE + test_hashing_on_platform( // + sz_hash_serial, sz_hash_state_init_serial, // + sz_hash_state_stream_serial, sz_hash_state_fold_serial, // + sz_hash_skylake, sz_hash_state_init_skylake, // + sz_hash_state_stream_skylake, sz_hash_state_fold_skylake); +#endif +#if SZ_USE_ICE + test_hashing_on_platform( // + sz_hash_serial, sz_hash_state_init_serial, // + sz_hash_state_stream_serial, sz_hash_state_fold_serial, // + sz_hash_ice, sz_hash_state_init_ice, // + sz_hash_state_stream_ice, sz_hash_state_fold_ice); +#endif +#if SZ_USE_NEON + test_hashing_on_platform( // + sz_hash_serial, sz_hash_state_init_serial, // + sz_hash_state_stream_serial, sz_hash_state_fold_serial, // + sz_hash_neon, sz_hash_state_init_neon, // + sz_hash_state_stream_neon, sz_hash_state_fold_neon); +#endif +}; + /** * @brief Tests various ASCII-based methods (e.g., `is_alpha`, `is_digit`) * provided by `sz::string` and `sz::string_view`. @@ -291,10 +385,10 @@ static void test_memory_utilities( // #if 0 // TODO: // We are going to randomly select the "source" and "target" slices of the strings. - // For `memcpy` and `memset` the offsets should have uniform ditribution, + // For `memcpy` and `memset` the offsets should have uniform distribution, // while the length should decay with an exponential distribution. // For `memmove` the offset should be uniform, but the "shift" and "length" should - // be exponenetial. The exponential distributions should be functions of the cache line width. + // be exponential. The exponential distributions should be functions of the cache line width. // https://en.cppreference.com/w/cpp/numeric/random/exponential_distribution std::string dataset(max_l2_size, '-'); auto &gen = global_random_generator(); @@ -953,13 +1047,13 @@ static void test_constructors() { strings.push_back(alphabet.substr(0, alphabet_slice)); std::vector copies {strings}; assert(copies.size() == strings.size()); - for (size_t i = 0; i < copies.size(); i++) { + for (size_t i = 0; i < copies.size(); ++i) { assert(copies[i].size() == strings[i].size()); assert(copies[i] == strings[i]); for (size_t j = 0; j < strings[i].size(); j++) { assert(copies[i][j] == strings[i][j]); } } std::vector assignments = strings; - for (size_t i = 0; i < assignments.size(); i++) { + for (size_t i = 0; i < assignments.size(); ++i) { assert(assignments[i].size() == strings[i].size()); assert(assignments[i] == strings[i]); for (size_t j = 0; j < strings[i].size(); j++) { assert(assignments[i][j] == strings[i][j]); } @@ -1027,12 +1121,12 @@ static void test_memory_stability_for_length(std::size_t len = 1ull << 10) { using string = sz::basic_string; string base; - for (std::size_t i = 0; i < len; i++) base.push_back('c'); + for (std::size_t i = 0; i < len; ++i) base.push_back('c'); assert(base.length() == len); // Do copies leak? assert_balanced_memory([&]() { - for (std::size_t i = 0; i < iterations; i++) { + for (std::size_t i = 0; i < iterations; ++i) { string copy(base); assert(copy.length() == len); assert(copy == base); @@ -1041,7 +1135,7 @@ static void test_memory_stability_for_length(std::size_t len = 1ull << 10) { // How about assignments? assert_balanced_memory([&]() { - for (std::size_t i = 0; i < iterations; i++) { + for (std::size_t i = 0; i < iterations; ++i) { string copy; copy = base; assert(copy.length() == len); @@ -1051,7 +1145,7 @@ static void test_memory_stability_for_length(std::size_t len = 1ull << 10) { // How about the move constructor? assert_balanced_memory([&]() { - for (std::size_t i = 0; i < iterations; i++) { + for (std::size_t i = 0; i < iterations; ++i) { string unique_item(base); assert(unique_item.length() == len); assert(unique_item == base); @@ -1063,7 +1157,7 @@ static void test_memory_stability_for_length(std::size_t len = 1ull << 10) { // And the move assignment operator with an empty target payload? assert_balanced_memory([&]() { - for (std::size_t i = 0; i < iterations; i++) { + for (std::size_t i = 0; i < iterations; ++i) { string unique_item(base); string copy; copy = std::move(unique_item); @@ -1074,7 +1168,7 @@ static void test_memory_stability_for_length(std::size_t len = 1ull << 10) { // And move assignment where the target had a payload? assert_balanced_memory([&]() { - for (std::size_t i = 0; i < iterations; i++) { + for (std::size_t i = 0; i < iterations; ++i) { string unique_item(base); string copy; for (std::size_t j = 0; j < 317; j++) copy.push_back('q'); @@ -1570,7 +1664,7 @@ void test_replacements(std::size_t lookup_tables_to_try = 128, std::size_t slice for (std::size_t lookup_table_variation = 0; lookup_table_variation != lookup_tables_to_try; ++lookup_table_variation) { sz::look_up_table lut; - for (std::size_t i = 0; i < 256; i++) lut[(char)i] = (char)(std::rand() % 256); + for (std::size_t i = 0; i < 256; ++i) lut[(char)i] = (char)(std::rand() % 256); for (std::size_t slice_idx = 0; slice_idx != slices_per_table; ++slice_idx) { std::size_t slice_offset = std::rand() % (body.length()); @@ -1597,7 +1691,7 @@ static void test_sequence_algorithms() { sz_sequence_t sequence; sz_cptr_t strings[] = {"banana", "apple", "cherry"}; sz_sequence_from_null_terminated_strings(strings, 3, &sequence); - assert(sequence.size == 3); + assert(sequence.count == 3); assert(sequence.get_start(sequence.handle, 0) == "banana"_sv); assert(sequence.get_start(sequence.handle, 1) == "apple"_sv); assert(sequence.get_start(sequence.handle, 2) == "cherry"_sv); @@ -1687,6 +1781,14 @@ static void test_stl_containers() { int main(int argc, char const **argv) { + sz_u128_vec_t some_state, some_key; + randomize_string((char *)&some_state.u8s[0], 16); + randomize_string((char *)&some_key.u8s[0], 16); + sz_u128_vec_t emulated_result = _sz_emulate_aesenc_si128_serial(some_state, some_key); + sz_u128_vec_t hardware_result; + hardware_result.xmm = _mm_aesenc_si128(some_state.xmm, some_key.xmm); + assert(memcmp(&emulated_result, &hardware_result, sizeof(sz_u128_vec_t)) == 0); + // Let's greet the user nicely sz_unused(argc && argv); std::printf("Hi, dear tester! You look nice today!\n"); @@ -1698,6 +1800,11 @@ int main(int argc, char const **argv) { // Basic utilities test_arithmetical_utilities(); + + // Compatibility across hardware-specific implementations + test_hashing_across_platforms(); + + // Core APIs test_ascii_utilities(); test_ascii_utilities(); test_memory_utilities(); diff --git a/scripts/test.hpp b/scripts/test.hpp index 6c37e9f6..ede983bc 100644 --- a/scripts/test.hpp +++ b/scripts/test.hpp @@ -1,5 +1,5 @@ /** - * @brief Helper structures and functions for C++ tests. + * @brief Helper structures and functions for C++ unit- and stress-tests. */ #pragma once #include // `std::ifstream` @@ -56,12 +56,36 @@ inline void randomize_string(char *string, std::size_t length, char const *alpha std::generate(string, string + length, [&]() -> char { return alphabet[distribution(global_random_generator())]; }); } +inline void randomize_string(char *string, std::size_t length) { + uniform_uint8_distribution_t distribution; + std::generate(string, string + length, [&]() -> char { return distribution(global_random_generator()); }); +} + inline std::string random_string(std::size_t length, char const *alphabet, std::size_t cardinality) { std::string result(length, '\0'); randomize_string(&result[0], length, alphabet, cardinality); return result; } +inline std::string repeat(std::string const &patten, std::size_t count) { + std::string result(patten.size() * count, '\0'); + for (std::size_t i = 0; i < count; ++i) std::copy(patten.begin(), patten.end(), result.begin() + i * patten.size()); + return result; +} + +/** + * @brief A callback type for iterating over consecutive random-length slices of a string. + */ +template +inline void iterate_in_random_slices(std::string const &text, slice_callback_type_ &&slice_callback) { + std::size_t remaining = text.size(); + while (remaining > 0) { + std::size_t slice_length = std::uniform_int_distribution(1, remaining)(global_random_generator()); + slice_callback({text.data() + text.size() - remaining, slice_length}); + remaining -= slice_length; + } +} + /** * @brief Inefficient baseline Levenshtein distance computation, as implemented in most codebases. * Allocates a new matrix on every call, with rows potentially scattered around memory. From 6659aa0869582df0e611fbdcb457fd83aca0917e Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 26 Feb 2025 17:42:41 +0000 Subject: [PATCH 131/751] Add: PRNG for Haswell & serial backend --- include/stringzilla/hash.h | 169 ++++++++++++++++++++++++++++++------- scripts/test.cpp | 42 +++++++-- 2 files changed, 174 insertions(+), 37 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 50deb776..6ef11e3d 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -125,14 +125,14 @@ SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed); * @brief A Pseudorandom Number Generator (PRNG), inspired the AES-CTR-128 algorithm, * but using only one round of AES mixing as opposed to "NIST SP 800-90A". * - * CTR_DRBG (CounTeR mode Deterministic Random Bit Generator) appears secure and indistinguishable from a true - * random source when AES is used as the underlying block cipher and 112 bits are taken from this PRNG. + * CTR_DRBG (CounTeR mode Deterministic Random Bit Generator) appears secure and indistinguishable from a + * true random source when AES is used as the underlying block cipher and 112 bits are taken from this PRNG. * When AES is used as the underlying block cipher and 128 bits are taken from each instantiation, * the required security level is delivered with the caveat that a 128-bit cipher's output in * counter mode can be distinguished from a true RNG. * * In this case, it doesn't apply, as we only use one round of AES mixing. We also don't expose a separate "key", - * only a "nonce", to keep the API simple. + * only a "nonce", to keep the API simple, but we mix it with 512 bits of Pi constants to increase randomness. * * @param[out] text Output string buffer to be populated. * @param[in] length Number of bytes in the string. @@ -145,7 +145,7 @@ SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed); * int main() { * char first_buffer[5], second_buffer[5]; * sz_generate(first_buffer, 5, 0); - * sz_generate(second_buffer, 5, 0); //? Same nonce will produce the same output + * sz_generate(second_buffer, 5, 0); //? Same nonce must produce the same output * return sz_bytesum(first_buffer, 5) == sz_bytesum(second_buffer, 5) ? 0 : 1; * } * @endcode @@ -705,7 +705,19 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state) { } SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { - sz_unused(text && length && nonce); + sz_u64_t const *pi_ptr = _sz_hash_pi_constants(); + sz_u128_vec_t input_vec, pi_vec, key_vec, generated_vec; + for (sz_size_t lane_index = 0; length; ++lane_index) { + // Each 128-bit block is initialized with the same nonce + input_vec.u64s[0] = input_vec.u64s[1] = nonce + lane_index; + // We rotate the first 512-bits of the Pi to mix with the nonce + pi_vec = ((sz_u128_vec_t const *)pi_ptr)[lane_index % 4]; + key_vec.u64s[0] = nonce ^ pi_vec.u64s[0]; + key_vec.u64s[1] = nonce ^ pi_vec.u64s[1]; + generated_vec = _sz_emulate_aesenc_si128_serial(input_vec, key_vec); + // Export back to the user-supplied buffer + for (int i = 0; i < 16 && length; ++i, --length) *text++ = generated_vec.u8s[i]; + } } #pragma endregion // Serial Implementation @@ -1047,7 +1059,97 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state) { } SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { - sz_generate_serial(text, length, nonce); + sz_u64_t const *pi_ptr = _sz_hash_pi_constants(); + if (length <= 16) { + __m128i input = _mm_set1_epi64x(nonce); + __m128i pi = _mm_load_si128((__m128i const *)pi_ptr); + __m128i key = _mm_xor_si128(_mm_set1_epi64x(nonce), pi); + __m128i generated = _mm_aesenc_si128(input, key); + // Now the tricky part is outputting this data to the user-supplied buffer + // without masked writes, like in AVX-512. + for (sz_size_t i = 0; i < length; ++i) text[i] = ((sz_u8_t *)&generated)[i]; + } + // Assuming the YMM register contains two 128-bit blocks, the input to the generator + // will be more complex, containing the sum of the nonce and the block number. + else if (length <= 32) { + __m128i inputs[2], pis[2], keys[2], generated[2]; + inputs[0] = _mm_set1_epi64x(nonce); + inputs[1] = _mm_set1_epi64x(nonce + 1); + pis[0] = _mm_load_si128((__m128i const *)(pi_ptr)); + pis[1] = _mm_load_si128((__m128i const *)(pi_ptr + 2)); + keys[0] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[0]); + keys[1] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[1]); + generated[0] = _mm_aesenc_si128(inputs[0], keys[0]); + generated[1] = _mm_aesenc_si128(inputs[1], keys[1]); + // The first store can easily be vectorized, but the second can be serial for now + _mm_storeu_si128((__m128i *)text, generated[0]); + for (sz_size_t i = 16; i < length; ++i) text[i] = ((sz_u8_t *)&generated[1])[i - 16]; + } + // The last special case we handle outside of the primary loop is for buffers up to 64 bytes long. + else if (length <= 48) { + __m128i inputs[3], pis[3], keys[3], generated[3]; + inputs[0] = _mm_set1_epi64x(nonce); + inputs[1] = _mm_set1_epi64x(nonce + 1); + inputs[2] = _mm_set1_epi64x(nonce + 2); + pis[0] = _mm_load_si128((__m128i const *)(pi_ptr)); + pis[1] = _mm_load_si128((__m128i const *)(pi_ptr + 2)); + pis[2] = _mm_load_si128((__m128i const *)(pi_ptr + 4)); + keys[0] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[0]); + keys[1] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[1]); + keys[2] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[2]); + generated[0] = _mm_aesenc_si128(inputs[0], keys[0]); + generated[1] = _mm_aesenc_si128(inputs[1], keys[1]); + generated[2] = _mm_aesenc_si128(inputs[2], keys[2]); + // The first store can easily be vectorized, but the second can be serial for now + _mm_storeu_si128((__m128i *)text, generated[0]); + _mm_storeu_si128((__m128i *)(text + 16), generated[1]); + for (sz_size_t i = 32; i < length; ++i) text[i] = ((sz_u8_t *)generated)[i]; + } + // The final part of the function is the primary loop, which processes the buffer in 64-byte chunks. + else { + __m128i inputs[4], pis[4], keys[4], generated[4]; + inputs[0] = _mm_set1_epi64x(nonce); + inputs[1] = _mm_set1_epi64x(nonce + 1); + inputs[2] = _mm_set1_epi64x(nonce + 2); + inputs[3] = _mm_set1_epi64x(nonce + 3); + // Load parts of PI into the registers + pis[0] = _mm_load_si128((__m128i const *)(pi_ptr)); + pis[1] = _mm_load_si128((__m128i const *)(pi_ptr + 2)); + pis[2] = _mm_load_si128((__m128i const *)(pi_ptr + 4)); + pis[3] = _mm_load_si128((__m128i const *)(pi_ptr + 6)); + // XOR the nonce with the PI constants + keys[0] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[0]); + keys[1] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[1]); + keys[2] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[2]); + keys[3] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[3]); + + // Produce the output, fixing the key and enumerating input chunks. + sz_size_t i = 0; + __m128i const increment = _mm_set1_epi64x(4); + for (; i + 64 <= length; i += 64) { + generated[0] = _mm_aesenc_si128(inputs[0], keys[0]); + generated[1] = _mm_aesenc_si128(inputs[1], keys[1]); + generated[2] = _mm_aesenc_si128(inputs[2], keys[2]); + generated[3] = _mm_aesenc_si128(inputs[3], keys[3]); + _mm_storeu_si128((__m128i *)(text + i), generated[0]); + _mm_storeu_si128((__m128i *)(text + i + 16), generated[1]); + _mm_storeu_si128((__m128i *)(text + i + 32), generated[2]); + _mm_storeu_si128((__m128i *)(text + i + 48), generated[3]); + inputs[0] = _mm_add_epi64(inputs[0], increment); + inputs[1] = _mm_add_epi64(inputs[1], increment); + inputs[2] = _mm_add_epi64(inputs[2], increment); + inputs[3] = _mm_add_epi64(inputs[3], increment); + } + + // Handle the tail of the buffer. + { + generated[0] = _mm_aesenc_si128(inputs[0], keys[0]); + generated[1] = _mm_aesenc_si128(inputs[1], keys[1]); + generated[2] = _mm_aesenc_si128(inputs[2], keys[2]); + generated[3] = _mm_aesenc_si128(inputs[3], keys[3]); + for (sz_size_t j = 0; i < length; ++i, ++j) text[i] = ((sz_u8_t *)generated)[j]; + } + } } #pragma clang attribute pop @@ -1280,6 +1382,7 @@ SZ_PUBLIC void sz_hash_state_stream_skylake(sz_hash_state_t *state, sz_cptr_t te } SZ_PUBLIC sz_u64_t sz_hash_state_fold_skylake(sz_hash_state_t const *state) { + //? We don't know a better way to fold the state on Ice Lake, than to use the Haswell implementation. return sz_hash_state_fold_haswell(state); } @@ -1539,58 +1642,62 @@ SZ_PUBLIC void sz_hash_state_stream_ice(sz_hash_state_t *state, sz_cptr_t text, } } -SZ_PUBLIC sz_u64_t sz_hash_state_fold_ice(sz_hash_state_t const *state) { return sz_hash_state_fold_haswell(state); } +SZ_PUBLIC sz_u64_t sz_hash_state_fold_ice(sz_hash_state_t const *state) { + //? We don't know a better way to fold the state on Ice Lake, than to use the Haswell implementation. + return sz_hash_state_fold_haswell(state); +} SZ_PUBLIC void sz_generate_ice(sz_ptr_t output, sz_size_t length, sz_u64_t nonce) { - // We can use `_mm512_broadcast_i32x4` and the `vbroadcasti32x4` instruction, but its latency is freaking 8 cycles. - // The `_mm512_shuffle_i32x4` and the `vshufi32x4` instruction has a latency of 3 cycles, somewhat better. - // The `_mm512_permutex_epi64` and the `vpermq` instruction also has a latency of 3 cycles. - // So we want to avoid that, if possible. - __m128i nonce_vec = _mm_set1_epi64x(nonce); - __m128i key128 = _mm_xor_si128(nonce_vec, _mm_set_epi64x(0x13198a2e03707344ull, 0x243f6a8885a308d3ull)); if (length <= 16) { - __mmask16 mask = _sz_u16_mask_until(length); __m128i input = _mm_set1_epi64x(nonce); - __m128i generated = _mm_aesenc_si128(input, key128); - _mm_mask_storeu_epi8((void *)output, mask, generated); + __m128i pi = _mm_load_si128((__m128i const *)_sz_hash_pi_constants()); + __m128i key = _mm_xor_si128(_mm_set1_epi64x(nonce), pi); + __m128i generated = _mm_aesenc_si128(input, key); + __mmask16 store_mask = _sz_u16_mask_until(length); + _mm_mask_storeu_epi8((void *)output, store_mask, generated); } // Assuming the YMM register contains two 128-bit blocks, the input to the generator // will be more complex, containing the sum of the nonce and the block number. else if (length <= 32) { - __mmask32 mask = _sz_u32_mask_until(length); __m256i input = _mm256_set_epi64x(nonce + 1, nonce + 1, nonce, nonce); - __m256i key256 = - _mm256_permute2x128_si256(_mm256_castsi128_si256(key128), _mm256_castsi128_si256(key128), 0x00); - __m256i generated = _mm256_aesenc_epi128(input, key256); - _mm256_mask_storeu_epi8((void *)output, mask, generated); + __m256i pi = _mm256_load_si256((__m256i const *)_sz_hash_pi_constants()); + __m256i key = _mm256_xor_si256(_mm256_set1_epi64x(nonce), pi); + __m256i generated = _mm256_aesenc_epi128(input, key); + __mmask32 store_mask = _sz_u32_mask_until(length); + _mm256_mask_storeu_epi8((void *)output, store_mask, generated); } // The last special case we handle outside of the primary loop is for buffers up to 64 bytes long. else if (length <= 64) { - __mmask64 mask = _sz_u64_mask_until(length); __m512i input = _mm512_set_epi64( // nonce + 3, nonce + 3, nonce + 2, nonce + 2, // nonce + 1, nonce + 1, nonce, nonce); - __m512i key512 = _mm512_permutex_epi64(_mm512_castsi128_si512(key128), 0x00); - __m512i generated = _mm512_aesenc_epi128(input, key512); - _mm512_mask_storeu_epi8((void *)output, mask, generated); + __m512i pi = _mm512_load_si512((__m512i const *)_sz_hash_pi_constants()); + __m512i key = _mm512_xor_si512(_mm512_set1_epi64(nonce), pi); + __m512i generated = _mm512_aesenc_epi128(input, key); + __mmask64 store_mask = _sz_u64_mask_until(length); + _mm512_mask_storeu_epi8((void *)output, store_mask, generated); } // The final part of the function is the primary loop, which processes the buffer in 64-byte chunks. else { - __m512i increment = _mm512_set1_epi64(4); + __m512i const increment = _mm512_set1_epi64(4); __m512i input = _mm512_set_epi64( // nonce + 3, nonce + 3, nonce + 2, nonce + 2, // nonce + 1, nonce + 1, nonce, nonce); - __m512i key512 = _mm512_permutex_epi64(_mm512_castsi128_si512(key128), 0x00); + __m512i const pi = _mm512_load_si512((__m512i const *)_sz_hash_pi_constants()); + __m512i const key = _mm512_xor_si512(_mm512_set1_epi64(nonce), pi); + + // Produce the output, fixing the key and enumerating input chunks. sz_size_t i = 0; for (; i + 64 <= length; i += 64) { - __m512i generated = _mm512_aesenc_epi128(input, key512); + __m512i generated = _mm512_aesenc_epi128(input, key); _mm512_storeu_epi8((void *)(output + i), generated); input = _mm512_add_epi64(input, increment); } + // Handle the tail of the buffer. - __mmask64 mask = _sz_u64_mask_until(length - i); - __m512i generated = _mm512_aesenc_epi128(input, key512); - _mm512_mask_storeu_epi8((void *)(output + i), mask, generated); + __m512i generated = _mm512_aesenc_epi128(input, key); + __mmask64 store_mask = _sz_u64_mask_until(length - i); + _mm512_mask_storeu_epi8((void *)(output + i), store_mask, generated); } } diff --git a/scripts/test.cpp b/scripts/test.cpp index 13090d8c..8465cf15 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -150,8 +150,10 @@ static void test_arithmetical_utilities() { } /** - * @brief Several string processing operations rely on computing integer logarithms. - * Failures in such operations will result in wrong `resize` outcomes and heap corruption. + * @brief Hashes a string and compares the output between a serial and hardware-specific SIMD backend. + * + * The test covers increasingly long and complex strings, starting with "abcabc..." repetitions and + * progressing towards corner cases like empty strings, all-zero inputs, zero seeds, and so on. */ static void test_hashing_on_platform( // sz_hash_t hash_base, sz_hash_state_init_t init_base, // @@ -196,13 +198,40 @@ static void test_hashing_on_platform( // test_on_seed(repeat("abc", copies), seed); } -static void test_hashing_across_platforms() { +/** + * @brief Tests Pseudo-Random Number Generators (PRNGs) ensuring that the same nonce + * produces exactly the same output across different SIMD implementations. + */ +static void test_random_generator_on_platform(sz_generate_t generate_base, sz_generate_t generate_simd) { + + auto test_on_nonce = [&](std::size_t length, sz_u64_t nonce) { + std::string text_base(length, '\0'); + std::string text_simd(length, '\0'); + generate_base(&text_base[0], static_cast(length), nonce); + generate_simd(&text_simd[0], static_cast(length), nonce); + assert(text_base == text_simd); + }; + + // Let's try different nonces: + std::vector nonces = { + 0u, 42u, // + std::numeric_limits::max(), // + std::numeric_limits::max(), // + }; + std::vector lengths = {1, 11, 23, 37, 40, 51, 64, 128, 1000}; + for (auto nonce : nonces) + for (auto length : lengths) // + test_on_nonce(length, nonce); +} + +static void test_simd_against_serial() { #if SZ_USE_HASWELL test_hashing_on_platform( // sz_hash_serial, sz_hash_state_init_serial, // sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_haswell, sz_hash_state_init_haswell, // sz_hash_state_stream_haswell, sz_hash_state_fold_haswell); + test_random_generator_on_platform(sz_generate_serial, sz_generate_haswell); #endif #if SZ_USE_SKYLAKE test_hashing_on_platform( // @@ -210,6 +239,7 @@ static void test_hashing_across_platforms() { sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_skylake, sz_hash_state_init_skylake, // sz_hash_state_stream_skylake, sz_hash_state_fold_skylake); + test_random_generator_on_platform(sz_generate_serial, sz_generate_skylake); #endif #if SZ_USE_ICE test_hashing_on_platform( // @@ -217,6 +247,7 @@ static void test_hashing_across_platforms() { sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_ice, sz_hash_state_init_ice, // sz_hash_state_stream_ice, sz_hash_state_fold_ice); + test_random_generator_on_platform(sz_generate_serial, sz_generate_ice); #endif #if SZ_USE_NEON test_hashing_on_platform( // @@ -224,6 +255,7 @@ static void test_hashing_across_platforms() { sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_neon, sz_hash_state_init_neon, // sz_hash_state_stream_neon, sz_hash_state_fold_neon); + test_random_generator_on_platform(sz_generate_serial, sz_generate_neon); #endif }; @@ -1800,9 +1832,7 @@ int main(int argc, char const **argv) { // Basic utilities test_arithmetical_utilities(); - - // Compatibility across hardware-specific implementations - test_hashing_across_platforms(); + test_simd_against_serial(); // Core APIs test_ascii_utilities(); From 2bbafa111426cab613053665a67cd41d44154920 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:55:19 +0000 Subject: [PATCH 132/751] Make: Drop unused `build.sh` --- scripts/build.sh | 65 ------------------------------------------------ 1 file changed, 65 deletions(-) delete mode 100755 scripts/build.sh diff --git a/scripts/build.sh b/scripts/build.sh deleted file mode 100755 index 600e5758..00000000 --- a/scripts/build.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash -# This Bash script compiles the CMake-based project with different compilers for different verrsions of C++ -# This is what should happen if only GCC 12 is installed and we are running on Sapphire Rapids. -# -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="sandybridge" -B build_release/gcc-12-sandybridge && \ -# cmake --build build_release/gcc-12-sandybridge --config Release -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="haswell" -B build_release/gcc-12-haswell && \ -# cmake --build build_release/gcc-12-haswell --config Release -# cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -# -DCMAKE_CXX_COMPILER=g++-12 -DCMAKE_C_COMPILER=gcc-12 \ -# -DSTRINGZILLA_TARGET_ARCH="sapphirerapids" -B build_release/gcc-12-sapphirerapids && \ -# cmake --build build_release/gcc-12-sapphirerapids --config Release - -# Array of target architectures -declare -a architectures=("sandybridge" "haswell" "sapphirerapids") - -# Function to get installed versions of a compiler -get_versions() { - local compiler_prefix=$1 - local versions=() - - echo "Checking for compilers in /usr/bin with prefix: $compiler_prefix" - - # Check if the directory /usr/bin exists and is a directory - if [ -d "/usr/bin" ]; then - for version in /usr/bin/${compiler_prefix}-*; do - echo "Checking: $version" - if [[ -x "$version" ]]; then - local ver=${version##*-} - echo "Found compiler version: $ver" - versions+=("$ver") - fi - done - else - echo "/usr/bin does not exist or is not a directory" - fi - - echo ${versions[@]} -} - -# Get installed versions of GCC and Clang -gcc_versions=$(get_versions gcc) -clang_versions=$(get_versions clang) - -# Compile for each combination of compiler and architecture -for arch in "${ARCHS[@]}"; do - for gcc_version in $gcc_versions; do - cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ - -DCMAKE_CXX_COMPILER=g++-$gcc_version -DCMAKE_C_COMPILER=gcc-$gcc_version \ - -DSTRINGZILLA_TARGET_ARCH="$arch" -B "build_release/gcc-$gcc_version-$arch" && \ - cmake --build "build_release/gcc-$gcc_version-$arch" --config Release - done - - for clang_version in $clang_versions; do - cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ - -DCMAKE_CXX_COMPILER=clang++-$clang_version -DCMAKE_C_COMPILER=clang-$clang_version \ - -DSTRINGZILLA_TARGET_ARCH="$arch" -B "build_release/clang-$clang_version-$arch" && \ - cmake --build "build_release/clang-$clang_version-$arch" --config Release - done -done - From 3538e970e09d9b19dc52ed93bd1508cd0bb9d6d6 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 28 Feb 2025 14:43:27 +0000 Subject: [PATCH 133/751] Add: Fetching dynamic library version in C Added `sz_version_major`, `sz_version_minor`, and `sz_version_patch` APIs --- .github/workflows/prerelease.yml | 6 +++--- .github/workflows/release.yml | 6 +++--- c/lib.c | 4 ++++ include/stringzilla/stringzilla.h | 20 +++++++++++++++++--- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index 57514b79..82cace09 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -37,11 +37,11 @@ jobs: package.json:"version": "(\d+\.\d+\.\d+)" CMakeLists.txt:VERSION (\d+\.\d+\.\d+) update-major-version-in: | - include/stringzilla/stringzilla.h:^#define STRINGZILLA_VERSION_MAJOR (\d+) + include/stringzilla/stringzilla.h:^#define STRINGZILLA_H_VERSION_MAJOR (\d+) update-minor-version-in: | - include/stringzilla/stringzilla.h:^#define STRINGZILLA_VERSION_MINOR (\d+) + include/stringzilla/stringzilla.h:^#define STRINGZILLA_H_VERSION_MINOR (\d+) update-patch-version-in: | - include/stringzilla/stringzilla.h:^#define STRINGZILLA_VERSION_PATCH (\d+) + include/stringzilla/stringzilla.h:^#define STRINGZILLA_H_VERSION_PATCH (\d+) dry-run: "true" test_ubuntu_gcc: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6a726b14..1c95500b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -37,11 +37,11 @@ jobs: package.json:"version": "(\d+\.\d+\.\d+)" CMakeLists.txt:VERSION (\d+\.\d+\.\d+) update-major-version-in: | - include/stringzilla/stringzilla.h:^#define STRINGZILLA_VERSION_MAJOR (\d+) + include/stringzilla/stringzilla.h:^#define STRINGZILLA_H_VERSION_MAJOR (\d+) update-minor-version-in: | - include/stringzilla/stringzilla.h:^#define STRINGZILLA_VERSION_MINOR (\d+) + include/stringzilla/stringzilla.h:^#define STRINGZILLA_H_VERSION_MINOR (\d+) update-patch-version-in: | - include/stringzilla/stringzilla.h:^#define STRINGZILLA_VERSION_PATCH (\d+) + include/stringzilla/stringzilla.h:^#define STRINGZILLA_H_VERSION_PATCH (\d+) dry-run: "false" push: "true" create-release: "true" diff --git a/c/lib.c b/c/lib.c index b65784cc..555f995a 100644 --- a/c/lib.c +++ b/c/lib.c @@ -391,6 +391,10 @@ BOOL WINAPI _DllMainCRTStartup(HINSTANCE hints, DWORD forward_reason, LPVOID lp) __attribute__((constructor)) static void sz_dispatch_table_init_on_gcc_or_clang(void) { sz_dispatch_table_init(); } #endif +SZ_DYNAMIC int sz_version_major(void) { return STRINGZILLA_H_VERSION_MAJOR; } +SZ_DYNAMIC int sz_version_minor(void) { return STRINGZILLA_H_VERSION_MINOR; } +SZ_DYNAMIC int sz_version_patch(void) { return STRINGZILLA_H_VERSION_PATCH; } + SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length) { return sz_dispatch_table.bytesum(text, length); } SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 660ffa6c..284754bd 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -37,9 +37,9 @@ #ifndef STRINGZILLA_H_ #define STRINGZILLA_H_ -#define STRINGZILLA_VERSION_MAJOR 3 -#define STRINGZILLA_VERSION_MINOR 11 -#define STRINGZILLA_VERSION_PATCH 3 +#define STRINGZILLA_H_VERSION_MAJOR 3 +#define STRINGZILLA_H_VERSION_MINOR 11 +#define STRINGZILLA_H_VERSION_PATCH 3 #include "types.h" // `sz_size_t`, `sz_bool_t`, `sz_ordering_t` #include "compare.h" // `sz_equal`, `sz_order` @@ -79,6 +79,20 @@ typedef enum { */ SZ_DYNAMIC sz_capability_t sz_capabilities(void); +#if defined(SZ_DYNAMIC_DISPATCH) + +SZ_DYNAMIC int sz_version_major(void); +SZ_DYNAMIC int sz_version_minor(void); +SZ_DYNAMIC int sz_version_patch(void); + +#else + +SZ_PUBLIC int sz_version_major(void) { return STRINGZILLA_H_VERSION_MAJOR; } +SZ_PUBLIC int sz_version_minor(void) { return STRINGZILLA_H_VERSION_MINOR; } +SZ_PUBLIC int sz_version_patch(void) { return STRINGZILLA_H_VERSION_PATCH; } + +#endif + #ifdef __cplusplus } #endif // __cplusplus From 2ce2b49efd696992571074954942d0ad7ec85847 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 28 Feb 2025 14:47:49 +0000 Subject: [PATCH 134/751] Break: `charset`/`generate` -> `byteset`/`fill_random` --- c/lib.c | 163 ++++++++------------ include/stringzilla/find.h | 231 ++++++++++++++-------------- include/stringzilla/hash.h | 95 +++++++----- include/stringzilla/similarity.h | 10 +- include/stringzilla/stringzilla.h | 2 +- include/stringzilla/stringzilla.hpp | 178 ++++++++++----------- include/stringzilla/types.h | 75 +++++---- scripts/bench_search.cpp | 42 ++--- scripts/test.cpp | 78 +++++----- 9 files changed, 428 insertions(+), 446 deletions(-) diff --git a/c/lib.c b/c/lib.c index 555f995a..f742ad2b 100644 --- a/c/lib.c +++ b/c/lib.c @@ -94,14 +94,12 @@ SZ_INTERNAL sz_capability_t _sz_capabilities_arm(void) { // - 0b0010: SVE2.1 is implemented // This value must match the existing indicator obtained from ID_AA64PFR0_EL1: unsigned supports_sve2 = ((id_aa64zfr0_el1) & 0xF) >= 1; - unsigned supports_sve2p1 = ((id_aa64zfr0_el1) & 0xF) >= 2; unsigned supports_neon = 1; // NEON is always supported - return (sz_capability_t)( // - (sz_cap_neon_k * (supports_neon)) | // - (sz_cap_sve_k * (supports_sve)) | // - (sz_cap_sve2_k * (supports_sve2)) | // - (sz_cap_sve2p1_k * (supports_sve2p1)) | // + return (sz_capability_t)( // + (sz_cap_neon_k * (supports_neon)) | // + (sz_cap_sve_k * (supports_sve)) | // + (sz_cap_sve2_k * (supports_sve2)) | // (sz_cap_serial_k)); #else // if !defined(_SZ_IS_APPLE) && !defined(_SZ_IS_LINUX) @@ -183,7 +181,7 @@ typedef struct sz_implementations_t { sz_hash_state_init_t hash_state_init; sz_hash_state_stream_t hash_state_stream; sz_hash_state_fold_t hash_state_fold; - sz_generate_t generate; + sz_fill_random_t fill_random; sz_find_byte_t find_byte; sz_find_byte_t rfind_byte; @@ -196,9 +194,8 @@ typedef struct sz_implementations_t { sz_needleman_wunsch_score_t alignment_score; sz_sequence_argsort_t sequence_argsort; + sz_sequence_join_t sequence_join; sz_pgrams_sort_t pgrams_sort; - sz_sequence_argsort_stable_t sequence_argsort_stable; - sz_pgrams_sort_stable_t pgrams_sort_stable; } sz_implementations_t; @@ -229,22 +226,21 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_serial; impl->hash_state_stream = sz_hash_state_stream_serial; impl->hash_state_fold = sz_hash_state_fold_serial; - impl->generate = sz_generate_serial; + impl->fill_random = sz_fill_random_serial; impl->find = sz_find_serial; impl->rfind = sz_rfind_serial; impl->find_byte = sz_find_byte_serial; impl->rfind_byte = sz_rfind_byte_serial; - impl->find_from_set = sz_find_charset_serial; - impl->rfind_from_set = sz_rfind_charset_serial; + impl->find_from_set = sz_find_byteset_serial; + impl->rfind_from_set = sz_rfind_byteset_serial; impl->edit_distance = sz_levenshtein_distance_serial; impl->alignment_score = sz_needleman_wunsch_score_serial; impl->sequence_argsort = sz_sequence_argsort_serial; + impl->sequence_join = sz_sequence_join_serial; impl->pgrams_sort = sz_pgrams_sort_serial; - impl->sequence_argsort_stable = sz_sequence_argsort_stable_serial; - impl->pgrams_sort_stable = sz_pgrams_sort_stable_serial; #if SZ_USE_HASWELL if (caps & sz_cap_haswell_k) { @@ -261,14 +257,14 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_haswell; impl->hash_state_stream = sz_hash_state_stream_haswell; impl->hash_state_fold = sz_hash_state_fold_haswell; - impl->generate = sz_generate_haswell; + impl->fill_random = sz_fill_random_haswell; impl->find_byte = sz_find_byte_haswell; impl->rfind_byte = sz_rfind_byte_haswell; impl->find = sz_find_haswell; impl->rfind = sz_rfind_haswell; - impl->find_from_set = sz_find_charset_haswell; - impl->rfind_from_set = sz_rfind_charset_haswell; + impl->find_from_set = sz_find_byteset_haswell; + impl->rfind_from_set = sz_rfind_byteset_haswell; } #endif @@ -286,20 +282,24 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_skylake; impl->hash_state_stream = sz_hash_state_stream_skylake; impl->hash_state_fold = sz_hash_state_fold_skylake; - impl->generate = sz_generate_skylake; + impl->fill_random = sz_fill_random_skylake; impl->find = sz_find_skylake; impl->rfind = sz_rfind_skylake; impl->find_byte = sz_find_byte_skylake; impl->rfind_byte = sz_rfind_byte_skylake; impl->bytesum = sz_bytesum_skylake; + + impl->sequence_argsort = sz_sequence_argsort_skylake; + impl->sequence_join = sz_sequence_join_skylake; + impl->pgrams_sort = sz_pgrams_sort_skylake; } #endif #if SZ_USE_ICE if (caps & sz_cap_ice_k) { - impl->find_from_set = sz_find_charset_ice; - impl->rfind_from_set = sz_rfind_charset_ice; + impl->find_from_set = sz_find_byteset_ice; + impl->rfind_from_set = sz_rfind_byteset_ice; impl->edit_distance = sz_levenshtein_distance_ice; impl->alignment_score = sz_needleman_wunsch_score_ice; @@ -311,12 +311,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_ice; impl->hash_state_stream = sz_hash_state_stream_ice; impl->hash_state_fold = sz_hash_state_fold_ice; - impl->generate = sz_generate_ice; - - impl->sequence_argsort = sz_sequence_argsort_ice; - impl->pgrams_sort = sz_pgrams_sort_ice; - impl->sequence_argsort_stable = sz_sequence_argsort_stable_ice; - impl->pgrams_sort_stable = sz_pgrams_sort_stable_ice; + impl->fill_random = sz_fill_random_ice; } #endif @@ -334,23 +329,22 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_init = sz_hash_state_init_neon; impl->hash_state_stream = sz_hash_state_stream_neon; impl->hash_state_fold = sz_hash_state_fold_neon; - impl->generate = sz_generate_neon; + impl->fill_random = sz_fill_random_neon; impl->find = sz_find_neon; impl->rfind = sz_rfind_neon; impl->find_byte = sz_find_byte_neon; impl->rfind_byte = sz_rfind_byte_neon; - impl->find_from_set = sz_find_charset_neon; - impl->rfind_from_set = sz_rfind_charset_neon; + impl->find_from_set = sz_find_byteset_neon; + impl->rfind_from_set = sz_rfind_byteset_neon; } #endif #if SZ_USE_SVE if (caps & sz_cap_sve_k) { impl->sequence_argsort = sz_sequence_argsort_sve; + impl->sequence_join = sz_sequence_join_sve; impl->pgrams_sort = sz_pgrams_sort_sve; - impl->sequence_argsort_stable = sz_sequence_argsort_stable_sve; - impl->pgrams_sort_stable = sz_pgrams_sort_stable_sve; } #endif } @@ -413,8 +407,8 @@ SZ_DYNAMIC sz_u64_t sz_hash_state_fold(sz_hash_state_t const *state) { return sz_dispatch_table.hash_state_fold(state); } -SZ_DYNAMIC void sz_generate(sz_ptr_t result, sz_size_t result_length, sz_u64_t nonce) { - sz_dispatch_table.generate(result, result_length, nonce); +SZ_DYNAMIC void sz_fill_random(sz_ptr_t result, sz_size_t result_length, sz_u64_t nonce) { + sz_dispatch_table.fill_random(result, result_length, nonce); } SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { @@ -457,51 +451,47 @@ SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t return sz_dispatch_table.rfind(haystack, h_length, needle, n_length); } -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +SZ_DYNAMIC sz_cptr_t sz_find_byteset(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set) { return sz_dispatch_table.find_from_set(text, length, set); } -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +SZ_DYNAMIC sz_cptr_t sz_rfind_byteset(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set) { return sz_dispatch_table.rfind_from_set(text, length, set); } -SZ_DYNAMIC sz_size_t sz_hamming_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_serial(a, a_length, b, b_length, bound); +SZ_DYNAMIC sz_status_t sz_hamming_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result) { + return sz_hamming_distance_serial(a, a_length, b, b_length, bound, result); } -SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound) { - return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); -} - -SZ_DYNAMIC sz_size_t sz_levenshtein_distance( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return sz_dispatch_table.edit_distance(a, a_length, b, b_length, bound, alloc); +SZ_DYNAMIC sz_status_t sz_hamming_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_size_t *result) { + return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound, result); } -SZ_DYNAMIC sz_size_t sz_levenshtein_distance_utf8( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t *alloc) { - return _sz_levenshtein_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); +SZ_DYNAMIC sz_status_t sz_levenshtein_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result) { + return sz_dispatch_table.edit_distance(a, a_length, b, b_length, bound, alloc, result); } -SZ_DYNAMIC sz_ssize_t sz_needleman_wunsch_score( // - sz_cptr_t a, sz_size_t a_length, // - sz_cptr_t b, sz_size_t b_length, // - sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { - return sz_dispatch_table.alignment_score(a, a_length, b, b_length, subs, gap, alloc); +SZ_DYNAMIC sz_status_t sz_levenshtein_distance_utf8( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc, sz_size_t *result) { + return _sz_levenshtein_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc, result); } -SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *array, sz_memory_allocator_t *alloc, sz_size_t *order) { - return sz_dispatch_table.sequence_argsort(array, alloc, order); +SZ_DYNAMIC sz_status_t sz_needleman_wunsch_score( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc, sz_ssize_t *result) { + return sz_dispatch_table.alignment_score(a, a_length, b, b_length, subs, gap, alloc, result); } SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, @@ -509,44 +499,15 @@ SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *array, sz_size_t count, sz_mem return sz_dispatch_table.pgrams_sort(array, count, alloc, order); } -SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *array, sz_memory_allocator_t *alloc, - sz_size_t *order) { - return sz_dispatch_table.sequence_argsort_stable(array, alloc, order); -} - -SZ_DYNAMIC sz_status_t sz_pgrams_sort_stable(sz_pgram_t *array, sz_size_t count, sz_memory_allocator_t *alloc, - sz_size_t *order) { - return sz_dispatch_table.pgrams_sort_stable(array, count, alloc, order); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); +SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *array, sz_memory_allocator_t *alloc, sz_size_t *order) { + return sz_dispatch_table.sequence_argsort(array, alloc, order); } -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); +SZ_DYNAMIC sz_status_t sz_sequence_join(sz_sequence_t const *first_array, sz_sequence_t const *second_array, + sz_memory_allocator_t *alloc, sz_size_t *intersection_size, + sz_size_t *first_positions, sz_size_t *second_positions) { + return sz_dispatch_table.sequence_join(first_array, second_array, alloc, intersection_size, first_positions, + second_positions); } // Provide overrides for the libc mem* functions @@ -626,7 +587,7 @@ SZ_DYNAMIC void *memrchr(void const *s, int c_wide, size_t n) { SZ_DYNAMIC void memfrob(void *s, size_t n) { static sz_u64_t nonce = 42; - sz_generate(s, n, nonce++); + sz_fill_random(s, n, nonce++); } #endif diff --git a/include/stringzilla/find.h b/include/stringzilla/find.h index 90b6a16f..d3db653e 100644 --- a/include/stringzilla/find.h +++ b/include/stringzilla/find.h @@ -7,14 +7,14 @@ * * - `sz_find` and reverse-order `sz_rfind` * - `sz_find_byte` and reverse-order `sz_rfind_byte` - * - `sz_find_charset` and reverse-order `sz_rfind_charset` + * - `sz_find_byteset` and reverse-order `sz_rfind_byteset` * * Convenience functions for character-set matching: * - * - `sz_find_char_from` - * - `sz_find_char_not_from` - * - `sz_rfind_char_from` - * - `sz_rfind_char_not_from` + * - `sz_find_byte_from` shortcut for `sz_find_byteset` + * - `sz_find_byte_not_from` shortcut for `sz_find_byteset` with inverted set + * - `sz_rfind_byte_from` shortcut for `sz_rfind_byteset` + * - `sz_rfind_byte_not_from` shortcut for `sz_rfind_byteset` with inverted set */ #ifndef STRINGZILLA_FIND_H_ #define STRINGZILLA_FIND_H_ @@ -35,10 +35,10 @@ extern "C" { * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the first match. + * @param[in] haystack Haystack - the string to search in. + * @param[in] h_length Number of bytes in the haystack. + * @param[in] needle Needle - single-byte substring to find. + * @return Address of the first match. */ SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); @@ -48,10 +48,10 @@ SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cpt * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S * Aarch64 implementation: missing * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - single-byte substring to find. - * @return Address of the last match. + * @param[in] haystack Haystack - the string to search in. + * @param[in] h_length Number of bytes in the haystack. + * @param[in] needle Needle - single-byte substring to find. + * @return Address of the last match. */ SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); @@ -86,22 +86,22 @@ SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, s * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. * Similar to `strstr(haystack, needle)` in LibC, but requires known length. * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the first match. + * @param[in] haystack Haystack - the string to search in. + * @param[in] h_length Number of bytes in the haystack. + * @param[in] needle Needle - substring to find. + * @param[in] n_length Number of bytes in the needle. + * @return Address of the first match. */ SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** * @brief Locates the last matching substring. * - * @param haystack Haystack - the string to search in. - * @param h_length Number of bytes in the haystack. - * @param needle Needle - substring to find. - * @param n_length Number of bytes in the needle. - * @return Address of the last match. + * @param[in] haystack Haystack - the string to search in. + * @param[in] h_length Number of bytes in the haystack. + * @param[in] needle Needle - substring to find. + * @param[in] n_length Number of bytes in the needle. + * @return Address of the last match. */ SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); @@ -132,9 +132,9 @@ SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cpt #endif /** - * @brief Finds the first character present from the ::set, present in ::text. + * @brief Finds the first character present from the @p set, present in @p text. * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_rfind_charset. + * May have identical implementation and performance to ::sz_rfind_byteset. * * Useful for parsing, when we want to skip a set of characters. Examples: * - 6 whitespaces: " \t\n\r\v\f". @@ -142,16 +142,16 @@ SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cpt * - 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. * - 2 JSON string special characters useful to locate the end of the string: "\"\\". * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the first matching character from ::set. + * @param[in] text String to be scanned. + * @param[in] set Set of relevant characters. + * @return Pointer to the first matching character from @p set. */ -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +SZ_DYNAMIC sz_cptr_t sz_find_byteset(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set); /** - * @brief Finds the last character present from the ::set, present in ::text. + * @brief Finds the last character present from the @p set, present in @p text. * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. - * May have identical implementation and performance to ::sz_find_charset. + * May have identical implementation and performance to ::sz_find_byteset. * * Useful for parsing, when we want to skip a set of characters. Examples: * - 6 whitespaces: " \t\n\r\v\f". @@ -159,40 +159,74 @@ SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charse * - 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. * - 2 JSON string special characters useful to locate the end of the string: "\"\\". * - * @param text String to be scanned. - * @param set Set of relevant characters. - * @return Pointer to the last matching character from ::set. + * @param[in] text String to be scanned. + * @param[in] set Set of relevant characters. + * @return Pointer to the last matching character from @p set. */ -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +SZ_DYNAMIC sz_cptr_t sz_rfind_byteset(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set); -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_find_byteset */ +SZ_PUBLIC sz_cptr_t sz_find_byteset_serial(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set); +/** @copydoc sz_rfind_byteset */ +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_serial(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set); #if SZ_USE_HASWELL -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_find_byteset */ +SZ_PUBLIC sz_cptr_t sz_find_byteset_haswell(sz_cptr_t haystack, sz_size_t length, sz_byteset_t const *set); +/** @copydoc sz_rfind_byteset */ +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_haswell(sz_cptr_t haystack, sz_size_t length, sz_byteset_t const *set); #endif #if SZ_USE_ICE -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_find_byteset */ +SZ_PUBLIC sz_cptr_t sz_find_byteset_ice(sz_cptr_t haystack, sz_size_t length, sz_byteset_t const *set); +/** @copydoc sz_rfind_byteset */ +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_ice(sz_cptr_t haystack, sz_size_t length, sz_byteset_t const *set); #endif #if SZ_USE_NEON -/** @copydoc sz_find_charset */ -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); -/** @copydoc sz_rfind_charset */ -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t haystack, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_find_byteset */ +SZ_PUBLIC sz_cptr_t sz_find_byteset_neon(sz_cptr_t haystack, sz_size_t length, sz_byteset_t const *set); +/** @copydoc sz_rfind_byteset */ +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_neon(sz_cptr_t haystack, sz_size_t length, sz_byteset_t const *set); #endif #pragma endregion // Core API +#pragma region Helper Shortcuts + +SZ_PUBLIC sz_cptr_t sz_find_byte_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_byteset_t set; + sz_byteset_init(&set); + for (; n_length; ++n, --n_length) sz_byteset_add(&set, *n); + return sz_find_byteset(h, h_length, &set); +} + +SZ_PUBLIC sz_cptr_t sz_find_byte_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_byteset_t set; + sz_byteset_init(&set); + for (; n_length; ++n, --n_length) sz_byteset_add(&set, *n); + sz_byteset_invert(&set); + return sz_find_byteset(h, h_length, &set); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_byte_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_byteset_t set; + sz_byteset_init(&set); + for (; n_length; ++n, --n_length) sz_byteset_add(&set, *n); + return sz_rfind_byteset(h, h_length, &set); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_byte_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_byteset_t set; + sz_byteset_init(&set); + for (; n_length; ++n, --n_length) sz_byteset_add(&set, *n); + sz_byteset_invert(&set); + return sz_rfind_byteset(h, h_length, &set); +} + +#pragma endregion // Helper Shortcuts + #pragma region Serial Implementation /** @@ -270,18 +304,18 @@ SZ_INTERNAL void _sz_locate_needle_anomalies( // } } -SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +SZ_PUBLIC sz_cptr_t sz_find_byteset_serial(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set) { for (sz_cptr_t const end = text + length; text != end; ++text) - if (sz_charset_contains(set, *text)) return text; + if (sz_byteset_contains(set, *text)) return text; return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_serial(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set) { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Warray-bounds" sz_cptr_t const end = text; for (text += length; text != end;) - if (sz_charset_contains(set, *(text -= 1))) return text; + if (sz_byteset_contains(set, *(text -= 1))) return text; return SZ_NULL_CHAR; #pragma GCC diagnostic pop } @@ -893,7 +927,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_haswell(sz_cptr_t h, sz_size_t h_length, sz_cptr_t return sz_rfind_serial(h, h_length, n, n_length); } -SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_find_byteset_haswell(sz_cptr_t text, sz_size_t length, sz_byteset_t const *filter) { // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. @@ -978,11 +1012,11 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_haswell(sz_cptr_t text, sz_size_t length, sz else { text += 32, length -= 32; } } - return sz_find_charset_serial(text, length, filter); + return sz_find_byteset_serial(text, length, filter); } -SZ_PUBLIC sz_cptr_t sz_rfind_charset_haswell(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_haswell(sz_cptr_t text, sz_size_t length, sz_byteset_t const *filter) { + return sz_rfind_byteset_serial(text, length, filter); } #pragma clang attribute pop @@ -1233,13 +1267,13 @@ SZ_PUBLIC sz_cptr_t sz_rfind_skylake(sz_cptr_t h, sz_size_t h_length, sz_cptr_t __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ apply_to = function) -SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { +SZ_PUBLIC sz_cptr_t sz_find_byteset_ice(sz_cptr_t text, sz_size_t length, sz_byteset_t const *filter) { // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. // In practice, that only hurts, even when we have matches every 5-ish bytes. // - // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); - // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); + // if (length < SZ_SWAR_THRESHOLD) return sz_find_byteset_serial(text, length, filter); + // sz_cptr_t early_result = sz_find_byteset_serial(text, SZ_SWAR_THRESHOLD, filter); // if (early_result) return early_result; // text += SZ_SWAR_THRESHOLD; // length -= SZ_SWAR_THRESHOLD; @@ -1348,8 +1382,8 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_ice(sz_cptr_t text, sz_size_t length, sz_cha return SZ_NULL_CHAR; } -SZ_PUBLIC sz_cptr_t sz_rfind_charset_ice(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { - return sz_rfind_charset_serial(text, length, filter); +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_ice(sz_cptr_t text, sz_size_t length, sz_byteset_t const *filter) { + return sz_rfind_byteset_serial(text, length, filter); } #pragma clang attribute pop @@ -1408,7 +1442,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_ return sz_rfind_byte_serial(h, h_length, n); } -SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register( // +SZ_PUBLIC sz_u64_t _sz_find_byteset_neon_register( // sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, uint8x16_t set_bottom_vec_u8x16) { // Once we've read the characters in the haystack, we want to @@ -1550,7 +1584,7 @@ SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, return sz_rfind_serial(h, h_length, n, n_length); } -SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { +SZ_PUBLIC sz_cptr_t sz_find_byteset_neon(sz_cptr_t h, sz_size_t h_length, sz_byteset_t const *set) { sz_u64_t matches; sz_u128_vec_t h_vec; uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); @@ -1558,27 +1592,27 @@ SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_cha for (; h_length >= 16; h += 16, h_length -= 16) { h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); + matches = _sz_find_byteset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); if (matches) return h + sz_u64_ctz(matches) / 4; } - return sz_find_charset_serial(h, h_length, set); + return sz_find_byteset_serial(h, h_length, set); } -SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { +SZ_PUBLIC sz_cptr_t sz_rfind_byteset_neon(sz_cptr_t h, sz_size_t h_length, sz_byteset_t const *set) { sz_u64_t matches; sz_u128_vec_t h_vec; uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); - // Check `sz_find_charset_neon` for explanations. + // Check `sz_find_byteset_neon` for explanations. for (; h_length >= 16; h_length -= 16) { h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); - matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); + matches = _sz_find_byteset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; } - return sz_rfind_charset_serial(h, h_length, set); + return sz_rfind_byteset_serial(h, h_length, set); } #pragma clang attribute pop @@ -1656,64 +1690,31 @@ SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t #endif } -SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +SZ_DYNAMIC sz_cptr_t sz_find_byteset(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set) { #if SZ_USE_ICE - return sz_find_charset_ice(text, length, set); + return sz_find_byteset_ice(text, length, set); #elif SZ_USE_HASWELL - return sz_find_charset_haswell(text, length, set); + return sz_find_byteset_haswell(text, length, set); #elif SZ_USE_NEON - return sz_find_charset_neon(text, length, set); + return sz_find_byteset_neon(text, length, set); #else - return sz_find_charset_serial(text, length, set); + return sz_find_byteset_serial(text, length, set); #endif } -SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +SZ_DYNAMIC sz_cptr_t sz_rfind_byteset(sz_cptr_t text, sz_size_t length, sz_byteset_t const *set) { #if SZ_USE_ICE - return sz_rfind_charset_ice(text, length, set); + return sz_rfind_byteset_ice(text, length, set); #elif SZ_USE_HASWELL - return sz_rfind_charset_haswell(text, length, set); + return sz_rfind_byteset_haswell(text, length, set); #elif SZ_USE_NEON - return sz_rfind_charset_neon(text, length, set); + return sz_rfind_byteset_neon(text, length, set); #else - return sz_rfind_charset_serial(text, length, set); + return sz_rfind_byteset_serial(text, length, set); #endif } #pragma endregion -#pragma region Helper Shortcuts - -SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_find_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - return sz_rfind_charset(h, h_length, &set); -} - -SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { - sz_charset_t set; - sz_charset_init(&set); - for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); - sz_charset_invert(&set); - return sz_rfind_charset(h, h_length, &set); -} - -#pragma endregion // Helper Shortcuts #endif // !SZ_DYNAMIC_DISPATCH #pragma endregion // Compile Time Dispatching diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 6ef11e3d..e23b700a 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -3,12 +3,12 @@ * @file hash.h * @author Ash Vardanian * - * Includes core APIs: + * Includes core APIs with hardware-specific backends: * * - `sz_bytesum` - for byte-level 64-bit unsigned byte-level checksums. * - `sz_hash` - for 64-bit single-shot hashing using AES instructions. * - `sz_hash_state_init`, `sz_hash_state_stream`, `sz_hash_state_fold` - for incremental hashing. - * - `sz_generate` - for populating buffers with pseudo-random noise using AES instructions. + * - `sz_fill_random` - for populating buffers with pseudo-random noise using AES instructions. * * Why the hell do we need a yet another hashing library?! * Turns out, most existing libraries have noticeable constraints. Try finding a library that: @@ -31,12 +31,12 @@ * - "xxHash" is implemented in C, has an extremely wide set of third-party language bindings, and provides both * 32-, 64-, and 128-bit hashes. It is fast, but its dynamic dispatch is limited to x86 with `xxh_x86dispatch.c`. * - * StringZilla uses a scheme more similar to the "aHash" library, utilizing the AES extensions, that provide + * StringZilla uses a scheme more similar to "aHash" and "GxHash", utilizing the AES extensions, that provide * a remarkable level of "mixing per cycle" and are broadly available on modern CPUs. Similar to "aHash", they * are combined with "shuffle & add" instructions to provide a high level of entropy in the output. That operation * is practically free, as many modern CPUs will dispatch them on different ports. On x86, for example: * - * - `VAESENC` (ZMM, ZMM, ZMM)`: + * - `VAESENC (ZMM, ZMM, ZMM)` and `VAESDEC (ZMM, ZMM, ZMM)`: * - on Intel Ice Lake: 5 cycles on port 0. * - On AMD Zen4: 4 cycles on ports 0 or 1. * - `VPSHUFB_Z (ZMM, K, ZMM, ZMM)` @@ -46,12 +46,16 @@ * - on Intel Ice Lake: 1 cycle on ports 0 or 5. * - On AMD Zen4: 1 cycle on ports 0, 1, 2, 3. * - * Unlike "aHash", the length is not mixed into "AES" block at start to allow incremental construction. - * Unlike "aHash", on long inputs, we use a heavier procedure that is more vector-friendly on modern servers. - * Unlike "aHash", we don't load interleaved memory regions, making vectorized variant more similar to sequential. - * Unlike "aHash", on platforms like Intel Skylake-X or AWS Graviton 3, we use masked loads. - * Unlike "aHash", in final folding procedure, we use the same `VAESENC` instead of `VAESDEC`, which - * still provides the same level of mixing, but allows us to have a lighter serial fallback implementation. + * But there several key differences: + * + * - A larger state and a larger block size is used for inputs over 64 bytes longs, benefiting from wider registers + * on current CPUs. Like many other hash functions, the state is initialized with the seed and a set of Pi constants. + * Unlike others, we pull more Pi bits (1024), but only 64-bits of the seed, to keep the API sane. + * - The length of the input is not mixed into the AES block at the start to allow incremental construction, + * when the final length is not known in advance. + * - The vector-loads are not interleaved, meaning that each byte of input has exactly the same weight in the hash. + * On the implementation side it require some extra shuffling on older platforms, but on newer platforms it + * can be done with "masked" loads in AVX-512 and "predicated" instructions in SVE2. * * @see Reini Urban's more active fork of SMHasher by Austin Appleby: https://github.com/rurban/smhasher * @see The serial AES routines are based on Morten Jensen's "tiny-AES-c": https://github.com/kokke/tiny-AES-c @@ -59,6 +63,16 @@ * @see The "aHash" Rust implementation by Tom Kaitchuck: https://github.com/tkaitchuck/aHash * @see "Emulating x86 AES Intrinsics on ARMv8-A" by Michael Brase: * https://blog.michaelbrase.com/2018/05/08/emulating-x86-aes-intrinsics-on-armv8-a/ + * + * Moreover, the same AES primitives are reused to implement a fast Pseudo-Random Number Generator @b (PRNG) that + * is consistent between different implementation backends and has reproducible output with the same "nonce". + * Originally, the PRNG was designed to produce random byte sequences, but combining it with @b `sz_lookup`, + * one can produce random strings with a given byteset. + * + * Other helpers include: TODO: + * + * - `sz_fill_alphabet` - combines `sz_fill_random` & `sz_lookup` to fill buffers with random ASCII characters. + * - `sz_fill_alphabet_utf8` - combines `sz_fill_random` & `sz_lookup` to fill buffers with random UTF-8 characters. */ #ifndef STRINGZILLA_HASH_H_ #define STRINGZILLA_HASH_H_ @@ -114,7 +128,7 @@ SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length); * @endcode * * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. - * @sa sz_hash_serial, sz_hash_haswell, sz_hash_skylake, sz_hash_ice, sz_hash_neon + * @sa sz_hash_serial, sz_hash_haswell, sz_hash_skylake, sz_hash_ice, sz_hash_neon, sz_hash_sve * * @note The algorithm must provide the same output on all platforms in both single-shot and incremental modes. * @sa sz_hash_state_init, sz_hash_state_stream, sz_hash_state_fold @@ -144,16 +158,17 @@ SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed); * #include * int main() { * char first_buffer[5], second_buffer[5]; - * sz_generate(first_buffer, 5, 0); - * sz_generate(second_buffer, 5, 0); //? Same nonce must produce the same output + * sz_fill_random(first_buffer, 5, 0); + * sz_fill_random(second_buffer, 5, 0); //? Same nonce must produce the same output * return sz_bytesum(first_buffer, 5) == sz_bytesum(second_buffer, 5) ? 0 : 1; * } * @endcode * * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. - * @sa sz_generate_serial, sz_generate_haswell, sz_generate_skylake, sz_generate_ice, sz_generate_neon + * @sa sz_fill_random_serial, sz_fill_random_haswell, sz_fill_random_skylake, sz_fill_random_ice, + * sz_fill_random_neon, sz_fill_random_sve */ -SZ_DYNAMIC void sz_generate(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); +SZ_DYNAMIC void sz_fill_random(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); /** * @brief The state for incremental construction of a hash. @@ -204,8 +219,8 @@ SZ_PUBLIC sz_u64_t sz_bytesum_serial(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length, sz_u64_t seed); -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); +/** @copydoc sz_fill_random */ +SZ_PUBLIC void sz_fill_random_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); /** @copydoc sz_hash_state_init */ SZ_PUBLIC void sz_hash_state_init_serial(sz_hash_state_t *state, sz_u64_t seed); @@ -222,8 +237,8 @@ SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t text, sz_size_t length, sz_u64_t seed); -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); +/** @copydoc sz_fill_random */ +SZ_PUBLIC void sz_fill_random_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); /** @copydoc sz_hash_state_init */ SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed); @@ -240,8 +255,8 @@ SZ_PUBLIC sz_u64_t sz_bytesum_skylake(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_skylake(sz_cptr_t text, sz_size_t length, sz_u64_t seed); -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); +/** @copydoc sz_fill_random */ +SZ_PUBLIC void sz_fill_random_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); /** @copydoc sz_hash_state_init */ SZ_PUBLIC void sz_hash_state_init_skylake(sz_hash_state_t *state, sz_u64_t seed); @@ -258,8 +273,8 @@ SZ_PUBLIC sz_u64_t sz_bytesum_ice(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t text, sz_size_t length, sz_u64_t seed); -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_ice(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); +/** @copydoc sz_fill_random */ +SZ_PUBLIC void sz_fill_random_ice(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); /** @copydoc sz_hash_state_init */ SZ_PUBLIC void sz_hash_state_init_ice(sz_hash_state_t *state, sz_u64_t seed); @@ -276,8 +291,8 @@ SZ_PUBLIC sz_u64_t sz_bytesum_neon(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_neon(sz_cptr_t text, sz_size_t length, sz_u64_t seed); -/** @copydoc sz_generate */ -SZ_PUBLIC void sz_generate_neon(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); +/** @copydoc sz_fill_random */ +SZ_PUBLIC void sz_fill_random_neon(sz_ptr_t text, sz_size_t length, sz_u64_t nonce); /** @copydoc sz_hash_state_init */ SZ_PUBLIC void sz_hash_state_init_neon(sz_hash_state_t *state, sz_u64_t seed); @@ -704,7 +719,7 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_serial(sz_hash_state_t const *state) { } } -SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { +SZ_PUBLIC void sz_fill_random_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { sz_u64_t const *pi_ptr = _sz_hash_pi_constants(); sz_u128_vec_t input_vec, pi_vec, key_vec, generated_vec; for (sz_size_t lane_index = 0; length; ++lane_index) { @@ -728,8 +743,8 @@ SZ_PUBLIC void sz_generate_serial(sz_ptr_t text, sz_size_t length, sz_u64_t nonc #pragma region Haswell Implementation #if SZ_USE_HASWELL #pragma GCC push_options -#pragma GCC target("avx2") -#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) +#pragma GCC target("avx2", "aes") +#pragma clang attribute push(__attribute__((target("avx2,aes"))), apply_to = function) SZ_PUBLIC sz_u64_t sz_bytesum_haswell(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. @@ -1058,7 +1073,7 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state) { } } -SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { +SZ_PUBLIC void sz_fill_random_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { sz_u64_t const *pi_ptr = _sz_hash_pi_constants(); if (length <= 16) { __m128i input = _mm_set1_epi64x(nonce); @@ -1165,8 +1180,8 @@ SZ_PUBLIC void sz_generate_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t non #pragma region Skylake Implementation #if SZ_USE_SKYLAKE #pragma GCC push_options -#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") -#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2", "aes") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2,aes"))), apply_to = function) SZ_PUBLIC sz_u64_t sz_bytesum_skylake(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. @@ -1386,8 +1401,8 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_skylake(sz_hash_state_t const *state) { return sz_hash_state_fold_haswell(state); } -SZ_PUBLIC void sz_generate_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { - sz_generate_serial(text, length, nonce); +SZ_PUBLIC void sz_fill_random_skylake(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_fill_random_serial(text, length, nonce); } #pragma clang attribute pop @@ -1647,7 +1662,7 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_ice(sz_hash_state_t const *state) { return sz_hash_state_fold_haswell(state); } -SZ_PUBLIC void sz_generate_ice(sz_ptr_t output, sz_size_t length, sz_u64_t nonce) { +SZ_PUBLIC void sz_fill_random_ice(sz_ptr_t output, sz_size_t length, sz_u64_t nonce) { if (length <= 16) { __m128i input = _mm_set1_epi64x(nonce); __m128i pi = _mm_load_si128((__m128i const *)_sz_hash_pi_constants()); @@ -1796,17 +1811,17 @@ SZ_DYNAMIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { #endif } -SZ_DYNAMIC void sz_generate(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { +SZ_DYNAMIC void sz_fill_random(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { #if SZ_USE_ICE - sz_generate_ice(text, length, nonce); + sz_fill_random_ice(text, length, nonce); #elif SZ_USE_SKYLAKE - sz_generate_skylake(text, length, nonce); + sz_fill_random_skylake(text, length, nonce); #elif SZ_USE_HASWELL - sz_generate_haswell(text, length, nonce); + sz_fill_random_haswell(text, length, nonce); #elif SZ_USE_NEON - sz_generate_neon(text, length, nonce); + sz_fill_random_neon(text, length, nonce); #else - sz_generate_serial(text, length, nonce); + sz_fill_random_serial(text, length, nonce); #endif } diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index 60540b33..058b1313 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -413,11 +413,11 @@ SZ_INTERNAL sz_status_t _sz_levenshtein_distance_wagner_fisher_serial( // // If the strings contain Unicode characters, let's estimate the max character width, // and use it to allocate a larger buffer to decode UTF8. - sz_charset_t ascii_charset; - sz_charset_init_ascii(&ascii_charset); - sz_charset_invert(&ascii_charset); - int const longer_is_ascii = sz_find_charset_serial(longer, longer_length, &ascii_charset) == SZ_NULL_CHAR; - int const shorter_is_ascii = sz_find_charset_serial(shorter, shorter_length, &ascii_charset) == SZ_NULL_CHAR; + sz_byteset_t ascii_byteset; + sz_byteset_init_ascii(&ascii_byteset); + sz_byteset_invert(&ascii_byteset); + int const longer_is_ascii = sz_find_byteset_serial(longer, longer_length, &ascii_byteset) == SZ_NULL_CHAR; + int const shorter_is_ascii = sz_find_byteset_serial(shorter, shorter_length, &ascii_byteset) == SZ_NULL_CHAR; int const will_convert_to_unicode = can_be_unicode == sz_true_k && (!longer_is_ascii || !shorter_is_ascii); if (will_convert_to_unicode) { buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); } else { can_be_unicode = sz_false_k; } diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 284754bd..7642f5ae 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -45,7 +45,7 @@ #include "compare.h" // `sz_equal`, `sz_order` #include "memory.h" // `sz_copy`, `sz_move`, `sz_fill` #include "hash.h" // `sz_bytesum`, `sz_hash`, `sz_state_init`, `sz_state_stream`, `sz_state_fold` -#include "find.h" // `sz_find`, `sz_find_charset`, `sz_rfind` +#include "find.h" // `sz_find`, `sz_find_byteset`, `sz_rfind` #include "small_string.h" // `sz_string_t`, `sz_string_init`, `sz_string_free` #include "similarity.h" // `sz_levenshtein_distance`, `sz_needleman_wunsch_score` #include "sort.h" // `sz_sequence_argsort`, `sz_pgrams_sort`, `sz_pgrams_sort_stable` diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 143f252e..a1b2de28 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -79,7 +79,7 @@ namespace ashvardanian { namespace stringzilla { template -class basic_char_set; +class basic_byteset; template class basic_string_slice; template @@ -278,23 +278,23 @@ inline carray<64> const &base64() noexcept { * @brief A set of characters represented as a bitset with 256 slots. */ template -class basic_char_set { - sz_charset_t bitset_; +class basic_byteset { + sz_byteset_t bitset_; public: using char_type = char_type_; - sz_constexpr_if_cpp14 basic_char_set() noexcept { - // ! Instead of relying on the `sz_charset_init`, we have to reimplement it to support `constexpr`. + sz_constexpr_if_cpp14 basic_byteset() noexcept { + // ! Instead of relying on the `sz_byteset_init`, we have to reimplement it to support `constexpr`. bitset_._u64s[0] = 0, bitset_._u64s[1] = 0, bitset_._u64s[2] = 0, bitset_._u64s[3] = 0; } - explicit sz_constexpr_if_cpp14 basic_char_set(std::initializer_list chars) noexcept : basic_char_set() { - // ! Instead of relying on the `sz_charset_add(&bitset_, c)`, we have to reimplement it to support `constexpr`. + explicit sz_constexpr_if_cpp14 basic_byteset(std::initializer_list chars) noexcept : basic_byteset() { + // ! Instead of relying on the `sz_byteset_add(&bitset_, c)`, we have to reimplement it to support `constexpr`. for (auto c : chars) bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); } - explicit sz_constexpr_if_cpp14 basic_char_set(char_type const *chars, std::size_t count_characters) noexcept - : basic_char_set() { + explicit sz_constexpr_if_cpp14 basic_byteset(char_type const *chars, std::size_t count_characters) noexcept + : basic_byteset() { for (std::size_t i = 0; i < count_characters; ++i) { char_type c = chars[i]; bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); @@ -302,8 +302,8 @@ class basic_char_set { } template - explicit sz_constexpr_if_cpp14 basic_char_set(std::array const &chars) noexcept - : basic_char_set() { + explicit sz_constexpr_if_cpp14 basic_byteset(std::array const &chars) noexcept + : basic_byteset() { static_assert(count_characters > 0, "Character array cannot be empty"); for (std::size_t i = 0; i < count_characters; ++i) { char_type c = chars[i]; @@ -311,21 +311,21 @@ class basic_char_set { } } - sz_constexpr_if_cpp14 basic_char_set(basic_char_set const &other) noexcept : bitset_(other.bitset_) {} - sz_constexpr_if_cpp14 basic_char_set &operator=(basic_char_set const &other) noexcept { + sz_constexpr_if_cpp14 basic_byteset(basic_byteset const &other) noexcept : bitset_(other.bitset_) {} + sz_constexpr_if_cpp14 basic_byteset &operator=(basic_byteset const &other) noexcept { bitset_ = other.bitset_; return *this; } - constexpr basic_char_set operator|(basic_char_set other) const noexcept { - basic_char_set result = *this; + constexpr basic_byteset operator|(basic_byteset other) const noexcept { + basic_byteset result = *this; result.bitset_._u64s[0] |= other.bitset_._u64s[0], result.bitset_._u64s[1] |= other.bitset_._u64s[1], result.bitset_._u64s[2] |= other.bitset_._u64s[2], result.bitset_._u64s[3] |= other.bitset_._u64s[3]; return result; } - inline basic_char_set &add(char_type c) noexcept { - sz_charset_add(&bitset_, sz_bitcast(sz_u8_t, c)); + inline basic_byteset &add(char_type c) noexcept { + sz_byteset_add(&bitset_, sz_bitcast(sz_u8_t, c)); return *this; } inline std::size_t size() const noexcept { @@ -333,30 +333,30 @@ class basic_char_set { sz_u64_popcount(bitset_._u64s[0]) + sz_u64_popcount(bitset_._u64s[1]) + // sz_u64_popcount(bitset_._u64s[2]) + sz_u64_popcount(bitset_._u64s[3]); } - inline sz_charset_t &raw() noexcept { return bitset_; } - inline sz_charset_t const &raw() const noexcept { return bitset_; } - inline bool contains(char_type c) const noexcept { return sz_charset_contains(&bitset_, sz_bitcast(sz_u8_t, c)); } - inline basic_char_set inverted() const noexcept { - basic_char_set result = *this; - sz_charset_invert(&result.bitset_); + inline sz_byteset_t &raw() noexcept { return bitset_; } + inline sz_byteset_t const &raw() const noexcept { return bitset_; } + inline bool contains(char_type c) const noexcept { return sz_byteset_contains(&bitset_, sz_bitcast(sz_u8_t, c)); } + inline basic_byteset inverted() const noexcept { + basic_byteset result = *this; + sz_byteset_invert(&result.bitset_); return result; } }; -using char_set = basic_char_set; - -inline char_set ascii_letters_set() { return char_set {ascii_letters(), sizeof(ascii_letters())}; } -inline char_set ascii_lowercase_set() { return char_set {ascii_lowercase(), sizeof(ascii_lowercase())}; } -inline char_set ascii_uppercase_set() { return char_set {ascii_uppercase(), sizeof(ascii_uppercase())}; } -inline char_set ascii_printables_set() { return char_set {ascii_printables(), sizeof(ascii_printables())}; } -inline char_set ascii_controls_set() { return char_set {ascii_controls(), sizeof(ascii_controls())}; } -inline char_set digits_set() { return char_set {digits(), sizeof(digits())}; } -inline char_set hexdigits_set() { return char_set {hexdigits(), sizeof(hexdigits())}; } -inline char_set octdigits_set() { return char_set {octdigits(), sizeof(octdigits())}; } -inline char_set punctuation_set() { return char_set {punctuation(), sizeof(punctuation())}; } -inline char_set whitespaces_set() { return char_set {whitespaces(), sizeof(whitespaces())}; } -inline char_set newlines_set() { return char_set {newlines(), sizeof(newlines())}; } -inline char_set base64_set() { return char_set {base64(), sizeof(base64())}; } +using byteset = basic_byteset; + +inline byteset ascii_letters_set() { return byteset {ascii_letters(), sizeof(ascii_letters())}; } +inline byteset ascii_lowercase_set() { return byteset {ascii_lowercase(), sizeof(ascii_lowercase())}; } +inline byteset ascii_uppercase_set() { return byteset {ascii_uppercase(), sizeof(ascii_uppercase())}; } +inline byteset ascii_printables_set() { return byteset {ascii_printables(), sizeof(ascii_printables())}; } +inline byteset ascii_controls_set() { return byteset {ascii_controls(), sizeof(ascii_controls())}; } +inline byteset digits_set() { return byteset {digits(), sizeof(digits())}; } +inline byteset hexdigits_set() { return byteset {hexdigits(), sizeof(hexdigits())}; } +inline byteset octdigits_set() { return byteset {octdigits(), sizeof(octdigits())}; } +inline byteset punctuation_set() { return byteset {punctuation(), sizeof(punctuation())}; } +inline byteset whitespaces_set() { return byteset {whitespaces(), sizeof(whitespaces())}; } +inline byteset newlines_set() { return byteset {newlines(), sizeof(newlines())}; } +inline byteset base64_set() { return byteset {base64(), sizeof(base64())}; } /** * @brief A look-up table for character replacement operations. @@ -1667,10 +1667,10 @@ class basic_string_slice { } /** @brief Find the first occurrence of a character from a set. */ - size_type find(char_set set) const noexcept { return find_first_of(set); } + size_type find(byteset set) const noexcept { return find_first_of(set); } /** @brief Find the last occurrence of a character from a set. */ - size_type rfind(char_set set) const noexcept { return find_last_of(set); } + size_type rfind(byteset set) const noexcept { return find_last_of(set); } #pragma endregion #pragma region Returning Partitions @@ -1682,7 +1682,7 @@ class basic_string_slice { partition_type partition(value_type pattern) const noexcept { return partition_(string_view(&pattern, 1), 1); } /** @brief Split the string into three parts, before the match, the match itself, and after it. */ - partition_type partition(char_set pattern) const noexcept { return partition_(pattern, 1); } + partition_type partition(byteset pattern) const noexcept { return partition_(pattern, 1); } /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ partition_type rpartition(string_view pattern) const noexcept { return rpartition_(pattern, pattern.length()); } @@ -1691,7 +1691,7 @@ class basic_string_slice { partition_type rpartition(value_type pattern) const noexcept { return rpartition_(string_view(&pattern, 1), 1); } /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ - partition_type rpartition(char_set pattern) const noexcept { return rpartition_(pattern, 1); } + partition_type rpartition(byteset pattern) const noexcept { return rpartition_(pattern, 1); } #pragma endregion #pragma endregion @@ -1699,7 +1699,7 @@ class basic_string_slice { #pragma region Matching Character Sets // `isascii` is a macro in MSVC headers - bool contains_only(char_set set) const noexcept { return find_first_not_of(set) == npos; } + bool contains_only(byteset set) const noexcept { return find_first_not_of(set) == npos; } bool is_alpha() const noexcept { return !empty() && contains_only(ascii_letters_set()); } bool is_alnum() const noexcept { return !empty() && contains_only(ascii_letters_set() | digits_set()); } bool is_ascii() const noexcept { return empty() || contains_only(ascii_controls_set() | ascii_printables_set()); } @@ -1715,8 +1715,8 @@ class basic_string_slice { * @param skip Number of characters to skip before the search. * @warning The behavior is @b undefined if `skip > size()`. */ - size_type find_first_of(char_set set, size_type skip = 0) const noexcept { - auto ptr = sz_find_charset(start_ + skip, length_ - skip, &set.raw()); + size_type find_first_of(byteset set, size_type skip = 0) const noexcept { + auto ptr = sz_find_byteset(start_ + skip, length_ - skip, &set.raw()); return ptr ? ptr - start_ : npos; } @@ -1725,30 +1725,30 @@ class basic_string_slice { * @param skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ - size_type find_first_not_of(char_set set, size_type skip = 0) const noexcept { + size_type find_first_not_of(byteset set, size_type skip = 0) const noexcept { return find_first_of(set.inverted(), skip); } /** * @brief Find the last occurrence of a character from a set. */ - size_type find_last_of(char_set set) const noexcept { - auto ptr = sz_rfind_charset(start_, length_, &set.raw()); + size_type find_last_of(byteset set) const noexcept { + auto ptr = sz_rfind_byteset(start_, length_, &set.raw()); return ptr ? ptr - start_ : npos; } /** * @brief Find the last occurrence of a character outside a set. */ - size_type find_last_not_of(char_set set) const noexcept { return find_last_of(set.inverted()); } + size_type find_last_not_of(byteset set) const noexcept { return find_last_of(set.inverted()); } /** * @brief Find the last occurrence of a character from a set. * @param until The offset of the last character to be considered. */ - size_type find_last_of(char_set set, size_type until) const noexcept { + size_type find_last_of(byteset set, size_type until) const noexcept { auto len = sz_min_of_two(until + 1, length_); - auto ptr = sz_rfind_charset(start_, len, &set.raw()); + auto ptr = sz_rfind_byteset(start_, len, &set.raw()); return ptr ? ptr - start_ : npos; } @@ -1756,7 +1756,7 @@ class basic_string_slice { * @brief Find the last occurrence of a character outside a set. * @param until The offset of the last character to be considered. */ - size_type find_last_not_of(char_set set, size_type until) const noexcept { + size_type find_last_not_of(byteset set, size_type until) const noexcept { return find_last_of(set.inverted(), until); } @@ -1839,9 +1839,9 @@ class basic_string_slice { * @brief Python-like convenience function, dropping prefix formed of given characters. * Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. */ - string_slice lstrip(char_set set) const noexcept { + string_slice lstrip(byteset set) const noexcept { set = set.inverted(); - auto new_start = (pointer)sz_find_charset(start_, length_, &set.raw()); + auto new_start = (pointer)sz_find_byteset(start_, length_, &set.raw()); return new_start ? string_slice {new_start, length_ - static_cast(new_start - start_)} : string_slice(); } @@ -1850,9 +1850,9 @@ class basic_string_slice { * @brief Python-like convenience function, dropping suffix formed of given characters. * Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. */ - string_slice rstrip(char_set set) const noexcept { + string_slice rstrip(byteset set) const noexcept { set = set.inverted(); - auto new_end = (pointer)sz_rfind_charset(start_, length_, &set.raw()); + auto new_end = (pointer)sz_rfind_byteset(start_, length_, &set.raw()); return new_end ? string_slice {start_, static_cast(new_end - start_ + 1)} : string_slice(); } @@ -1860,12 +1860,12 @@ class basic_string_slice { * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. * Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. */ - string_slice strip(char_set set) const noexcept { + string_slice strip(byteset set) const noexcept { set = set.inverted(); - auto new_start = (pointer)sz_find_charset(start_, length_, &set.raw()); + auto new_start = (pointer)sz_find_byteset(start_, length_, &set.raw()); return new_start ? string_slice {new_start, static_cast( - sz_rfind_charset(new_start, length_ - (new_start - start_), &set.raw()) - + sz_rfind_byteset(new_start, length_ - (new_start - start_), &set.raw()) - new_start + 1)} : string_slice(); } @@ -1881,8 +1881,8 @@ class basic_string_slice { using find_disjoint_type = range_matches>; using rfind_disjoint_type = range_rmatches>; - using find_all_chars_type = range_matches>; - using rfind_all_chars_type = range_rmatches>; + using find_all_chars_type = range_matches>; + using rfind_all_chars_type = range_rmatches>; /** @brief Find all potentially @b overlapping occurrences of a given string. */ find_all_type find_all(string_view needle, include_overlaps_type = {}) const noexcept { return {*this, needle}; } @@ -1897,16 +1897,16 @@ class basic_string_slice { rfind_disjoint_type rfind_all(string_view needle, exclude_overlaps_type) const noexcept { return {*this, needle}; } /** @brief Find all occurrences of given characters. */ - find_all_chars_type find_all(char_set set) const noexcept { return {*this, {set}}; } + find_all_chars_type find_all(byteset set) const noexcept { return {*this, {set}}; } /** @brief Find all occurrences of given characters in @b reverse order. */ - rfind_all_chars_type rfind_all(char_set set) const noexcept { return {*this, {set}}; } + rfind_all_chars_type rfind_all(byteset set) const noexcept { return {*this, {set}}; } using split_type = range_splits>; using rsplit_type = range_rsplits>; - using split_chars_type = range_splits>; - using rsplit_chars_type = range_rsplits>; + using split_chars_type = range_splits>; + using rsplit_chars_type = range_rsplits>; /** @brief Split around occurrences of a given string. */ split_type split(string_view delimiter) const noexcept { return {*this, delimiter}; } @@ -1915,10 +1915,10 @@ class basic_string_slice { rsplit_type rsplit(string_view delimiter) const noexcept { return {*this, delimiter}; } /** @brief Split around occurrences of given characters. */ - split_chars_type split(char_set set = whitespaces_set()) const noexcept { return {*this, {set}}; } + split_chars_type split(byteset set = whitespaces_set()) const noexcept { return {*this, {set}}; } /** @brief Split around occurrences of given characters in @b reverse order. */ - rsplit_chars_type rsplit(char_set set = whitespaces_set()) const noexcept { return {*this, {set}}; } + rsplit_chars_type rsplit(byteset set = whitespaces_set()) const noexcept { return {*this, {set}}; } /** @brief Split around the occurrences of all newline characters. */ split_chars_type splitlines() const noexcept { return split(newlines_set()); } @@ -1934,8 +1934,8 @@ class basic_string_slice { size_type bytesum() const noexcept { return static_cast(sz_bytesum(start_, length_)); } /** @brief Populate a character set with characters present in this string. */ - char_set as_set() const noexcept { - char_set set; + byteset as_set() const noexcept { + byteset set; for (auto c : *this) set.add(c); return set; } @@ -2555,17 +2555,17 @@ class basic_string { } /** @brief Find the first occurrence of a character from a set. */ - size_type find(char_set set) const noexcept { return view().find(set); } + size_type find(byteset set) const noexcept { return view().find(set); } /** @brief Find the last occurrence of a character from a set. */ - size_type rfind(char_set set) const noexcept { return view().rfind(set); } + size_type rfind(byteset set) const noexcept { return view().rfind(set); } #pragma endregion #pragma endregion #pragma region Matching Character Sets - bool contains_only(char_set set) const noexcept { return find_first_not_of(set) == npos; } + bool contains_only(byteset set) const noexcept { return find_first_not_of(set) == npos; } bool is_alpha() const noexcept { return !empty() && contains_only(ascii_letters_set()); } bool is_alnum() const noexcept { return !empty() && contains_only(ascii_letters_set() | digits_set()); } bool is_ascii() const noexcept { return empty() || contains_only(ascii_controls_set() | ascii_printables_set()); } @@ -2583,38 +2583,38 @@ class basic_string { * @param skip Number of characters to skip before the search. * @warning The behavior is @b undefined if `skip > size()`. */ - size_type find_first_of(char_set set, size_type skip = 0) const noexcept { return view().find_first_of(set, skip); } + size_type find_first_of(byteset set, size_type skip = 0) const noexcept { return view().find_first_of(set, skip); } /** * @brief Find the first occurrence of a character outside a set. * @param skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ - size_type find_first_not_of(char_set set, size_type skip = 0) const noexcept { + size_type find_first_not_of(byteset set, size_type skip = 0) const noexcept { return view().find_first_not_of(set, skip); } /** * @brief Find the last occurrence of a character from a set. */ - size_type find_last_of(char_set set) const noexcept { return view().find_last_of(set); } + size_type find_last_of(byteset set) const noexcept { return view().find_last_of(set); } /** * @brief Find the last occurrence of a character outside a set. */ - size_type find_last_not_of(char_set set) const noexcept { return view().find_last_not_of(set); } + size_type find_last_not_of(byteset set) const noexcept { return view().find_last_not_of(set); } /** * @brief Find the last occurrence of a character from a set. * @param until The offset of the last character to be considered. */ - size_type find_last_of(char_set set, size_type until) const noexcept { return view().find_last_of(set, until); } + size_type find_last_of(byteset set, size_type until) const noexcept { return view().find_last_of(set, until); } /** * @brief Find the last occurrence of a character outside a set. * @param until The offset of the last character to be considered. */ - size_type find_last_not_of(char_set set, size_type until) const noexcept { + size_type find_last_not_of(byteset set, size_type until) const noexcept { return view().find_last_not_of(set, until); } @@ -2697,7 +2697,7 @@ class basic_string { * @brief Python-like convenience function, dropping prefix formed of given characters. * Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. */ - basic_string &lstrip(char_set set) noexcept { + basic_string &lstrip(byteset set) noexcept { auto remaining = view().lstrip(set); remove_prefix(size() - remaining.size()); return *this; @@ -2707,7 +2707,7 @@ class basic_string { * @brief Python-like convenience function, dropping suffix formed of given characters. * Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. */ - basic_string &rstrip(char_set set) noexcept { + basic_string &rstrip(byteset set) noexcept { auto remaining = view().rstrip(set); remove_suffix(size() - remaining.size()); return *this; @@ -2717,7 +2717,7 @@ class basic_string { * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. * Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. */ - basic_string &strip(char_set set) noexcept { return lstrip(set).rstrip(set); } + basic_string &strip(byteset set) noexcept { return lstrip(set).rstrip(set); } #pragma endregion #pragma endregion @@ -3339,7 +3339,7 @@ class basic_string { sz_ptr_t start; sz_size_t length; sz_string_range(&string_, &start, &length); - sz_generate(start, length, nonce); + sz_fill_random(start, length, nonce); return *this; } @@ -3393,7 +3393,7 @@ class basic_string { * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. * The algorithm is suboptimal when this string is made exclusively of the pattern. */ - basic_string &replace_all(char_set pattern, string_view replacement) noexcept(false) { + basic_string &replace_all(byteset pattern, string_view replacement) noexcept(false) { if (!try_replace_all(pattern, replacement)) throw std::bad_alloc(); return *this; } @@ -3418,8 +3418,8 @@ class basic_string { * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. * The algorithm is suboptimal when this string is made exclusively of the pattern. */ - bool try_replace_all(char_set pattern, string_view replacement) noexcept { - return try_replace_all_(pattern, replacement); + bool try_replace_all(byteset pattern, string_view replacement) noexcept { + return try_replace_all_(pattern, replacement); } /** @@ -3458,8 +3458,8 @@ static_assert(sizeof(string) == 4 * sizeof(void *), "String size must be 4 point namespace literals { constexpr string_view operator""_sv(char const *str, std::size_t length) noexcept { return {str, length}; } -sz_constexpr_if_cpp14 char_set operator""_cs(char const *str, std::size_t length) noexcept { - return char_set {str, length}; +sz_constexpr_if_cpp14 byteset operator""_bs(char const *str, std::size_t length) noexcept { + return byteset {str, length}; } } // namespace literals @@ -3565,7 +3565,7 @@ bool basic_string::try_replace_all_(pattern_type pattern // 1. The pattern and the replacement are of the same length. Piece of cake! // 2. The pattern is longer than the replacement. We need to compact the strings. // 3. The pattern is shorter than the replacement. We may have to allocate more memory. - using matcher_type = typename std::conditional::value, + using matcher_type = typename std::conditional::value, matcher_find_first_of, matcher_find>::type; matcher_type matcher({pattern}); @@ -3611,7 +3611,7 @@ bool basic_string::try_replace_all_(pattern_type pattern // 3. The pattern is shorter than the replacement. We may have to allocate more memory. else { - using rmatcher_type = typename std::conditional::value, + using rmatcher_type = typename std::conditional::value, matcher_find_last_of, matcher_rfind>::type; using rmatches_type = range_rmatches; @@ -3927,7 +3927,7 @@ std::ptrdiff_t alignment_score( template void randomize(basic_string_slice string, sz_u64_t nonce) noexcept { static_assert(!std::is_const::value, "The string must be mutable."); - sz_generate(string.data(), string.size(), nonce); + sz_fill_random(string.data(), string.size(), nonce); } /** diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 75d76e61..39a6352b 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -22,7 +22,7 @@ * - `sz_string_view_t` - for a C-style `std::string_view`-like structure. * - `sz_memory_allocator_t` - a wrapper for memory-management functions. * - `sz_sequence_t` - a wrapper to access strings forming a sequential container. - * - `sz_charset_t` - a bitset for 256 possible byte values. + * - `sz_byteset_t` - a bitset for 256 possible byte values. */ #ifndef STRINGZILLA_TYPES_H_ #define STRINGZILLA_TYPES_H_ @@ -344,10 +344,13 @@ typedef enum { */ sz_bad_alloc_k = -1, /** - * For algorithms that have an upper bound on some parameter, like the maximum number of iterations, - * or the maximum edit distance, this status indicates that the limit was reached. + * For algorithms that require UTF8 input, this status indicates that the input is invalid. */ - sz_reached_limit_k = -2, + sz_invalid_utf8_k = -2, + /** + * For algorithms that take collections of unique elements, this status indicates presence of duplicates. + */ + sz_contains_duplicates_k = -3, } sz_status_t; /** @@ -374,33 +377,47 @@ typedef struct sz_string_view_t { #pragma region Character Sets /** - * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. - * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert + * @brief Bit-set semi-opaque structure for 256 possible byte values. Useful for filtering and search. + * @sa sz_byteset_init, sz_byteset_add, sz_byteset_contains, sz_byteset_invert + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char const *alphabet = "abcdefghijklmnopqrstuvwxyz"; + * sz_byteset_t byteset; + * sz_byteset_init(&byteset); + * for (sz_size_t i = 0; i < 26; ++i) + * sz_byteset_add(&byteset, alphabet[i]); + * return sz_byteset_contains(&byteset, 'a') && !sz_byteset_contains(&byteset, 'A') ? 0 : 1; + * } + * @endcode */ -typedef union sz_charset_t { +typedef union sz_byteset_t { sz_u64_t _u64s[4]; sz_u32_t _u32s[8]; sz_u16_t _u16s[16]; sz_u8_t _u8s[32]; -} sz_charset_t; +} sz_byteset_t; /** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ -SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } +SZ_PUBLIC void sz_byteset_init(sz_byteset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } /** @brief Initializes a bit-set to all ASCII character. */ -SZ_PUBLIC void sz_charset_init_ascii(sz_charset_t *s) { +SZ_PUBLIC void sz_byteset_init_ascii(sz_byteset_t *s) { s->_u64s[0] = s->_u64s[1] = 0xFFFFFFFFFFFFFFFFull; s->_u64s[2] = s->_u64s[3] = 0; } /** @brief Adds a character to the set and accepts @b unsigned integers. */ -SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } +SZ_PUBLIC void sz_byteset_add_u8(sz_byteset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } -/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ -SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast +/** @brief Adds a character to the set. Consider @b sz_byteset_add_u8. */ +SZ_PUBLIC void sz_byteset_add(sz_byteset_t *s, char c) { sz_byteset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast /** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ -SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { +SZ_PUBLIC sz_bool_t sz_byteset_contains_u8(sz_byteset_t const *s, sz_u8_t c) { // Checking the bit can be done in different ways: // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 @@ -409,13 +426,13 @@ SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); } -/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ -SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { - return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast +/** @brief Checks if the set contains a given character. Consider @b sz_byteset_contains_u8. */ +SZ_PUBLIC sz_bool_t sz_byteset_contains(sz_byteset_t const *s, char c) { + return sz_byteset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast } /** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ -SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { +SZ_PUBLIC void sz_byteset_invert(sz_byteset_t *s) { s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; } @@ -476,8 +493,8 @@ typedef sz_u64_t (*sz_hash_state_fold_t)(struct sz_hash_state_t const *); /** @brief Signature of `sz_bytesum`. */ typedef sz_u64_t (*sz_bytesum_t)(sz_cptr_t, sz_size_t); -/** @brief Signature of `sz_generate`. */ -typedef void (*sz_generate_t)(sz_ptr_t, sz_size_t, sz_u64_t); +/** @brief Signature of `sz_fill_random`. */ +typedef void (*sz_fill_random_t)(sz_ptr_t, sz_size_t, sz_u64_t); /** @brief Signature of `sz_equal`. */ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); @@ -486,7 +503,7 @@ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); /** @brief Signature of `sz_lookup`. */ -typedef void (*sz_lookup_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); +typedef void (*sz_lookup_t)(sz_ptr_t, sz_size_t, sz_cptr_t, sz_cptr_t); /** @brief Signature of `sz_move`. */ typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); @@ -501,7 +518,7 @@ typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); /** @brief Signature of `sz_find_set`. */ -typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); +typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_byteset_t const *); /** @brief Signature of `sz_hamming_distance`. */ typedef sz_status_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_size_t *); @@ -515,18 +532,14 @@ typedef sz_status_t (*sz_needleman_wunsch_score_t)(sz_cptr_t, sz_size_t, sz_cptr sz_error_cost_t, sz_memory_allocator_t *, sz_ssize_t *); /** @brief Signature of `sz_sequence_argsort`. */ -typedef sz_status_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *, - sz_bool_t *); +typedef sz_status_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); /** @brief Signature of `sz_pgrams_sort`. */ -typedef sz_status_t (*sz_pgrams_sort_t)(sz_pgram_t *, sz_size_t, sz_memory_allocator_t *, sz_sorted_idx_t *, - sz_bool_t *); - -/** @brief Signature of `sz_sequence_argsort_stable`. */ -typedef sz_sequence_argsort_t sz_sequence_argsort_stable_t; +typedef sz_status_t (*sz_pgrams_sort_t)(sz_pgram_t *, sz_size_t, sz_memory_allocator_t *, sz_sorted_idx_t *); -/** @brief Signature of `sz_pgrams_sort_stable`. */ -typedef sz_pgrams_sort_t sz_pgrams_sort_stable_t; +/** @brief Signature of `sz_sequence_join`. */ +typedef sz_status_t (*sz_sequence_join_t)(struct sz_sequence_t const *, struct sz_sequence_t const *, + sz_memory_allocator_t *, sz_size_t *, sz_sorted_idx_t *, sz_sorted_idx_t *); #pragma endregion diff --git a/scripts/bench_search.cpp b/scripts/bench_search.cpp index 7380a697..6ffd9790 100644 --- a/scripts/bench_search.cpp +++ b/scripts/bench_search.cpp @@ -123,11 +123,11 @@ tracked_binary_functions_t rfind_functions() { return result; } -tracked_binary_functions_t find_charset_functions() { +tracked_binary_functions_t find_byteset_functions() { // ! Despite receiving string-views, following functions are assuming the strings are null-terminated. auto wrap_sz = [](auto function) -> binary_function_t { return binary_function_t([function](std::string_view h, std::string_view n) { - sz::char_set set; + sz::byteset set; for (auto c : n) set.add(c); sz_cptr_t match = function(h.data(), h.size(), &set.raw()); return (match ? match - h.data() : h.size()); @@ -139,26 +139,26 @@ tracked_binary_functions_t find_charset_functions() { auto match = h.find_first_of(n); return (match == std::string_view::npos ? h.size() : match); }}, - {"sz_find_charset_serial", wrap_sz(sz_find_charset_serial), true}, + {"sz_find_byteset_serial", wrap_sz(sz_find_byteset_serial), true}, #if SZ_USE_HASWELL - {"sz_find_charset_haswell", wrap_sz(sz_find_charset_haswell), true}, + {"sz_find_byteset_haswell", wrap_sz(sz_find_byteset_haswell), true}, #endif #if SZ_USE_ICE - {"sz_find_charset_ice", wrap_sz(sz_find_charset_ice), true}, + {"sz_find_byteset_ice", wrap_sz(sz_find_byteset_ice), true}, #endif #if SZ_USE_NEON - {"sz_find_charset_neon", wrap_sz(sz_find_charset_neon), true}, + {"sz_find_byteset_neon", wrap_sz(sz_find_byteset_neon), true}, #endif {"strcspn", [](std::string_view h, std::string_view n) { return strcspn(h.data(), n.data()); }}, }; return result; } -tracked_binary_functions_t rfind_charset_functions() { +tracked_binary_functions_t rfind_byteset_functions() { // ! Despite receiving string-views, following functions are assuming the strings are null-terminated. auto wrap_sz = [](auto function) -> binary_function_t { return binary_function_t([function](std::string_view h, std::string_view n) { - sz::char_set set; + sz::byteset set; for (auto c : n) set.add(c); sz_cptr_t match = function(h.data(), h.size(), &set.raw()); return (match ? match - h.data() : 0); @@ -170,12 +170,12 @@ tracked_binary_functions_t rfind_charset_functions() { auto match = h.find_last_of(n); return (match == std::string_view::npos ? 0 : match); }}, - {"sz_rfind_charset_serial", wrap_sz(sz_rfind_charset_serial), true}, + {"sz_rfind_byteset_serial", wrap_sz(sz_rfind_byteset_serial), true}, #if SZ_USE_ICE - {"sz_rfind_charset_ice", wrap_sz(sz_rfind_charset_ice), true}, + {"sz_rfind_byteset_ice", wrap_sz(sz_rfind_byteset_ice), true}, #endif #if SZ_USE_NEON - {"sz_rfind_charset_neon", wrap_sz(sz_rfind_charset_neon), true}, + {"sz_rfind_byteset_neon", wrap_sz(sz_rfind_byteset_neon), true}, #endif }; return result; @@ -304,25 +304,25 @@ int main(int argc, char const **argv) { bench_rfinds(dataset.text, {" "}, rfind_functions()); std::printf("Benchmarking for an [\\n\\r\\v\\f] RegEx:\n"); - bench_finds(dataset.text, {"\n\r\v\f"}, find_charset_functions()); - bench_rfinds(dataset.text, {"\n\r\v\f"}, rfind_charset_functions()); + bench_finds(dataset.text, {"\n\r\v\f"}, find_byteset_functions()); + bench_rfinds(dataset.text, {"\n\r\v\f"}, rfind_byteset_functions()); // Typical ASCII tokenization and validation benchmarks std::printf("Benchmarking for all whitespaces:\n"); - bench_finds(dataset.text, {{sz::whitespaces(), sizeof(sz::whitespaces())}}, find_charset_functions()); - bench_rfinds(dataset.text, {{sz::whitespaces(), sizeof(sz::whitespaces())}}, rfind_charset_functions()); + bench_finds(dataset.text, {{sz::whitespaces(), sizeof(sz::whitespaces())}}, find_byteset_functions()); + bench_rfinds(dataset.text, {{sz::whitespaces(), sizeof(sz::whitespaces())}}, rfind_byteset_functions()); std::printf("Benchmarking for HTML tag start/end:\n"); - bench_finds(dataset.text, {"<>"}, find_charset_functions()); - bench_rfinds(dataset.text, {"<>"}, rfind_charset_functions()); + bench_finds(dataset.text, {"<>"}, find_byteset_functions()); + bench_rfinds(dataset.text, {"<>"}, rfind_byteset_functions()); std::printf("Benchmarking for punctuation marks:\n"); - bench_finds(dataset.text, {{sz::punctuation(), sizeof(sz::punctuation())}}, find_charset_functions()); - bench_rfinds(dataset.text, {{sz::punctuation(), sizeof(sz::punctuation())}}, rfind_charset_functions()); + bench_finds(dataset.text, {{sz::punctuation(), sizeof(sz::punctuation())}}, find_byteset_functions()); + bench_rfinds(dataset.text, {{sz::punctuation(), sizeof(sz::punctuation())}}, rfind_byteset_functions()); std::printf("Benchmarking for non-printable characters:\n"); - bench_finds(dataset.text, {{sz::ascii_controls(), sizeof(sz::ascii_controls())}}, find_charset_functions()); - bench_rfinds(dataset.text, {{sz::ascii_controls(), sizeof(sz::ascii_controls())}}, rfind_charset_functions()); + bench_finds(dataset.text, {{sz::ascii_controls(), sizeof(sz::ascii_controls())}}, find_byteset_functions()); + bench_rfinds(dataset.text, {{sz::ascii_controls(), sizeof(sz::ascii_controls())}}, rfind_byteset_functions()); // Baseline benchmarks for present tokens, coming in all lengths std::printf("Benchmarking on present lines:\n"); diff --git a/scripts/test.cpp b/scripts/test.cpp index 8465cf15..a0eac08e 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -63,7 +63,7 @@ namespace sz = ashvardanian::stringzilla; using namespace sz::scripts; using sz::literals::operator""_sv; // for `sz::string_view` -using sz::literals::operator""_cs; // for `sz::char_set` +using sz::literals::operator""_bs; // for `sz::byteset` /* * Instantiate all the templates to make the symbols visible and also check @@ -75,7 +75,7 @@ template class std::basic_string_view; template class sz::basic_string_slice; template class std::basic_string; template class sz::basic_string; -template class sz::basic_char_set; +template class sz::basic_byteset; template class std::vector; template class std::map; @@ -202,7 +202,7 @@ static void test_hashing_on_platform( // * @brief Tests Pseudo-Random Number Generators (PRNGs) ensuring that the same nonce * produces exactly the same output across different SIMD implementations. */ -static void test_random_generator_on_platform(sz_generate_t generate_base, sz_generate_t generate_simd) { +static void test_random_generator_on_platform(sz_fill_random_t generate_base, sz_fill_random_t generate_simd) { auto test_on_nonce = [&](std::size_t length, sz_u64_t nonce) { std::string text_base(length, '\0'); @@ -231,7 +231,7 @@ static void test_simd_against_serial() { sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_haswell, sz_hash_state_init_haswell, // sz_hash_state_stream_haswell, sz_hash_state_fold_haswell); - test_random_generator_on_platform(sz_generate_serial, sz_generate_haswell); + test_random_generator_on_platform(sz_fill_random_serial, sz_fill_random_haswell); #endif #if SZ_USE_SKYLAKE test_hashing_on_platform( // @@ -239,7 +239,7 @@ static void test_simd_against_serial() { sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_skylake, sz_hash_state_init_skylake, // sz_hash_state_stream_skylake, sz_hash_state_fold_skylake); - test_random_generator_on_platform(sz_generate_serial, sz_generate_skylake); + test_random_generator_on_platform(sz_fill_random_serial, sz_fill_random_skylake); #endif #if SZ_USE_ICE test_hashing_on_platform( // @@ -247,7 +247,7 @@ static void test_simd_against_serial() { sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_ice, sz_hash_state_init_ice, // sz_hash_state_stream_ice, sz_hash_state_fold_ice); - test_random_generator_on_platform(sz_generate_serial, sz_generate_ice); + test_random_generator_on_platform(sz_fill_random_serial, sz_fill_random_ice); #endif #if SZ_USE_NEON test_hashing_on_platform( // @@ -255,7 +255,7 @@ static void test_simd_against_serial() { sz_hash_state_stream_serial, sz_hash_state_fold_serial, // sz_hash_neon, sz_hash_state_init_neon, // sz_hash_state_stream_neon, sz_hash_state_fold_neon); - test_random_generator_on_platform(sz_generate_serial, sz_generate_neon); + test_random_generator_on_platform(sz_fill_random_serial, sz_fill_random_neon); #endif }; @@ -268,13 +268,13 @@ static void test_ascii_utilities() { using str = string_type; - assert("aaa"_cs.size() == 1ull); - assert("\0\0"_cs.size() == 1ull); - assert("abc"_cs.size() == 3ull); - assert("a\0bc"_cs.size() == 4ull); + assert("aaa"_bs.size() == 1ull); + assert("\0\0"_bs.size() == 1ull); + assert("abc"_bs.size() == 3ull); + assert("a\0bc"_bs.size() == 4ull); - assert(!"abc"_cs.contains('\0')); - assert(str("bca").contains_only("abc"_cs)); + assert(!"abc"_bs.contains('\0')); + assert(str("bca").contains_only("abc"_bs)); assert(!str("").is_alpha()); assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ").is_alpha()); @@ -309,9 +309,9 @@ static void test_ascii_utilities() { assert(str("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+").is_printable()); assert(!str("012🔥").is_printable()); - assert(str("").contains_only("abc"_cs)); - assert(str("abc").contains_only("abc"_cs)); - assert(!str("abcd").contains_only("abc"_cs)); + assert(str("").contains_only("abc"_bs)); + assert(str("abc").contains_only("abc"_bs)); + assert(!str("abcd").contains_only("abc"_bs)); } inline void expect_equality(char const *a, char const *b, std::size_t size) { @@ -1026,9 +1026,9 @@ void test_non_stl_extensions_for_updates() { assert_scoped(str s = "hello", s.replace_all("xx", "xx"), s == "hello"); assert_scoped(str s = "hello", s.replace_all("l", "1"), s == "he11o"); assert_scoped(str s = "hello", s.replace_all("he", "al"), s == "alllo"); - assert_scoped(str s = "hello", s.replace_all("x"_cs, "!"), s == "hello"); - assert_scoped(str s = "hello", s.replace_all("o"_cs, "!"), s == "hell!"); - assert_scoped(str s = "hello", s.replace_all("ho"_cs, "!"), s == "!ell!"); + assert_scoped(str s = "hello", s.replace_all("x"_bs, "!"), s == "hello"); + assert_scoped(str s = "hello", s.replace_all("o"_bs, "!"), s == "hell!"); + assert_scoped(str s = "hello", s.replace_all("ho"_bs, "!"), s == "!ell!"); // Shorter replacements. assert_scoped(str s = "hello", s.replace_all("xx", "x"), s == "hello"); @@ -1036,8 +1036,8 @@ void test_non_stl_extensions_for_updates() { assert_scoped(str s = "hello", s.replace_all("h", ""), s == "ello"); assert_scoped(str s = "hello", s.replace_all("o", ""), s == "hell"); assert_scoped(str s = "hello", s.replace_all("llo", "!"), s == "he!"); - assert_scoped(str s = "hello", s.replace_all("x"_cs, ""), s == "hello"); - assert_scoped(str s = "hello", s.replace_all("lo"_cs, ""), s == "he"); + assert_scoped(str s = "hello", s.replace_all("x"_bs, ""), s == "hello"); + assert_scoped(str s = "hello", s.replace_all("lo"_bs, ""), s == "he"); // Longer replacements. assert_scoped(str s = "hello", s.replace_all("xx", "xxx"), s == "hello"); @@ -1045,8 +1045,8 @@ void test_non_stl_extensions_for_updates() { assert_scoped(str s = "hello", s.replace_all("h", "hh"), s == "hhello"); assert_scoped(str s = "hello", s.replace_all("o", "oo"), s == "helloo"); assert_scoped(str s = "hello", s.replace_all("llo", "llo!"), s == "hello!"); - assert_scoped(str s = "hello", s.replace_all("x"_cs, "xx"), s == "hello"); - assert_scoped(str s = "hello", s.replace_all("lo"_cs, "lo"), s == "helololo"); + assert_scoped(str s = "hello", s.replace_all("x"_bs, "xx"), s == "hello"); + assert_scoped(str s = "hello", s.replace_all("lo"_bs, "lo"), s == "helololo"); // Directly mapping bytes using a Look-Up Table. sz::look_up_table invert_case = sz::look_up_table::identity(); @@ -1286,9 +1286,9 @@ static void test_search() { assert("aabaa"_sv.remove_prefix("a") == "abaa"); assert("aabaa"_sv.remove_suffix("a") == "aaba"); - assert("aabaa"_sv.lstrip("a"_cs) == "baa"); - assert("aabaa"_sv.rstrip("a"_cs) == "aab"); - assert("aabaa"_sv.strip("a"_cs) == "b"); + assert("aabaa"_sv.lstrip("a"_bs) == "baa"); + assert("aabaa"_sv.rstrip("a"_bs) == "aab"); + assert("aabaa"_sv.strip("a"_bs) == "b"); // Check more advanced composite operations assert("abbccc"_sv.partition('b').before.size() == 1); @@ -1320,21 +1320,21 @@ static void test_search() { assert("a.b.c.d"_sv.find_all(".").size() == 3); assert("a.,b.,c.,d"_sv.find_all(".,").size() == 3); assert("a.,b.,c.,d"_sv.rfind_all(".,").size() == 3); - assert("a.b,c.d"_sv.find_all(".,"_cs).size() == 3); + assert("a.b,c.d"_sv.find_all(".,"_bs).size() == 3); assert("a...b...c"_sv.rfind_all("..").size() == 4); assert("a...b...c"_sv.rfind_all("..", sz::include_overlaps_type {}).size() == 4); assert("a...b...c"_sv.rfind_all("..", sz::exclude_overlaps_type {}).size() == 2); - auto finds = "a.b.c"_sv.find_all("abcd"_cs).template to>(); + auto finds = "a.b.c"_sv.find_all("abcd"_bs).template to>(); assert(finds.size() == 3); assert(finds[0] == "a"); - auto rfinds = "a.b.c"_sv.rfind_all("abcd"_cs).template to>(); + auto rfinds = "a.b.c"_sv.rfind_all("abcd"_bs).template to>(); assert(rfinds.size() == 3); assert(rfinds[0] == "c"); { - auto splits = ".a..c."_sv.split("."_cs).template to>(); + auto splits = ".a..c."_sv.split("."_bs).template to>(); assert(splits.size() == 5); assert(splits[0] == ""); assert(splits[1] == "a"); @@ -1369,9 +1369,9 @@ static void test_search() { assert(*advanced("a.b.c.d"_sv.split(".").begin(), 3) == "d"); assert(*advanced("a.b.c.d"_sv.rsplit(".").begin(), 3) == "a"); assert("a.b.,c,d"_sv.split(".,").size() == 2); - assert("a.b,c.d"_sv.split(".,"_cs).size() == 4); + assert("a.b,c.d"_sv.split(".,"_bs).size() == 4); - auto rsplits = ".a..c."_sv.rsplit("."_cs).template to>(); + auto rsplits = ".a..c."_sv.rsplit("."_bs).template to>(); assert(rsplits.size() == 5); assert(rsplits[0] == ""); assert(rsplits[1] == "c"); @@ -1724,9 +1724,9 @@ static void test_sequence_algorithms() { sz_cptr_t strings[] = {"banana", "apple", "cherry"}; sz_sequence_from_null_terminated_strings(strings, 3, &sequence); assert(sequence.count == 3); - assert(sequence.get_start(sequence.handle, 0) == "banana"_sv); - assert(sequence.get_start(sequence.handle, 1) == "apple"_sv); - assert(sequence.get_start(sequence.handle, 2) == "cherry"_sv); + assert("banana"_sv == sequence.get_start(sequence.handle, 0)); + assert("apple"_sv == sequence.get_start(sequence.handle, 1)); + assert("cherry"_sv == sequence.get_start(sequence.handle, 2)); } // Basic tests with predetermined orders. @@ -1813,14 +1813,6 @@ static void test_stl_containers() { int main(int argc, char const **argv) { - sz_u128_vec_t some_state, some_key; - randomize_string((char *)&some_state.u8s[0], 16); - randomize_string((char *)&some_key.u8s[0], 16); - sz_u128_vec_t emulated_result = _sz_emulate_aesenc_si128_serial(some_state, some_key); - sz_u128_vec_t hardware_result; - hardware_result.xmm = _mm_aesenc_si128(some_state.xmm, some_key.xmm); - assert(memcmp(&emulated_result, &hardware_result, sizeof(sz_u128_vec_t)) == 0); - // Let's greet the user nicely sz_unused(argc && argv); std::printf("Hi, dear tester! You look nice today!\n"); From 2caefac64aeae2c3b5bee2adebfb76dc8acf38f4 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 28 Feb 2025 15:50:37 +0000 Subject: [PATCH 135/751] Fix: Compilation of all bindings --- CONTRIBUTING.md | 14 + Package.swift | 12 +- README.md | 60 +-- include/stringzilla/sort.h | 210 ++++---- include/stringzilla/stringzilla.h | 11 +- include/stringzilla/types.h | 14 +- python/lib.c | 183 +++---- rust/lib.rs | 657 ++++++++++++------------- rustfmt.toml | 1 + scripts/bench_memory.cpp | 2 +- scripts/bench_similarity.cpp | 17 +- scripts/bench_sort.cpp | 26 +- scripts/bench_token.cpp | 16 +- scripts/test.py | 30 +- swift/StringProtocol+StringZilla.swift | 109 ++-- swift/Test.swift | 25 +- 16 files changed, 690 insertions(+), 697 deletions(-) create mode 100644 rustfmt.toml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d6009a30..a8e825af 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -432,6 +432,13 @@ npm ci && npm test swift build && swift test ``` +To format, consider using [SwiftFormat](https://github.com/nicklockwood/SwiftFormat): + +```bash +brew install swiftformat +swiftformat . +``` + Running Swift on Linux requires a couple of extra steps, as the Swift compiler is not available in the default repositories. Please get the most recent Swift tarball from the [official website](https://www.swift.org/install/). At the time of writing, for 64-bit Arm CPU running Ubuntu 22.04, the following commands would work: @@ -467,6 +474,13 @@ sudo docker run --rm -v "$PWD:/workspace" -w /workspace swift:5.9 /bin/bash -cl cargo test ``` +If you need to isolate a failing test: + +```bash +export RUST_BACKTRACE=full +cargo test -- --test-threads=1 --nocapture +``` + If you are updating the package contents, you can validate the list of included files using the following command: ```bash diff --git a/Package.swift b/Package.swift index c5c15fbb..ec3fe103 100644 --- a/Package.swift +++ b/Package.swift @@ -5,16 +5,16 @@ let package = Package( name: "StringZilla", platforms: [ // Linux doesn't have to be explicitly listed - .iOS(.v13), // For iOS, version 13 and later - .tvOS(.v13), // For tvOS, version 13 and later + .iOS(.v13), // For iOS, version 13 and later + .tvOS(.v13), // For tvOS, version 13 and later .macOS(.v10_15), // For macOS, version 10.15 (Catalina) and later - .watchOS(.v6) // For watchOS, version 6 and later + .watchOS(.v6), // For watchOS, version 6 and later ], products: [ .library( name: "StringZilla", targets: ["StringZillaC", "StringZilla"] - ) + ), ], targets: [ .target( @@ -27,7 +27,7 @@ let package = Package( .define("SZ_AVOID_LIBC", to: "0"), // We need `malloc` from LibC .define("SZ_DEBUG", to: "0"), // We don't need any extra assertions in the C layer .headerSearchPath("include/stringzilla"), // Specify header search paths - .unsafeFlags(["-Wall"]) // Use with caution: specify custom compiler flags + .unsafeFlags(["-Wall"]), // Use with caution: specify custom compiler flags ] ), .target( @@ -41,7 +41,7 @@ let package = Package( dependencies: ["StringZilla"], path: "swift", sources: ["Test.swift"] - ) + ), ], cLanguageStandard: CLanguageStandard.c99 ) diff --git a/README.md b/README.md index 22c8e2b0..18aea8e2 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ __Who is this for?__ arm: 0.02 GB/s - sz_find_charset
+ sz_find_byteset
x86: 4.08 · arm: 3.22 GB/s @@ -155,7 +155,7 @@ __Who is this for?__ ⚪ - sz_rfind_charset
+ sz_rfind_byteset
x86: 0.43 · arm: 0.23 GB/s @@ -181,7 +181,7 @@ __Who is this for?__ arm: 5.9 MB/s - sz_generate
+ sz_fill_random
x86: 56.2 · arm: 25.8 MB/s @@ -203,7 +203,7 @@ __Who is this for?__ arm: 140.0 MB/s - sz_look_up_transform
+ sz_lookup
x86: 21.2 · arm: 8.5 GB/s @@ -247,7 +247,7 @@ __Who is this for?__ arm: 2,220 ns - sz_edit_distance
+ sz_levenshtein_distance
x86: 99 · arm: 180 ns @@ -265,7 +265,7 @@ __Who is this for?__ arm: 367 ms - sz_alignment_score
+ sz_needleman_wunsch_score
x86: 73 · arm: 177 ms @@ -396,8 +396,8 @@ x: int = text.find_first_of('chars', start=0, end=sys.maxsize) x: int = text.find_last_of('chars', start=0, end=sys.maxsize) x: int = text.find_first_not_of('chars', start=0, end=sys.maxsize) x: int = text.find_last_not_of('chars', start=0, end=sys.maxsize) -x: Strs = text.split_charset(separator='chars', maxsplit=sys.maxsize, keepseparator=False) -x: Strs = text.rsplit_charset(separator='chars', maxsplit=sys.maxsize, keepseparator=False) +x: Strs = text.split_byteset(separator='chars', maxsplit=sys.maxsize, keepseparator=False) +x: Strs = text.rsplit_byteset(separator='chars', maxsplit=sys.maxsize, keepseparator=False) ``` You can also transform the string using Look-Up Tables (LUTs), mapping it to a different character set. @@ -453,8 +453,8 @@ StringZilla saves a lot of memory by viewing existing memory regions as substrin ```py x: SplitIterator[Str] = text.split_iter(separator=' ', keepseparator=False) x: SplitIterator[Str] = text.rsplit_iter(separator=' ', keepseparator=False) -x: SplitIterator[Str] = text.split_charset_iter(separator='chars', keepseparator=False) -x: SplitIterator[Str] = text.rsplit_charset_iter(separator='chars', keepseparator=False) +x: SplitIterator[Str] = text.split_byteset_iter(separator='chars', keepseparator=False) +x: SplitIterator[Str] = text.rsplit_byteset_iter(separator='chars', keepseparator=False) ``` StringZilla can easily be 10x more memory efficient than native Python classes for tokenization. @@ -654,7 +654,7 @@ By design, StringZilla has a couple of notable differences from LibC: That way `sz_find` and `sz_rfind` are similar to `strstr` and `strrstr` in LibC. Similarly, `sz_find_byte` and `sz_rfind_byte` replace `memchr` and `memrchr`. -The `sz_find_charset` maps to `strspn` and `strcspn`, while `sz_rfind_charset` has no sibling in LibC. +The `sz_find_byteset` maps to `strspn` and `strcspn`, while `sz_rfind_byteset` has no sibling in LibC. @@ -679,11 +679,11 @@ The `sz_find_charset` maps to `strspn` and `strcspn`, while `sz_rfind_charset` h - + - + @@ -923,7 +923,7 @@ StringZilla provides a convenient `partition` function, which returns a tuple of ```cpp auto parts = haystack.partition(':'); // Matching a character auto [before, match, after] = haystack.partition(':'); // Structure unpacking -auto [before, match, after] = haystack.partition(sz::char_set(":;")); // Character-set argument +auto [before, match, after] = haystack.partition(sz::byteset(":;")); // Character-set argument auto [before, match, after] = haystack.partition(" : "); // String argument auto [before, match, after] = haystack.rpartition(sz::whitespaces_set()); // Split around the last whitespace ``` @@ -951,8 +951,8 @@ Here is a sneak peek of the most useful ones. ```cpp text.hash(); // -> 64 bit unsigned integer text.ssize(); // -> 64 bit signed length to avoid `static_cast(text.size())` -text.contains_only(" \w\t"); // == text.find_first_not_of(sz::char_set(" \w\t")) == npos; -text.contains(sz::whitespaces_set()); // == text.find(sz::char_set(sz::whitespaces_set())) != npos; +text.contains_only(" \w\t"); // == text.find_first_not_of(sz::byteset(" \w\t")) == npos; +text.contains(sz::whitespaces_set()); // == text.find(sz::byteset(sz::whitespaces_set())) != npos; // Simpler slicing than `substr` text.front(10); // -> sz::string_view @@ -997,7 +997,7 @@ To avoid those, StringZilla provides lazily-evaluated ranges, compatible with th ```cpp for (auto line : haystack.split("\r\n")) - for (auto word : line.split(sz::char_set(" \w\t.,;:!?"))) + for (auto word : line.split(sz::byteset(" \w\t.,;:!?"))) std::cout << word << std::endl; ``` @@ -1006,9 +1006,9 @@ It also allows interleaving matches, if you want both inclusions of `xx` in `xxx Debugging pointer offsets is not a pleasant exercise, so keep the following functions in mind. - `haystack.[r]find_all(needle, interleaving)` -- `haystack.[r]find_all(sz::char_set(""))` +- `haystack.[r]find_all(sz::byteset(""))` - `haystack.[r]split(needle)` -- `haystack.[r]split(sz::char_set(""))` +- `haystack.[r]split(sz::byteset(""))` For $N$ matches the split functions will report $N+1$ matches, potentially including empty strings. Ranges have a few convenience methods as well: @@ -1065,7 +1065,7 @@ sz::string random_string(std::size_t length, char const *alphabet, std::size_t c ``` Mouthful and slow. -StringZilla provides a C native method - `sz_generate` and a convenient C++ wrapper - `sz::generate`. +StringZilla provides a C native method - `sz_fill_random` and a convenient C++ wrapper - `sz::generate`. Similar to Python it also defines the commonly used character sets. ```cpp @@ -1085,9 +1085,9 @@ In text processing, it's often necessary to replace all occurrences of a specifi Standard library functions may not offer the most efficient or convenient methods for performing bulk replacements, especially when dealing with large strings or performance-critical applications. - `haystack.replace_all(needle_string, replacement_string)` -- `haystack.replace_all(sz::char_set(""), replacement_string)` +- `haystack.replace_all(sz::byteset(""), replacement_string)` - `haystack.try_replace_all(needle_string, replacement_string)` -- `haystack.try_replace_all(sz::char_set(""), replacement_string)` +- `haystack.try_replace_all(sz::byteset(""), replacement_string)` - `haystack.transform(sz::look_up_table::identity())` - `haystack.transform(sz::look_up_table::identity(), haystack.data())` @@ -1250,8 +1250,8 @@ sz::find("Hello, world!", "world") // 7 sz::rfind("Hello, world!", "world") // 7 // Generalizations of `memchr::memrchr[123]` -sz::find_char_from("Hello, world!", "world") // 2 -sz::rfind_char_from("Hello, world!", "world") // 11 +sz::find_byte_from("Hello, world!", "world") // 2 +sz::rfind_byte_from("Hello, world!", "world") // 11 ``` Unlike `memchr`, the throughput of `stringzilla` is [high in both normal and reverse-order searches][memchr-benchmarks]. @@ -1268,10 +1268,10 @@ let my_cow_str = Cow::from(&my_string); // Use the generic function with a String assert_eq!(my_string.sz_find("world"), Some(7)); assert_eq!(my_string.sz_rfind("world"), Some(7)); -assert_eq!(my_string.sz_find_char_from("world"), Some(2)); -assert_eq!(my_string.sz_rfind_char_from("world"), Some(11)); -assert_eq!(my_string.sz_find_char_not_from("world"), Some(0)); -assert_eq!(my_string.sz_rfind_char_not_from("world"), Some(12)); +assert_eq!(my_string.sz_find_byte_from("world"), Some(2)); +assert_eq!(my_string.sz_rfind_byte_from("world"), Some(11)); +assert_eq!(my_string.sz_find_byte_not_from("world"), Some(0)); +assert_eq!(my_string.sz_rfind_byte_not_from("world"), Some(12)); // Same works for &str and Cow<'_, str> assert_eq!(my_str.sz_find("world"), Some(7)); @@ -1315,7 +1315,7 @@ s[s.findLast(substring: "o")!...] // "o StringZilla. 👋") s[s.findFirst(characterFrom: "aeiou")!...] // "ello, world! Welcome to StringZilla. 👋") s[s.findLast(characterFrom: "aeiou")!...] // "a. 👋") s[s.findFirst(characterNotFrom: "aeiou")!...] // "Hello, world! Welcome to StringZilla. 👋" -s.editDistance(from: "Hello, world!")! // 29 +s.levenshteinDistance(from: "Hello, world!")! // 29 ``` ## Algorithms & Design Decisions 📚 @@ -1561,7 +1561,7 @@ Most StringZilla operations are byte-level, so they work well with ASCII and UTF In some cases, like edit-distance computation, the result of byte-level evaluation and character-level evaluation may differ. So StringZilla provides following functions to work with Unicode: -- `sz_edit_distance_utf8` - computes the Levenshtein distance between two UTF-8 strings. +- `sz_levenshtein_distance_utf8` - computes the Levenshtein distance between two UTF-8 strings. - `sz_hamming_distance_utf8` - computes the Hamming distance between two UTF-8 strings. Java, JavaScript, Python 2, C#, and Objective-C, however, use wide characters (`wchar`) - two byte long codes, instead of the more reasonable fixed-length UTF32 or variable-length UTF8. diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 8e387b70..721ba940 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -3,11 +3,14 @@ * @file sort.h * @author Ash Vardanian * - * Includes core APIs for `sz_sequence_t` string collections: + * Includes core APIs for `sz_sequence_t` string collections with hardware-specific backends: * * - `sz_sequence_argsort` - to get the sorting permutation of a string collection. * - `sz_sequence_join` - to compute the intersection of two arbitrary string collections. * + * The first can easily be used to implement SORT and GROUPBY operations SQL, while the second can be used to + * implement JOIN operations. Both are essential for implementing efficient database engines. + * * The core idea of all following string algorithms is to process strings not based on 1 character at a time, * but on a larger "Pointer-sized N-grams" fitting in 4 or 8 bytes at once, on 32-bit or 64-bit architectures, * respectively. In reality we may not use the full pointer size, but only a few bytes from it, and keep the @@ -21,7 +24,7 @@ * * Other helpers include: * - * - `sz_pgrams_sort_stable_with_insertion` - for quadratic-complexity sorting of small continuous integer arrays. + * - `sz_pgrams_sort_with_insertion` - for quadratic-complexity sorting of small continuous integer arrays. * - `sz_sequence_argsort_with_insertion` - for quadratic-complexity sorting of small string collections. * - `sz_sequence_argsort_stabilize` - updates the sorting permutation to be stable. */ @@ -45,10 +48,11 @@ extern "C" { * * @param[in] sequence Immutable sequence of strings to sort. * @param[in] alloc Optional memory allocator for temporary storage. - * @param[out] order Output permutation that sorts the elements. Must fit at least `sequence->count` integers. + * @param[out] order Output permutation that sorts the elements. * * @retval `sz_success_k` if the operation was successful. * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @pre The @p order array must fit at least `sequence->count` integers. * @post The @p order array will contain a valid permutation of `[0, sequence->count - 1]`. * * Example usage: @@ -60,8 +64,8 @@ extern "C" { * sz_sequence_t sequence; * sz_sequence_from_null_terminated_strings(strings, 3, &sequence); * sz_sorted_idx_t order[3]; - * sz_sequence_argsort(&sequence, NULL, order); - * return order[0] == 1 && order[1] == 0 && order[2] == 2 ? 0 : 1; + * sz_status_t status = sz_sequence_argsort(&sequence, NULL, order); + * return status == sz_success_k && order[0] == 1 && order[1] == 0 && order[2] == 2 ? 0 : 1; * } * @endcode * @@ -69,7 +73,7 @@ extern "C" { * @see https://en.wikipedia.org/wiki/Quicksort * * @note This algorithm is @b unstable: equal elements may change relative order. - * @sa sz_sequence_argsort_stable + * @sa sz_sequence_argsort_stabilize * * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. * @sa sz_sequence_argsort_serial, sz_sequence_argsort_skylake, sz_sequence_argsort_sve @@ -84,10 +88,11 @@ SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_mem * @param[inout] pgrams Continuous buffer of unsigned integers to sort in place. * @param[in] count Number of elements in the sequence. * @param[in] alloc Optional memory allocator for temporary storage. - * @param[out] order Output permutation that sorts the elements. Must fit at least @p count integers. + * @param[out] order Output permutation that sorts the elements. * * @retval `sz_success_k` if the operation was successful. * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @pre The @p order array must fit at least `count` integers. * @post The @p order array will contain a valid permutation of `[0, count - 1]`. * * Example usage: @@ -97,17 +102,14 @@ SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_mem * int main() { * sz_pgram_t pgrams[] = {42, 17, 99, 8}; * sz_sorted_idx_t order[4]; - * sz_pgrams_sort(pgrams, 4, NULL, order); - * return order[0] == 3 && order[1] == 1 && order[2] == 0 && order[3] == 2 ? 0 : 1; + * sz_status_t status = sz_pgrams_sort(pgrams, 4, NULL, order); + * return status == sz_success_k && order[0] == 3 && order[1] == 1 && order[2] == 0 && order[3] == 2 ? 0 : 1; * } * @endcode * * @note The algorithm has linear memory complexity, quadratic worst-case and log-linear average time complexity. * @see https://en.wikipedia.org/wiki/Quicksort * - * @note This algorithm is @b unstable: equal elements may change relative order. - * @sa sz_pgrams_sort_stable - * * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. * @sa sz_pgrams_sort_serial, sz_pgrams_sort_skylake, sz_pgrams_sort_sve */ @@ -115,46 +117,51 @@ SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_me sz_sorted_idx_t *order); /** - * @brief Faster @b arg-sort for an arbitrary @b string sequence, using MergeSort. - * Outputs the @p order of elements in the immutable @p sequence, that would sort it. + * @brief Intersects two arbitrary @b string sequences, using a hash table. + * Outputs the @p first_positions from the @p first_sequence and @p second_positions from + * the @p second_sequence, that contain identical strings. * - * This algorithm guarantees stability, ensuring that the relative order of equal elements is preserved. - * It uses more memory than `sz_sequence_argsort`, but its performance is more predictable. - * It's preferred for very large inputs, as most memory access happens in a sequential pattern. * - * @param[in] sequence Immutable sequence of strings to sort. + * @param[in] first_sequence First immutable sequence of strings to intersection. + * @param[in] second_sequence Second immutable sequence of strings to intersection. * @param[in] alloc Optional memory allocator for temporary storage. - * @param[out] order Output permutation that sorts the elements. Must fit at least `sequence->count` integers. + * @param[out] intersection_size Number of identical strings in both sequences. + * @param[out] first_positions Offset positions of the identical strings from the @p first_sequence. + * @param[out] second_positions Offset positions of the identical strings from the @p second_sequence. * * @retval `sz_success_k` if the operation was successful. * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. - * @post The @p order array will contain a valid permutation of `[0, sequence->count - 1]`. + * @retval `sz_contains_duplicates_k` if any of the sequences contain duplicate strings. + * @pre The @p first_positions arrays must fit at least `min(first_sequence->count, second_sequence->count)` items. + * @pre The @p second_positions arrays must fit at least `min(first_sequence->count, second_sequence->count)` items. * * Example usage: * * @code{.c} * #include * int main() { - * char const *strings[] = {"banana", "apple", "cherry"}; - * sz_sequence_t sequence; - * sz_sequence_from_null_terminated_strings(strings, 3, &sequence); - * sz_sorted_idx_t order[3]; - * sz_sequence_argsort_stable(&sequence, NULL, order); - * return order[0] == 1 && order[1] == 0 && order[2] == 2 ? 0 : 1; + * char const *first[] = {"banana", "apple", "cherry"}; + * char const *second[] = {"cherry", "orange", "pineapple", "banana"}; + * sz_sequence_t first_sequence, second_sequence; + * sz_sequence_from_null_terminated_strings(first, 3, &first_sequence); + * sz_sequence_from_null_terminated_strings(second, 4, &second_sequence); + * sz_size_t intersection_size; + * sz_sorted_idx_t first_positions[3], second_positions[3]; //? 3 is the size of the smaller sequence + * sz_status_t status = sz_sequence_join(&first_sequence, &second_sequence, NULL, + * &intersection_size, first_positions, second_positions); + * return status == sz_success_k && intersection_size == 2 ? 0 : 1; * } * @endcode * - * @note The algorithm has linear memory complexity and log-linear time complexity. - * @see https://en.wikipedia.org/wiki/Merge_sort - * - * @note This algorithm is @b stable: equal elements maintain their relative order. - * @sa sz_sequence_argsort + * @note The algorithm has linear memory complexity and linear time complexity. + * @see https://en.wikipedia.org/wiki/Join_(SQL) * * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. - * @sa sz_sequence_argsort_stable_serial, sz_sequence_argsort_stable_skylake, sz_sequence_argsort_stable_sve + * @sa sz_sequence_join_serial, sz_sequence_join_skylake, sz_sequence_join_sve */ -SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_DYNAMIC sz_status_t sz_sequence_join(sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, + sz_memory_allocator_t *alloc, sz_size_t *intersection_size, + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); /** * @brief Faster @b inplace `std::stable_sort` for a continuous @b unsigned-integer sequence, using MergeSort. @@ -180,7 +187,7 @@ SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, * int main() { * sz_pgram_t pgrams[] = {42, 17, 99, 8}; * sz_sorted_idx_t order[4]; - * sz_pgrams_sort_stable(pgrams, 4, NULL, order); + * sz_pgrams_join(pgrams, 4, NULL, order); * return order[0] == 3 && order[1] == 1 && order[2] == 0 && order[3] == 2 ? 0 : 1; * } * @endcode @@ -192,10 +199,10 @@ SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, * @sa sz_pgrams_sort * * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. - * @sa sz_pgrams_sort_stable_serial, sz_pgrams_sort_stable_skylake, sz_pgrams_sort_stable_sve + * @sa sz_pgrams_join_serial, sz_pgrams_join_skylake, sz_pgrams_join_sve */ -SZ_DYNAMIC sz_status_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +SZ_DYNAMIC sz_status_t sz_pgrams_join(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); /** @copydoc sz_sequence_argsort */ SZ_PUBLIC sz_status_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, @@ -205,6 +212,8 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, SZ_PUBLIC sz_status_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); +#if SZ_USE_SKYLAKE + /** @copydoc sz_sequence_argsort */ SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); @@ -213,6 +222,16 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, SZ_PUBLIC sz_status_t sz_pgrams_sort_skylake(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); +/** @copydoc sz_sequence_join */ +SZ_PUBLIC sz_status_t sz_sequence_join_skylake( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); + +#endif + +#if SZ_USE_SVE + /** @copydoc sz_sequence_argsort */ SZ_PUBLIC sz_status_t sz_sequence_argsort_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); @@ -221,36 +240,20 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_sve(sz_sequence_t const *sequence, sz_ SZ_PUBLIC sz_status_t sz_pgrams_sort_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** @copydoc sz_sequence_argsort_stable */ -SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); - -/** @copydoc sz_pgrams_sort_stable */ -SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +/** @copydoc sz_sequence_join */ +SZ_PUBLIC sz_status_t sz_sequence_join_sve( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); -/** @copydoc sz_sequence_argsort_stable */ -SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_skylake(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); - -/** @copydoc sz_pgrams_sort_stable */ -SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_skylake(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); - -/** @copydoc sz_sequence_argsort_stable */ -SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); - -/** @copydoc sz_pgrams_sort_stable */ -SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); +#endif #pragma endregion #pragma region Generic Public Helpers /** - * @brief Quadratic complexity insertion sort adjust for our @b argsort usecase. + * @brief Quadratic complexity @b stable insertion sort adjust for our @b argsort usecase. * Needs no extra memory and is used as a fallback for small inputs. */ SZ_PUBLIC void sz_sequence_argsort_with_insertion(sz_sequence_t const *sequence, sz_sorted_idx_t *order) { @@ -281,11 +284,11 @@ SZ_PUBLIC void sz_sequence_argsort_with_insertion(sz_sequence_t const *sequence, } /** - * @brief Quadratic complexity insertion sort adjust for our @b pgram-sorting usecase. + * @brief Quadratic complexity @b stable insertion sort adjust for our @b pgram-sorting usecase. * Needs no extra memory and is used as a fallback for small inputs. */ -SZ_PUBLIC void sz_pgrams_sort_stable_with_insertion(sz_pgram_t *pgrams, sz_size_t count, sz_sorted_idx_t *order) { +SZ_PUBLIC void sz_pgrams_sort_with_insertion(sz_pgram_t *pgrams, sz_size_t count, sz_sorted_idx_t *order) { // Assume `order` is already initialized with 0, 1, 2, ... N. for (sz_size_t i = 1; i < count; ++i) { @@ -714,7 +717,7 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, * @brief Helper function similar to `std::set_union` over pairs of integers and their original indices. * @see https://en.cppreference.com/w/cpp/algorithm/set_union */ -SZ_INTERNAL void _sz_sequence_argsort_stable_serial_merge( // +SZ_INTERNAL void _sz_sequence_join_serial_merge( // sz_pgram_t const *first_pgrams, sz_sorted_idx_t const *first_indices, sz_size_t first_count, // sz_pgram_t const *second_pgrams, sz_sorted_idx_t const *second_indices, sz_size_t second_count, // sz_pgram_t *result_pgrams, sz_sorted_idx_t *result_indices) { @@ -761,8 +764,8 @@ SZ_INTERNAL void _sz_sequence_argsort_stable_serial_merge( _sz_assert(merged_begin[i - 1] <= merged_begin[i] && "The merged pgrams must be in ascending order."); } -SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_PUBLIC sz_status_t sz_pgrams_join_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { // First, initialize the `order` with `std::iota`-like behavior. for (sz_size_t i = 0; i != count; ++i) order[i] = i; @@ -770,7 +773,7 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t // On very small collections - just use the quadratic-complexity insertion sort // without any smart optimizations or memory allocations. if (count <= 32) { - sz_pgrams_sort_stable_with_insertion(pgrams, count, order); + sz_pgrams_sort_with_insertion(pgrams, count, order); return sz_success_k; } @@ -779,7 +782,7 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t // For the tail of the array, sort it with insertion sort. sz_size_t const tail_count = count & 7u; - sz_pgrams_sort_stable_with_insertion(pgrams + count - tail_count, tail_count, order + count - tail_count); + sz_pgrams_sort_with_insertion(pgrams + count - tail_count, tail_count, order + count - tail_count); // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; @@ -821,7 +824,7 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t else if (i + left_count + right_count > count) { right_count = count - (i + left_count); } // Merge the two runs: - _sz_sequence_argsort_stable_serial_merge( // + _sz_sequence_join_serial_merge( // src_pgrams + i, src_order + i, left_count, // src_pgrams + i + run_size, src_order + i + run_size, right_count, // dst_pgrams + i, dst_order + i); @@ -844,9 +847,11 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_stable_serial(sz_pgram_t *pgrams, sz_size_t return sz_success_k; } -SZ_PUBLIC sz_status_t sz_sequence_argsort_stable_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { - sz_unused(sequence && alloc && order); +SZ_PUBLIC sz_status_t sz_sequence_join_serial( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { + sz_unused(first_sequence && second_sequence && alloc && intersection_size && first_positions && second_positions); return sz_success_k; } @@ -967,7 +972,7 @@ SZ_INTERNAL void _sz_sequence_argsort_skylake_3way_partition( * @brief Recursive Quick-Sort implementation backing both the `sz_sequence_argsort_skylake` and * `sz_pgrams_sort_skylake`, and using the `_sz_sequence_argsort_skylake_3way_partition` under the hood. */ -SZ_INTERNAL void _sz_sequence_argsort_skylake_recursively( // +SZ_PUBLIC void _sz_sequence_argsort_skylake_recursively( // sz_pgram_t *initial_pgrams, sz_sorted_idx_t *initial_order, // sz_pgram_t *temporary_pgrams, sz_sorted_idx_t *temporary_order, // sz_size_t const start_in_sequence, sz_size_t const end_in_sequence) { @@ -977,7 +982,7 @@ SZ_INTERNAL void _sz_sequence_argsort_skylake_recursively( // sz_size_t const count = end_in_sequence - start_in_sequence; sz_size_t const pgrams_per_register = sizeof(sz_u512_vec_t) / sizeof(sz_pgram_t); if (count <= pgrams_per_register) { - sz_pgrams_sort_stable_with_insertion( // + sz_pgrams_sort_with_insertion( // initial_pgrams + start_in_sequence, count, initial_order + start_in_sequence); return; } @@ -1040,12 +1045,12 @@ SZ_PUBLIC void _sz_sequence_argsort_skylake_next_pgrams( sz_size_t const start_character) { // Prepare the new range of pgrams - _sz_sequence_argsort_serial_export_next_pgrams(sequence, global_pgrams, global_order, start_in_sequence, - end_in_sequence, start_character); + _sz_sequence_argsort_serial_export_next_pgrams( // + sequence, global_pgrams, global_order, start_in_sequence, end_in_sequence, start_character); // Sort current pgrams with a quicksort - _sz_sequence_argsort_skylake_recursively(global_pgrams, global_order, temporary_pgrams, temporary_order, - start_in_sequence, end_in_sequence); + _sz_sequence_argsort_skylake_recursively( // + global_pgrams, global_order, temporary_pgrams, temporary_order, start_in_sequence, end_in_sequence); // Depending on the architecture, we will export a different number of bytes. // On 32-bit architectures, we will export 3 bytes, and on 64-bit architectures - 7 bytes. @@ -1064,11 +1069,11 @@ SZ_PUBLIC void _sz_sequence_argsort_skylake_next_pgrams( sz_size_t current_pgram_length = (sz_size_t)current_pgram_str[0]; //! The byte order was swapped int has_multiple_strings = nested_end - nested_start > 1; int has_more_characters_in_each = current_pgram_length == pgram_capacity; - if (has_multiple_strings && has_more_characters_in_each) { - _sz_sequence_argsort_skylake_next_pgrams(sequence, global_pgrams, global_order, temporary_pgrams, - temporary_order, nested_start, nested_end, - start_character + pgram_capacity); - } + if (has_multiple_strings && has_more_characters_in_each) + _sz_sequence_argsort_skylake_next_pgrams( // + sequence, global_pgrams, global_order, temporary_pgrams, temporary_order, nested_start, nested_end, + start_character + pgram_capacity); + // Move to the next nested_start = nested_end; } @@ -1111,6 +1116,14 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, return sz_success_k; } +SZ_PUBLIC sz_status_t sz_sequence_join_skylake( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { + sz_unused(first_sequence && second_sequence && alloc && intersection_size && first_positions && second_positions); + return sz_success_k; +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SKYLAKE @@ -1144,25 +1157,24 @@ SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_me #endif } -SZ_DYNAMIC sz_status_t sz_sequence_argsort_stable(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { +SZ_DYNAMIC sz_status_t sz_sequence_join(sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, + sz_memory_allocator_t *alloc, sz_size_t *intersection_size, + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { #if SZ_USE_SKYLAKE - return sz_sequence_argsort_skylake(sequence, alloc, order); + return sz_sequence_join_skylake( // + first_sequence, second_sequence, // + alloc, intersection_size, // + first_positions, second_positions); #elif SZ_USE_SVE - return sz_sequence_argsort_sve(sequence, alloc, order); + return sz_sequence_join_sve( // + first_sequence, second_sequence, // + alloc, intersection_size, // + first_positions, second_positions); #else - return sz_sequence_argsort_serial(sequence, alloc, order); -#endif -} - -SZ_DYNAMIC sz_status_t sz_pgrams_sort_stable(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { -#if SZ_USE_SKYLAKE - return sz_pgrams_sort_skylake(pgrams, count, alloc, order); -#elif SZ_USE_SVE - return sz_pgrams_sort_sve(pgrams, count, alloc, order); -#else - return sz_pgrams_sort_serial(pgrams, count, alloc, order); + return sz_sequence_join_serial( // + first_sequence, second_sequence, // + alloc, intersection_size, // + first_positions, second_positions); #endif } diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 7642f5ae..c497d4f1 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -64,12 +64,13 @@ typedef enum { sz_cap_haswell_k = 1 << 10, ///< x86 AVX2 capability with FMA and F16C extensions sz_cap_skylake_k = 1 << 11, ///< x86 AVX512 baseline capability - sz_cap_ice_k = 1 << 12, ///< x86 AVX512 capability with advanced integer algos + sz_cap_ice_k = 1 << 12, ///< x86 AVX512 capability with advanced integer algos and AES extensions - sz_cap_neon_k = 1 << 20, ///< ARM NEON baseline capability - sz_cap_sve_k = 1 << 21, ///< ARM SVE baseline capability - sz_cap_sve2_k = 1 << 22, ///< ARM SVE2 capability - sz_cap_sve2p1_k = 1 << 23, ///< ARM SVE2p1 capability + sz_cap_neon_k = 1 << 20, ///< ARM NEON baseline capability + sz_cap_neon_aes_k = 1 << 21, ///< ARM NEON baseline capability with AES extensions + sz_cap_sve_k = 1 << 24, ///< ARM SVE baseline capability + sz_cap_sve2_k = 1 << 25, ///< ARM SVE2 capability + sz_cap_sve2_aes_k = 1 << 26, ///< ARM SVE2 capability with AES extensions } sz_capability_t; diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 39a6352b..c4f71907 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -82,12 +82,8 @@ */ #if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) #define _SZ_IS_64_BIT (1) -#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. #else #define _SZ_IS_64_BIT (0) -#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. -#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. #endif /** @@ -302,10 +298,9 @@ typedef unsigned long long sz_u64_t; // Always 64 bits typedef unsigned long long sz_size_t; // 64-bit. typedef long long sz_ssize_t; // 64-bit. #else -typedef unsigned sz_size_t; // 32-bit. -typedef unsigned sz_ssize_t; // 32-bit. +typedef unsigned int sz_size_t; // 32-bit. +typedef int sz_ssize_t; // 32-bit. #endif // _SZ_IS_64_BIT - #endif // SZ_AVOID_LIBC /** @@ -774,6 +769,11 @@ SZ_PUBLIC void sz_sequence_from_null_terminated_strings(sz_cptr_t *start, sz_siz * like equality checks and relative order computing. */ #define SZ_CACHE_LINE_WIDTH (64) // bytes +#define SZ_SIZE_MAX ((sz_size_t)(-1)) +#define SZ_SSIZE_MAX ((sz_ssize_t)(SZ_SIZE_MAX >> 1)) + +SZ_INTERNAL sz_size_t _sz_size_max(void) { return SZ_SIZE_MAX; } +SZ_INTERNAL sz_ssize_t _sz_ssize_max(void) { return SZ_SSIZE_MAX; } /** * @brief Similar to `assert`, the `_sz_assert` is used in the `SZ_DEBUG` mode diff --git a/python/lib.c b/python/lib.c index 6e334719..46ed1c51 100644 --- a/python/lib.c +++ b/python/lib.c @@ -208,12 +208,12 @@ static sz_ptr_t temporary_memory_allocate(sz_size_t size, sz_string_view_t *exis static void temporary_memory_free(sz_ptr_t start, sz_size_t size, sz_string_view_t *existing) {} -static sz_cptr_t parts_get_start(sz_sequence_t *seq, sz_size_t i) { - return ((sz_string_view_t const *)seq->handle)[i].start; +static sz_cptr_t parts_get_start(void const *handle, sz_size_t i) { + return ((sz_string_view_t const *)handle)[i].start; } -static sz_size_t parts_get_length(sz_sequence_t *seq, sz_size_t i) { - return ((sz_string_view_t const *)seq->handle)[i].length; +static sz_size_t parts_get_length(void const *handle, sz_size_t i) { + return ((sz_string_view_t const *)handle)[i].length; } void reverse_offsets(sz_sorted_idx_t *array, size_t length) { @@ -236,7 +236,7 @@ void reverse_haystacks(sz_string_view_t *array, size_t length) { } } -void apply_order(sz_string_view_t *array, sz_sorted_idx_t *order, size_t length) { +void permute(sz_string_view_t *array, sz_sorted_idx_t *order, size_t length) { for (size_t i = 0; i < length; ++i) { if (i == order[i]) continue; sz_string_view_t temp = array[i]; @@ -682,7 +682,7 @@ static PyObject *Str_repr(Str *self) { } } -static Py_hash_t Str_hash(Str *self) { return (Py_hash_t)sz_hash(self->memory.start, self->memory.length); } +static Py_hash_t Str_hash(Str *self) { return (Py_hash_t)sz_hash(self->memory.start, self->memory.length, 0); } static char const doc_like_hash[] = // "Compute the hash value of the string.\n" @@ -713,7 +713,7 @@ static PyObject *Str_like_hash(PyObject *self, PyObject *args, PyObject *kwargs) return NULL; } - sz_u64_t result = sz_hash(text.start, text.length); + sz_u64_t result = sz_hash(text.start, text.length, 0); return PyLong_FromUnsignedLongLong((unsigned long long)result); } @@ -1837,7 +1837,8 @@ static PyObject *Str_count(PyObject *self, PyObject *args, PyObject *kwargs) { return PyLong_FromSize_t(count); } -static PyObject *_Str_edit_distance(PyObject *self, PyObject *args, PyObject *kwargs, sz_edit_distance_t function) { +static PyObject *_Str_levenshtein_distance(PyObject *self, PyObject *args, PyObject *kwargs, + sz_levenshtein_distance_t function) { int is_member = self != NULL && PyObject_TypeCheck(self, &StrType); Py_ssize_t nargs = PyTuple_Size(args); if (nargs < !is_member + 1 || nargs > !is_member + 2) { @@ -1877,10 +1878,12 @@ static PyObject *_Str_edit_distance(PyObject *self, PyObject *args, PyObject *kw reusing_allocator.free = &temporary_memory_free; reusing_allocator.handle = &temporary_memory; - sz_size_t distance = function(str1.start, str1.length, str2.start, str2.length, bound, &reusing_allocator); + sz_size_t distance; + sz_status_t status = + function(str1.start, str1.length, str2.start, str2.length, bound, &reusing_allocator, &distance); // Check for memory allocation issues - if (distance == SZ_SIZE_MAX) { + if (status != sz_success_k) { PyErr_NoMemory(); return NULL; } @@ -1888,7 +1891,7 @@ static PyObject *_Str_edit_distance(PyObject *self, PyObject *args, PyObject *kw return PyLong_FromSize_t(distance); } -static char const doc_edit_distance[] = // +static char const doc_levenshtein_distance[] = // "Compute the Levenshtein edit distance between two strings.\n" "\n" "Args:\n" @@ -1898,11 +1901,11 @@ static char const doc_edit_distance[] = // "Returns:\n" " int: The edit distance (number of insertions, deletions, substitutions)."; -static PyObject *Str_edit_distance(PyObject *self, PyObject *args, PyObject *kwargs) { - return _Str_edit_distance(self, args, kwargs, &sz_edit_distance); +static PyObject *Str_levenshtein_distance(PyObject *self, PyObject *args, PyObject *kwargs) { + return _Str_levenshtein_distance(self, args, kwargs, &sz_levenshtein_distance); } -static char const doc_edit_distance_unicode[] = // +static char const doc_levenshtein_distance_unicode[] = // "Compute the Levenshtein edit distance between two Unicode strings.\n" "\n" "Args:\n" @@ -1912,8 +1915,8 @@ static char const doc_edit_distance_unicode[] = // "Returns:\n" " int: The edit distance in Unicode characters."; -static PyObject *Str_edit_distance_unicode(PyObject *self, PyObject *args, PyObject *kwargs) { - return _Str_edit_distance(self, args, kwargs, &sz_edit_distance_utf8); +static PyObject *Str_levenshtein_distance_unicode(PyObject *self, PyObject *args, PyObject *kwargs) { + return _Str_levenshtein_distance(self, args, kwargs, &sz_levenshtein_distance_utf8); } static PyObject *_Str_hamming_distance(PyObject *self, PyObject *args, PyObject *kwargs, @@ -1951,10 +1954,11 @@ static PyObject *_Str_hamming_distance(PyObject *self, PyObject *args, PyObject return NULL; } - sz_size_t distance = function(str1.start, str1.length, str2.start, str2.length, (sz_size_t)bound); + sz_size_t distance; + sz_status_t status = function(str1.start, str1.length, str2.start, str2.length, (sz_size_t)bound, &distance); // Check for memory allocation issues - if (distance == SZ_SIZE_MAX) { + if (status != sz_success_k) { PyErr_NoMemory(); return NULL; } @@ -1990,7 +1994,7 @@ static PyObject *Str_hamming_distance_unicode(PyObject *self, PyObject *args, Py return _Str_hamming_distance(self, args, kwargs, &sz_hamming_distance_utf8); } -static char const doc_alignment_score[] = // +static char const doc_needleman_wunsch_score[] = // "Compute the Needleman-Wunsch alignment score between two strings.\n" "\n" "Args:\n" @@ -2002,7 +2006,7 @@ static char const doc_alignment_score[] = // "Returns:\n" " int: The alignment score."; -static PyObject *Str_alignment_score(PyObject *self, PyObject *args, PyObject *kwargs) { +static PyObject *Str_needleman_wunsch_score(PyObject *self, PyObject *args, PyObject *kwargs) { int is_member = self != NULL && PyObject_TypeCheck(self, &StrType); Py_ssize_t nargs = PyTuple_Size(args); if (nargs < !is_member + 1 || nargs > !is_member + 2) { @@ -2074,14 +2078,15 @@ static PyObject *Str_alignment_score(PyObject *self, PyObject *args, PyObject *k reusing_allocator.free = &temporary_memory_free; reusing_allocator.handle = &temporary_memory; - sz_ssize_t score = sz_alignment_score(str1.start, str1.length, str2.start, str2.length, substitutions, - (sz_error_cost_t)gap, &reusing_allocator); + sz_ssize_t score; + sz_status_t status = sz_needleman_wunsch_score(str1.start, str1.length, str2.start, str2.length, substitutions, + (sz_error_cost_t)gap, &reusing_allocator, &score); // Don't forget to release the buffer view PyBuffer_Release(&substitutions_view); // Check for memory allocation issues - if (score == SZ_SSIZE_MAX) { + if (status != sz_success_k) { PyErr_NoMemory(); return NULL; } @@ -2259,11 +2264,11 @@ static PyObject *Str_translate(PyObject *self, PyObject *args, PyObject *kwargs) } sz_string_view_t look_up_table_str; - SZ_ALIGN64 char look_up_table[256]; + _SZ_ALIGN64 char look_up_table[256]; if (PyDict_Check(look_up_table_obj)) { // If any character is not defined, it will be replaced with itself: - for (int i = 0; i < 256; i++) { look_up_table[i] = (char)i; } + for (int i = 0; i < 256; i++) look_up_table[i] = (char)i; // Process the dictionary into the look-up table PyObject *key, *value; @@ -2305,7 +2310,7 @@ static PyObject *Str_translate(PyObject *self, PyObject *args, PyObject *kwargs) // Perform the translation using the look-up table if (is_inplace) { - sz_look_up_transform(str.start, str.length, look_up_table, str.start); + sz_lookup(str.start, str.length, str.start, look_up_table); Py_RETURN_NONE; } // Allocate a string of the same size, get it's raw pointer and transform the data into it @@ -2321,7 +2326,7 @@ static PyObject *Str_translate(PyObject *self, PyObject *args, PyObject *kwargs) } sz_ptr_t new_buffer = (sz_ptr_t)PyUnicode_DATA(new_unicode_obj); - sz_look_up_transform(str.start, str.length, look_up_table, new_buffer); + sz_lookup(new_buffer, str.length, str.start, look_up_table); return new_unicode_obj; } else { @@ -2333,7 +2338,7 @@ static PyObject *Str_translate(PyObject *self, PyObject *args, PyObject *kwargs) // Get the buffer and perform the transformation sz_ptr_t new_buffer = (sz_ptr_t)PyBytes_AS_STRING(new_bytes_obj); - sz_look_up_transform(str.start, str.length, look_up_table, new_buffer); + sz_lookup(new_buffer, str.length, str.start, look_up_table); return new_bytes_obj; } } @@ -2354,7 +2359,7 @@ static PyObject *Str_find_first_of(PyObject *self, PyObject *args, PyObject *kwa Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_find_char_from, sz_false_k, &signed_offset, &text, + if (!_Str_find_implementation_(self, args, kwargs, &sz_find_byte_from, sz_false_k, &signed_offset, &text, &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); @@ -2375,7 +2380,7 @@ static PyObject *Str_find_first_not_of(PyObject *self, PyObject *args, PyObject Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_find_char_not_from, sz_false_k, &signed_offset, &text, + if (!_Str_find_implementation_(self, args, kwargs, &sz_find_byte_not_from, sz_false_k, &signed_offset, &text, &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); @@ -2396,7 +2401,7 @@ static PyObject *Str_find_last_of(PyObject *self, PyObject *args, PyObject *kwar Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_char_from, sz_true_k, &signed_offset, &text, + if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_byte_from, sz_true_k, &signed_offset, &text, &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); @@ -2417,7 +2422,7 @@ static PyObject *Str_find_last_not_of(PyObject *self, PyObject *args, PyObject * Py_ssize_t signed_offset; sz_string_view_t text; sz_string_view_t separator; - if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_char_not_from, sz_true_k, &signed_offset, &text, + if (!_Str_find_implementation_(self, args, kwargs, &sz_rfind_byte_not_from, sz_true_k, &signed_offset, &text, &separator)) return NULL; return PyLong_FromSsize_t(signed_offset); @@ -2456,7 +2461,7 @@ static SplitIterator *Str_split_iter_(PyObject *text_obj, PyObject *separator_ob /** * @brief Implements the normal order split logic for both string-delimiters and character sets. - * Produuces one of the consecutive layouts - `STRS_CONSECUTIVE_64` or `STRS_CONSECUTIVE_32`. + * Produces one of the consecutive layouts - `STRS_CONSECUTIVE_64` or `STRS_CONSECUTIVE_32`. */ static Strs *Str_split_(PyObject *parent_string, sz_string_view_t const text, sz_string_view_t const separator, int keepseparator, Py_ssize_t maxsplit, sz_find_t finder, sz_size_t match_length) { @@ -2544,7 +2549,7 @@ static Strs *Str_split_(PyObject *parent_string, sz_string_view_t const text, sz /** * @brief Implements the reverse order split logic for both string-delimiters and character sets. - * Unlike the `Str_split_` can't use consecutive layouts and produces a `REAORDERED` one. + * Unlike the `Str_split_` can't use consecutive layouts and produces a `REORDERED` one. */ static Strs *Str_rsplit_(PyObject *parent_string, sz_string_view_t const text, sz_string_view_t const separator, int keepseparator, Py_ssize_t maxsplit, sz_find_t finder, sz_size_t match_length) { @@ -2622,7 +2627,7 @@ static Strs *Str_rsplit_(PyObject *parent_string, sz_string_view_t const text, s } /** - * @brief Proxy routing requests like `Str.split`, `Str.rsplit`, `Str.split_charset` and `Str.rsplit_charset` + * @brief Proxy routing requests like `Str.split`, `Str.rsplit`, `Str.split_byteset` and `Str.rsplit_byteset` * to `Str_split_` and `Str_rsplit_` implementations, parsing function arguments. */ static PyObject *Str_split_with_known_callback(PyObject *self, PyObject *args, PyObject *kwargs, // @@ -2747,7 +2752,7 @@ static PyObject *Str_rsplit(PyObject *self, PyObject *args, PyObject *kwargs) { return Str_split_with_known_callback(self, args, kwargs, &sz_rfind, 0, sz_true_k, sz_false_k); } -static char const doc_split_charset[] = // +static char const doc_split_byteset[] = // "Split a string by a set of character separators.\n" "\n" "Args:\n" @@ -2758,11 +2763,11 @@ static char const doc_split_charset[] = // "Returns:\n" " Strs: A list of strings split by the character set."; -static PyObject *Str_split_charset(PyObject *self, PyObject *args, PyObject *kwargs) { - return Str_split_with_known_callback(self, args, kwargs, &sz_find_char_from, 1, sz_false_k, sz_false_k); +static PyObject *Str_split_byteset(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_find_byte_from, 1, sz_false_k, sz_false_k); } -static char const doc_rsplit_charset[] = // +static char const doc_rsplit_byteset[] = // "Split a string by a set of character separators in reverse order.\n" "\n" "Args:\n" @@ -2773,8 +2778,8 @@ static char const doc_rsplit_charset[] = // "Returns:\n" " Strs: A list of strings split by the character set."; -static PyObject *Str_rsplit_charset(PyObject *self, PyObject *args, PyObject *kwargs) { - return Str_split_with_known_callback(self, args, kwargs, &sz_rfind_char_from, 1, sz_true_k, sz_false_k); +static PyObject *Str_rsplit_byteset(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_rfind_byte_from, 1, sz_true_k, sz_false_k); } static char const doc_split_iter[] = // @@ -2809,7 +2814,7 @@ static PyObject *Str_rsplit_iter(PyObject *self, PyObject *args, PyObject *kwarg return Str_split_with_known_callback(self, args, kwargs, &sz_rfind, 0, sz_true_k, sz_true_k); } -static char const doc_split_charset_iter[] = // +static char const doc_split_byteset_iter[] = // "Create an iterator for splitting a string by a set of character separators.\n" "\n" "Args:\n" @@ -2819,11 +2824,11 @@ static char const doc_split_charset_iter[] = // "Returns:\n" " iterator: An iterator yielding split substrings."; -static PyObject *Str_split_charset_iter(PyObject *self, PyObject *args, PyObject *kwargs) { - return Str_split_with_known_callback(self, args, kwargs, &sz_find_char_from, 1, sz_false_k, sz_true_k); +static PyObject *Str_split_byteset_iter(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_find_byte_from, 1, sz_false_k, sz_true_k); } -static char const doc_rsplit_charset_iter[] = // +static char const doc_rsplit_byteset_iter[] = // "Create an iterator for splitting a string by a set of character separators in reverse order.\n" "\n" "Args:\n" @@ -2833,8 +2838,8 @@ static char const doc_rsplit_charset_iter[] = // "Returns:\n" " iterator: An iterator yielding split substrings in reverse."; -static PyObject *Str_rsplit_charset_iter(PyObject *self, PyObject *args, PyObject *kwargs) { - return Str_split_with_known_callback(self, args, kwargs, &sz_rfind_char_from, 1, sz_true_k, sz_true_k); +static PyObject *Str_rsplit_byteset_iter(PyObject *self, PyObject *args, PyObject *kwargs) { + return Str_split_with_known_callback(self, args, kwargs, &sz_rfind_byte_from, 1, sz_true_k, sz_true_k); } static char const doc_splitlines[] = // @@ -2924,7 +2929,7 @@ static PyObject *Str_splitlines(PyObject *self, PyObject *args, PyObject *kwargs sz_string_view_t separator; separator.start = "\x0A\x0B\x0C\x0D\x85\x1C\x1D\x1E"; separator.length = 8; - return Str_split_(text_obj, text, separator, keeplinebreaks, maxsplit, &sz_find_char_from, 1); + return Str_split_(text_obj, text, separator, keeplinebreaks, maxsplit, &sz_find_byte_from, 1); } static PyObject *Str_concat(PyObject *self, PyObject *other) { @@ -3011,23 +3016,24 @@ static PyMethodDef Str_methods[] = { {"hamming_distance", (PyCFunction)Str_hamming_distance, SZ_METHOD_FLAGS, doc_hamming_distance}, {"hamming_distance_unicode", (PyCFunction)Str_hamming_distance_unicode, SZ_METHOD_FLAGS, doc_hamming_distance_unicode}, - {"edit_distance", (PyCFunction)Str_edit_distance, SZ_METHOD_FLAGS, doc_edit_distance}, - {"edit_distance_unicode", (PyCFunction)Str_edit_distance_unicode, SZ_METHOD_FLAGS, doc_edit_distance_unicode}, - {"alignment_score", (PyCFunction)Str_alignment_score, SZ_METHOD_FLAGS, doc_alignment_score}, + {"levenshtein_distance", (PyCFunction)Str_levenshtein_distance, SZ_METHOD_FLAGS, doc_levenshtein_distance}, + {"levenshtein_distance_unicode", (PyCFunction)Str_levenshtein_distance_unicode, SZ_METHOD_FLAGS, + doc_levenshtein_distance_unicode}, + {"needleman_wunsch_score", (PyCFunction)Str_needleman_wunsch_score, SZ_METHOD_FLAGS, doc_needleman_wunsch_score}, // Character search extensions {"find_first_of", (PyCFunction)Str_find_first_of, SZ_METHOD_FLAGS, doc_find_first_of}, {"find_last_of", (PyCFunction)Str_find_last_of, SZ_METHOD_FLAGS, doc_find_last_of}, {"find_first_not_of", (PyCFunction)Str_find_first_not_of, SZ_METHOD_FLAGS, doc_find_first_not_of}, {"find_last_not_of", (PyCFunction)Str_find_last_not_of, SZ_METHOD_FLAGS, doc_find_last_not_of}, - {"split_charset", (PyCFunction)Str_split_charset, SZ_METHOD_FLAGS, doc_split_charset}, - {"rsplit_charset", (PyCFunction)Str_rsplit_charset, SZ_METHOD_FLAGS, doc_rsplit_charset}, + {"split_byteset", (PyCFunction)Str_split_byteset, SZ_METHOD_FLAGS, doc_split_byteset}, + {"rsplit_byteset", (PyCFunction)Str_rsplit_byteset, SZ_METHOD_FLAGS, doc_rsplit_byteset}, // Lazily evaluated iterators {"split_iter", (PyCFunction)Str_split_iter, SZ_METHOD_FLAGS, doc_split_iter}, {"rsplit_iter", (PyCFunction)Str_rsplit_iter, SZ_METHOD_FLAGS, doc_rsplit_iter}, - {"split_charset_iter", (PyCFunction)Str_split_charset_iter, SZ_METHOD_FLAGS, doc_split_charset_iter}, - {"rsplit_charset_iter", (PyCFunction)Str_rsplit_charset_iter, SZ_METHOD_FLAGS, doc_rsplit_charset_iter}, + {"split_byteset_iter", (PyCFunction)Str_split_byteset_iter, SZ_METHOD_FLAGS, doc_split_byteset_iter}, + {"rsplit_byteset_iter", (PyCFunction)Str_rsplit_byteset_iter, SZ_METHOD_FLAGS, doc_rsplit_byteset_iter}, // Dealing with larger-than-memory datasets {"offset_within", (PyCFunction)Str_offset_within, SZ_METHOD_FLAGS, doc_offset_within}, @@ -3181,8 +3187,8 @@ static PyObject *Strs_shuffle(Strs *self, PyObject *args, PyObject *kwargs) { Py_RETURN_NONE; } -static sz_bool_t Strs_sort_(Strs *self, sz_string_view_t **parts_output, sz_sorted_idx_t **order_output, - sz_size_t *count_output) { +static sz_bool_t Strs_argsort_(Strs *self, sz_string_view_t **parts_output, sz_sorted_idx_t **order_output, + sz_size_t *count_output) { // Change the layout if (!prepare_strings_for_reordering(self)) { PyErr_Format(PyExc_TypeError, "Failed to prepare the sequence for sorting"); @@ -3208,17 +3214,15 @@ static sz_bool_t Strs_sort_(Strs *self, sz_string_view_t **parts_output, sz_sort // Call our sorting algorithm sz_sequence_t sequence; sz_fill(&sequence, sizeof(sequence), 0); - sequence.order = (sz_sorted_idx_t *)temporary_memory.start; sequence.count = count; sequence.handle = parts; sequence.get_start = parts_get_start; sequence.get_length = parts_get_length; - for (sz_sorted_idx_t i = 0; i != sequence.count; ++i) sequence.order[i] = i; - sz_sequence_argsort(&sequence); + sz_status_t status = sz_sequence_argsort(&sequence, NULL, (sz_sorted_idx_t *)temporary_memory.start); // Export results *parts_output = parts; - *order_output = sequence.order; + *order_output = (sz_sorted_idx_t *)temporary_memory.start; *count_output = sequence.count; return 1; } @@ -3256,18 +3260,18 @@ static PyObject *Strs_sort(Strs *self, PyObject *args, PyObject *kwargs) { sz_string_view_t *parts = NULL; sz_size_t *order = NULL; sz_size_t count = 0; - if (!Strs_sort_(self, &parts, &order, &count)) return NULL; + if (!Strs_argsort_(self, &parts, &order, &count)) return NULL; // Apply the sorting algorithm here, considering the `reverse` value if (reverse) reverse_offsets(order, count); // Apply the new order. - apply_order(parts, order, count); + permute(parts, order, count); Py_RETURN_NONE; } -static PyObject *Strs_order(Strs *self, PyObject *args, PyObject *kwargs) { +static PyObject *Strs_argsort(Strs *self, PyObject *args, PyObject *kwargs) { PyObject *reverse_obj = NULL; // Default is not reversed // Check for positional arguments @@ -3300,7 +3304,7 @@ static PyObject *Strs_order(Strs *self, PyObject *args, PyObject *kwargs) { sz_string_view_t *parts = NULL; sz_sorted_idx_t *order = NULL; sz_size_t count = 0; - if (!Strs_sort_(self, &parts, &order, &count)) return NULL; + if (!Strs_argsort_(self, &parts, &order, &count)) return NULL; // Apply the sorting algorithm here, considering the `reverse` value if (reverse) reverse_offsets(order, count); @@ -3606,11 +3610,11 @@ static PyGetSetDef Strs_getsetters[] = { static PyMethodDef Strs_methods[] = { {"shuffle", Strs_shuffle, SZ_METHOD_FLAGS, "Shuffle (in-place) the elements of the Strs object."}, // {"sort", Strs_sort, SZ_METHOD_FLAGS, "Sort (in-place) the elements of the Strs object."}, // - {"order", Strs_order, SZ_METHOD_FLAGS, "Provides the indexes to achieve sorted order."}, // + {"argsort", Strs_argsort, SZ_METHOD_FLAGS, "Provides the permutation to achieve sorted order."}, // {"sample", Strs_sample, SZ_METHOD_FLAGS, "Provides a random sample of a given size."}, // - // {"to_pylist", Strs_to_pylist, SZ_METHOD_FLAGS, "Exports string-views to a native list of native strings."}, - // // - {NULL, NULL, 0, NULL}}; + // {"to_pylist", Strs_to_pylist, SZ_METHOD_FLAGS, "Exports string-views to a native list of native strings."}, // + {NULL, NULL, 0, NULL} // Sentinel +}; static PyTypeObject StrsType = { PyVarObject_HEAD_INIT(NULL, 0).tp_name = "stringzilla.Strs", @@ -3660,23 +3664,24 @@ static PyMethodDef stringzilla_methods[] = { // Edit distance extensions {"hamming_distance", Str_hamming_distance, SZ_METHOD_FLAGS, doc_hamming_distance}, {"hamming_distance_unicode", Str_hamming_distance_unicode, SZ_METHOD_FLAGS, doc_hamming_distance_unicode}, - {"edit_distance", Str_edit_distance, SZ_METHOD_FLAGS, doc_edit_distance}, - {"edit_distance_unicode", Str_edit_distance_unicode, SZ_METHOD_FLAGS, doc_edit_distance_unicode}, - {"alignment_score", Str_alignment_score, SZ_METHOD_FLAGS, doc_alignment_score}, + {"levenshtein_distance", Str_levenshtein_distance, SZ_METHOD_FLAGS, doc_levenshtein_distance}, + {"levenshtein_distance_unicode", Str_levenshtein_distance_unicode, SZ_METHOD_FLAGS, + doc_levenshtein_distance_unicode}, + {"needleman_wunsch_score", Str_needleman_wunsch_score, SZ_METHOD_FLAGS, doc_needleman_wunsch_score}, // Character search extensions {"find_first_of", Str_find_first_of, SZ_METHOD_FLAGS, doc_find_first_of}, {"find_last_of", Str_find_last_of, SZ_METHOD_FLAGS, doc_find_last_of}, {"find_first_not_of", Str_find_first_not_of, SZ_METHOD_FLAGS, doc_find_first_not_of}, {"find_last_not_of", Str_find_last_not_of, SZ_METHOD_FLAGS, doc_find_last_not_of}, - {"split_charset", Str_split_charset, SZ_METHOD_FLAGS, doc_split_charset}, - {"rsplit_charset", Str_rsplit_charset, SZ_METHOD_FLAGS, doc_rsplit_charset}, + {"split_byteset", Str_split_byteset, SZ_METHOD_FLAGS, doc_split_byteset}, + {"rsplit_byteset", Str_rsplit_byteset, SZ_METHOD_FLAGS, doc_rsplit_byteset}, // Lazily evaluated iterators {"split_iter", Str_split_iter, SZ_METHOD_FLAGS, doc_split_iter}, {"rsplit_iter", Str_rsplit_iter, SZ_METHOD_FLAGS, doc_rsplit_iter}, - {"split_charset_iter", Str_split_charset_iter, SZ_METHOD_FLAGS, doc_split_charset_iter}, - {"rsplit_charset_iter", Str_rsplit_charset_iter, SZ_METHOD_FLAGS, doc_rsplit_charset_iter}, + {"split_byteset_iter", Str_split_byteset_iter, SZ_METHOD_FLAGS, doc_split_byteset_iter}, + {"rsplit_byteset_iter", Str_rsplit_byteset_iter, SZ_METHOD_FLAGS, doc_rsplit_byteset_iter}, // Dealing with larger-than-memory datasets {"offset_within", Str_offset_within, SZ_METHOD_FLAGS, doc_offset_within}, @@ -3714,8 +3719,7 @@ PyMODINIT_FUNC PyInit_stringzilla(void) { // Add version metadata { char version_str[50]; - sprintf(version_str, "%d.%d.%d", STRINGZILLA_VERSION_MAJOR, STRINGZILLA_VERSION_MINOR, - STRINGZILLA_VERSION_PATCH); + sprintf(version_str, "%d.%d.%d", sz_version_major(), sz_version_minor(), sz_version_patch()); PyModule_AddStringConstant(m, "__version__", version_str); } @@ -3724,17 +3728,18 @@ PyMODINIT_FUNC PyInit_stringzilla(void) { sz_capability_t caps = sz_capabilities(); char caps_str[512]; char const *serial = (caps & sz_cap_serial_k) ? "serial," : ""; - char const *neon = (caps & sz_cap_arm_neon_k) ? "neon," : ""; - char const *sve = (caps & sz_cap_arm_sve_k) ? "sve," : ""; - char const *avx2 = (caps & sz_cap_x86_avx2_k) ? "avx2," : ""; - char const *avx512f = (caps & sz_cap_x86_avx512f_k) ? "avx512f," : ""; - char const *avx512vl = (caps & sz_cap_x86_avx512vl_k) ? "avx512vl," : ""; - char const *avx512bw = (caps & sz_cap_x86_avx512bw_k) ? "avx512bw," : ""; - char const *avx512vbmi = (caps & sz_cap_x86_avx512vbmi_k) ? "avx512vbmi," : ""; - char const *gfni = (caps & sz_cap_x86_gfni_k) ? "gfni," : ""; - char const *avx512vbmi2 = (caps & sz_cap_x86_avx512vbmi2_k) ? "avx512vbmi2," : ""; - sprintf(caps_str, "%s%s%s%s%s%s%s%s%s%s", serial, neon, sve, avx2, avx512f, avx512vl, avx512bw, avx512vbmi, - avx512vbmi2, gfni); + char const *neon = (caps & sz_cap_neon_k) ? "neon," : ""; + char const *neon_aes = (caps & sz_cap_neon_aes_k) ? "neon_aes," : ""; + char const *sve = (caps & sz_cap_sve_k) ? "sve," : ""; + char const *sve2 = (caps & sz_cap_sve2_k) ? "sve2," : ""; + char const *sve2_aes = (caps & sz_cap_sve2_aes_k) ? "sve2_aes," : ""; + char const *haswell = (caps & sz_cap_haswell_k) ? "haswell," : ""; + char const *skylake = (caps & sz_cap_skylake_k) ? "skylake," : ""; + char const *ice = (caps & sz_cap_ice_k) ? "ice," : ""; + sprintf(caps_str, "%s%s%s%s%s%s%s%s%s", // + serial, // + neon, neon_aes, sve, sve2, sve2_aes, // + haswell, skylake, ice); PyModule_AddStringConstant(m, "__capabilities__", caps_str); } diff --git a/rust/lib.rs b/rust/lib.rs index d9d4e237..d5e9a682 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -8,6 +8,61 @@ pub mod sz { + #[repr(C)] + #[derive(Debug, PartialEq)] + pub enum Status { + Success = 0, + BadAlloc = -1, + InvalidUtf8 = -2, + ContainsDuplicates = -3, + } + + #[repr(C)] + #[derive(Debug, Clone, Copy)] + pub struct Byteset { + bits: [u64; 4], + } + + impl Byteset { + /// Initializes a bit‑set to an empty collection (all characters banned). + #[inline] + pub fn new() -> Self { + Self { bits: [0; 4] } + } + + /// Initializes a bit‑set to contain all ASCII characters. + #[inline] + pub fn new_ascii() -> Self { + Self { + bits: [u64::MAX, u64::MAX, 0, 0], + } + } + + /// Adds a byte to the set. + #[inline] + pub fn add_u8(&mut self, c: u8) { + let idx = (c >> 6) as usize; // Divide by 64. + let bit = c & 63; // Remainder modulo 64. + self.bits[idx] |= 1 << bit; + } + + /// Adds a character to the set. + /// + /// This function assumes the character is in the ASCII range. + #[inline] + pub fn add(&mut self, c: char) { + self.add_u8(c as u8); + } + + /// Inverts the bit-set so that all set bits become unset and vice versa. + #[inline] + pub fn invert(&mut self) { + for b in self.bits.iter_mut() { + *b = !*b; + } + } + } + use core::{ffi::c_void, usize}; // Import the functions from the StringZilla C library. @@ -26,83 +81,64 @@ pub mod sz { needle_length: usize, ) -> *const c_void; - fn sz_find_char_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; + fn sz_find_byteset(haystack: *const c_void, haystack_length: usize, byteset: *const c_void) -> *const c_void; - fn sz_rfind_char_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_find_char_not_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; - - fn sz_rfind_char_not_from( - haystack: *const c_void, - haystack_length: usize, - needle: *const c_void, - needle_length: usize, - ) -> *const c_void; + fn sz_rfind_byteset(haystack: *const c_void, haystack_length: usize, byteset: *const c_void) -> *const c_void; fn sz_bytesum(text: *const c_void, length: usize) -> u64; fn sz_hash(text: *const c_void, length: usize, seed: u64) -> u64; - fn sz_generate(text: *mut c_void, length: usize, seed: u64) -> u64; + fn sz_fill_random(text: *mut c_void, length: usize, seed: u64); - fn sz_edit_distance( - haystack1: *const c_void, - haystack1_length: usize, - haystack2: *const c_void, - haystack2_length: usize, + pub fn sz_levenshtein_distance( + a: *const c_void, + a_length: usize, + b: *const c_void, + b_length: usize, bound: usize, - allocator: *const c_void, - ) -> usize; - - fn sz_edit_distance_utf8( - haystack1: *const c_void, - haystack1_length: usize, - haystack2: *const c_void, - haystack2_length: usize, + alloc: *const c_void, + result: *mut usize, + ) -> Status; + + pub fn sz_levenshtein_distance_utf8( + a: *const c_void, + a_length: usize, + b: *const c_void, + b_length: usize, bound: usize, - allocator: *const c_void, - ) -> usize; - - fn sz_hamming_distance( - haystack1: *const c_void, - haystack1_length: usize, - haystack2: *const c_void, - haystack2_length: usize, + alloc: *const c_void, + result: *mut usize, + ) -> Status; + + pub fn sz_hamming_distance( + a: *const c_void, + a_length: usize, + b: *const c_void, + b_length: usize, bound: usize, - ) -> usize; - - fn sz_hamming_distance_utf8( - haystack1: *const c_void, - haystack1_length: usize, - haystack2: *const c_void, - haystack2_length: usize, + result: *mut usize, + ) -> Status; + + pub fn sz_hamming_distance_utf8( + a: *const c_void, + a_length: usize, + b: *const c_void, + b_length: usize, bound: usize, - ) -> usize; - - fn sz_alignment_score( - haystack1: *const c_void, - haystack1_length: usize, - haystack2: *const c_void, - haystack2_length: usize, - matrix: *const c_void, + result: *mut usize, + ) -> Status; + + pub fn sz_needleman_wunsch_score( + a: *const c_void, + a_length: usize, + b: *const c_void, + b_length: usize, + subs: *const i8, gap: i8, - allocator: *const c_void, - ) -> isize; + alloc: *const c_void, + result: *mut isize, + ) -> Status; } @@ -136,21 +172,41 @@ pub mod sz { /// # Arguments /// /// * `text`: The byte slice to compute the checksum for. + /// * `seed` - A 64-bit value that acts as the seed for the hash function. /// /// # Returns /// /// A `u64` representing the hash value of the input byte slice. - pub fn hash(text: T) -> u64 + pub fn hash_with_seed(text: T, seed: u64) -> u64 where T: AsRef<[u8]>, { let text_ref = text.as_ref(); let text_pointer = text_ref.as_ptr() as _; let text_length = text_ref.len(); - let result = unsafe { sz_hash(text_pointer, text_length) }; + let result = unsafe { sz_hash(text_pointer, text_length, seed) }; return result; } + /// Computes a 64-bit AES-based hash value for a given byte slice `text`. + /// This function is designed to provide a high-quality hash value for use in + /// hash tables, data structures, and cryptographic applications. + /// Unlike the bytesum function, the hash function is order-sensitive. + /// + /// # Arguments + /// + /// * `text`: The byte slice to compute the checksum for. + /// + /// # Returns + /// + /// A `u64` representing the hash value of the input byte slice. + pub fn hash(text: T) -> u64 + where + T: AsRef<[u8]>, + { + hash_with_seed(text, 0) + } + /// Locates the first matching substring within `haystack` that equals `needle`. /// This function is similar to the `memmem()` function in LibC, but, unlike `strstr()`, /// it requires the length of both haystack and needle to be known beforehand. @@ -175,14 +231,7 @@ pub mod sz { let haystack_length = haystack_ref.len(); let needle_pointer = needle_ref.as_ptr() as _; let needle_length = needle_ref.len(); - let result = unsafe { - sz_find( - haystack_pointer, - haystack_length, - needle_pointer, - needle_length, - ) - }; + let result = unsafe { sz_find(haystack_pointer, haystack_length, needle_pointer, needle_length) }; if result.is_null() { None @@ -215,14 +264,7 @@ pub mod sz { let haystack_length = haystack_ref.len(); let needle_pointer = needle_ref.as_ptr() as _; let needle_length = needle_ref.len(); - let result = unsafe { - sz_rfind( - haystack_pointer, - haystack_length, - needle_pointer, - needle_length, - ) - }; + let result = unsafe { sz_rfind(haystack_pointer, haystack_length, needle_pointer, needle_length) }; if result.is_null() { None @@ -244,7 +286,7 @@ pub mod sz { /// /// An `Option` representing the index of the first occurrence of any byte from /// `needles` within `haystack`, if found, otherwise `None`. - pub fn find_char_from(haystack: H, needles: N) -> Option + pub fn find_byte_from(haystack: H, needles: N) -> Option where H: AsRef<[u8]>, N: AsRef<[u8]>, @@ -253,16 +295,13 @@ pub mod sz { let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); - let needles_pointer = needles_ref.as_ptr() as _; - let needles_length = needles_ref.len(); - let result = unsafe { - sz_find_char_from( - haystack_pointer, - haystack_length, - needles_pointer, - needles_length, - ) - }; + let mut byteset = Byteset::new(); + for &b in needles_ref { + byteset.add_u8(b); + } + + let result = + unsafe { sz_find_byteset(haystack_pointer, haystack_length, &byteset as *const _ as *const c_void) }; if result.is_null() { None } else { @@ -283,7 +322,7 @@ pub mod sz { /// /// An `Option` representing the index of the last occurrence of any byte from /// `needles` within `haystack`, if found, otherwise `None`. - pub fn rfind_char_from(haystack: H, needles: N) -> Option + pub fn rfind_byte_from(haystack: H, needles: N) -> Option where H: AsRef<[u8]>, N: AsRef<[u8]>, @@ -292,16 +331,13 @@ pub mod sz { let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); - let needles_pointer = needles_ref.as_ptr() as _; - let needles_length = needles_ref.len(); - let result = unsafe { - sz_rfind_char_from( - haystack_pointer, - haystack_length, - needles_pointer, - needles_length, - ) - }; + let mut byteset = Byteset::new(); + for &b in needles_ref { + byteset.add_u8(b); + } + + let result = + unsafe { sz_rfind_byteset(haystack_pointer, haystack_length, &byteset as *const _ as *const c_void) }; if result.is_null() { None } else { @@ -322,7 +358,7 @@ pub mod sz { /// /// An `Option` representing the index of the first occurrence of any byte not in /// `needles` within `haystack`, if found, otherwise `None`. - pub fn find_char_not_from(haystack: H, needles: N) -> Option + pub fn find_byte_not_from(haystack: H, needles: N) -> Option where H: AsRef<[u8]>, N: AsRef<[u8]>, @@ -331,16 +367,14 @@ pub mod sz { let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); - let needles_pointer = needles_ref.as_ptr() as _; - let needles_length = needles_ref.len(); - let result = unsafe { - sz_find_char_not_from( - haystack_pointer, - haystack_length, - needles_pointer, - needles_length, - ) - }; + let mut byteset = Byteset::new(); + for &b in needles_ref { + byteset.add_u8(b); + } + byteset.invert(); + + let result = + unsafe { sz_find_byteset(haystack_pointer, haystack_length, &byteset as *const _ as *const c_void) }; if result.is_null() { None } else { @@ -361,7 +395,7 @@ pub mod sz { /// /// An `Option` representing the index of the last occurrence of any byte not in /// `needles` within `haystack`, if found, otherwise `None`. - pub fn rfind_char_not_from(haystack: H, needles: N) -> Option + pub fn rfind_byte_not_from(haystack: H, needles: N) -> Option where H: AsRef<[u8]>, N: AsRef<[u8]>, @@ -370,16 +404,14 @@ pub mod sz { let needles_ref = needles.as_ref(); let haystack_pointer = haystack_ref.as_ptr() as _; let haystack_length = haystack_ref.len(); - let needles_pointer = needles_ref.as_ptr() as _; - let needles_length = needles_ref.len(); - let result = unsafe { - sz_rfind_char_not_from( - haystack_pointer, - haystack_length, - needles_pointer, - needles_length, - ) - }; + let mut byteset = Byteset::new(); + for &b in needles_ref { + byteset.add_u8(b); + } + byteset.invert(); + + let result = + unsafe { sz_rfind_byteset(haystack_pointer, haystack_length, &byteset as *const _ as *const c_void) }; if result.is_null() { None } else { @@ -401,7 +433,7 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (insertions, /// deletions, or substitutions) required to change `first` into `second`. - pub fn edit_distance_bounded(first: F, second: S, bound: usize) -> usize + pub fn levenshtein_distance_bounded(first: F, second: S, bound: usize) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, @@ -412,19 +444,22 @@ pub mod sz { let second_length = second_ref.len(); let first_pointer = first_ref.as_ptr() as _; let second_pointer = second_ref.as_ptr() as _; - unsafe { - sz_edit_distance( + let mut result: usize = 0; + let status = unsafe { + sz_levenshtein_distance( first_pointer, first_length, second_pointer, second_length, - // Upper bound on the distance, that allows us to exit early. If zero is - // passed, the maximum possible distance will be equal to the length of - // the longer input. bound, - // Uses the default allocator - core::ptr::null(), + core::ptr::null(), // Uses the default allocator + &mut result as *mut _, ) + }; + if status == Status::Success { + Ok(result) + } else { + Err(status) } } @@ -441,7 +476,7 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (insertions, /// deletions, or substitutions) required to change `first` into `second`. - pub fn edit_distance_utf8_bounded(first: F, second: S, bound: usize) -> usize + pub fn levenshtein_distance_utf8_bounded(first: F, second: S, bound: usize) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, @@ -452,19 +487,22 @@ pub mod sz { let second_length = second_ref.len(); let first_pointer = first_ref.as_ptr() as _; let second_pointer = second_ref.as_ptr() as _; - unsafe { - sz_edit_distance_utf8( + let mut result: usize = 0; + let status = unsafe { + sz_levenshtein_distance_utf8( first_pointer, first_length, second_pointer, second_length, - // Upper bound on the distance, that allows us to exit early. If zero is - // passed, the maximum possible distance will be equal to the length of - // the longer input. bound, - // Uses the default allocator - core::ptr::null(), + core::ptr::null(), // Uses the default allocator + &mut result as *mut _, ) + }; + if status == Status::Success { + Ok(result) + } else { + Err(status) } } @@ -481,12 +519,12 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (insertions, /// deletions, or substitutions) required to change `first` into `second`. - pub fn edit_distance(first: F, second: S) -> usize + pub fn levenshtein_distance(first: F, second: S) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, { - edit_distance_bounded(first, second, usize::MAX) + levenshtein_distance_bounded(first, second, usize::MAX) } /// Computes the Levenshtein edit distance between two UTF8 strings, using the Wagner-Fisher @@ -501,12 +539,12 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (insertions, /// deletions, or substitutions) required to change `first` into `second`. - pub fn edit_distance_utf8(first: F, second: S) -> usize + pub fn levenshtein_distance_utf8(first: F, second: S) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, { - edit_distance_utf8_bounded(first, second, usize::MAX) + levenshtein_distance_utf8_bounded(first, second, usize::MAX) } /// Computes the Hamming edit distance between two strings, counting the number of substituted characters. @@ -522,7 +560,7 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (substitutions) required to /// change `first` into `second`. - pub fn hamming_distance_bounded(first: F, second: S, bound: usize) -> usize + pub fn hamming_distance_bounded(first: F, second: S, bound: usize) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, @@ -533,17 +571,21 @@ pub mod sz { let second_length = second_ref.len(); let first_pointer = first_ref.as_ptr() as _; let second_pointer = second_ref.as_ptr() as _; - unsafe { + let mut result: usize = 0; + let status = unsafe { sz_hamming_distance( first_pointer, first_length, second_pointer, second_length, - // Upper bound on the distance, that allows us to exit early. If zero is - // passed, the maximum possible distance will be equal to the length of - // the longer input. bound, + &mut result as *mut _, ) + }; + if status == Status::Success { + Ok(result) + } else { + Err(status) } } @@ -560,7 +602,7 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (substitutions) required to /// change `first` into `second`. - pub fn hamming_distance_utf8_bounded(first: F, second: S, bound: usize) -> usize + pub fn hamming_distance_utf8_bounded(first: F, second: S, bound: usize) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, @@ -571,17 +613,21 @@ pub mod sz { let second_length = second_ref.len(); let first_pointer = first_ref.as_ptr() as _; let second_pointer = second_ref.as_ptr() as _; - unsafe { + let mut result: usize = 0; + let status = unsafe { sz_hamming_distance_utf8( first_pointer, first_length, second_pointer, second_length, - // Upper bound on the distance, that allows us to exit early. If zero is - // passed, the maximum possible distance will be equal to the length of - // the longer input. bound, + &mut result as *mut _, ) + }; + if status == Status::Success { + Ok(result) + } else { + Err(status) } } @@ -597,7 +643,7 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (substitutions) required to /// change `first` into `second`. - pub fn hamming_distance(first: F, second: S) -> usize + pub fn hamming_distance(first: F, second: S) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, @@ -617,7 +663,7 @@ pub mod sz { /// /// A `usize` representing the minimum number of single-character edits (substitutions) required to /// change `first` into `second`. - pub fn hamming_distance_utf8(first: F, second: S) -> usize + pub fn hamming_distance_utf8(first: F, second: S) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, @@ -642,7 +688,7 @@ pub mod sz { /// An `isize` representing the total alignment score, where higher scores indicate better /// alignment between the two strings, considering the specified gap penalties and /// substitution matrix. - pub fn alignment_score(first: F, second: S, matrix: [[i8; 256]; 256], gap: i8) -> isize + pub fn alignment_score(first: F, second: S, matrix: [[i8; 256]; 256], gap: i8) -> Result where F: AsRef<[u8]>, S: AsRef<[u8]>, @@ -653,16 +699,23 @@ pub mod sz { let second_length = second_ref.len(); let first_pointer = first_ref.as_ptr() as _; let second_pointer = second_ref.as_ptr() as _; - unsafe { - sz_alignment_score( + let mut result: isize = 0; + let status = unsafe { + sz_needleman_wunsch_score( first_pointer, first_length, second_pointer, second_length, matrix.as_ptr() as _, gap, - core::ptr::null(), + core::ptr::null(), // Uses the default allocator + &mut result as *mut _, ) + }; + if status == Status::Success { + Ok(result) + } else { + Err(status) } } @@ -697,42 +750,27 @@ pub mod sz { /// you need to generate random strings or data sequences based on a specific set /// of characters, such as generating random DNA sequences or testing inputs. /// - /// # Type Parameters - /// - /// * `T`: The type of the text to be randomized. Must be mutable and convertible to a byte slice. - /// * `A`: The type of the alphabet. Must be convertible to a byte slice. - /// /// # Arguments /// - /// * `text`: A mutable reference to the data to randomize. This data will be mutated in place. - /// * `alphabet`: A reference to the byte slice representing the alphabet to use for randomization. + /// * `buffer`: A mutable reference to the data to randomize. This data will be mutated in place. + /// * `nonce`: A 64-bit "number used once" (nonce) value to seed the random number generator. /// /// # Examples /// /// ``` /// use stringzilla::sz; - /// let mut my_text = vec![0; 10]; // A buffer to randomize - /// let alphabet = b"ACTG"; // Using a DNA alphabet - /// sz::randomize(&mut my_text, &alphabet); + /// let mut buffer = vec![0; 10]; + /// sz::fill_random(&mut buffer, 42); /// ``` /// - /// After than, `my_text` is filled with random 'A', 'C', 'T', or 'G' values. - pub fn randomize(text: &mut T, alphabet: &A) + /// After than, `buffer` is filled with random byte values from 0 to 255. + pub fn fill_random(buffer: &mut T, nonce: u64) where T: AsMut<[u8]> + ?Sized, // Allows for mutable references to dynamically sized types. - A: AsRef<[u8]> + ?Sized, // Allows for references to dynamically sized types. { - let text_slice = text.as_mut(); - let alphabet_slice = alphabet.as_ref(); + let buffer_slice = buffer.as_mut(); unsafe { - sz_generate( - alphabet_slice.as_ptr() as *const c_void, - alphabet_slice.len(), - text_slice.as_mut_ptr() as *mut c_void, - text_slice.len(), - core::ptr::null(), - core::ptr::null_mut(), - ); + sz_fill_random(buffer_slice.as_ptr() as _, buffer_slice.len(), nonce); } } } @@ -757,10 +795,10 @@ impl<'a> Matcher<'a> for MatcherType<'a> { match self { MatcherType::Find(needle) => sz::find(haystack, needle), MatcherType::RFind(needle) => sz::rfind(haystack, needle), - MatcherType::FindFirstOf(needles) => sz::find_char_from(haystack, needles), - MatcherType::FindLastOf(needles) => sz::rfind_char_from(haystack, needles), - MatcherType::FindFirstNotOf(needles) => sz::find_char_not_from(haystack, needles), - MatcherType::FindLastNotOf(needles) => sz::rfind_char_not_from(haystack, needles), + MatcherType::FindFirstOf(needles) => sz::find_byte_from(haystack, needles), + MatcherType::FindLastOf(needles) => sz::rfind_byte_from(haystack, needles), + MatcherType::FindFirstNotOf(needles) => sz::find_byte_not_from(haystack, needles), + MatcherType::FindLastNotOf(needles) => sz::rfind_byte_not_from(haystack, needles), } } @@ -1088,9 +1126,9 @@ where /// use stringzilla::StringZilla; /// /// let haystack = "Hello, world!"; - /// assert_eq!(haystack.sz_find_char_from("aeiou".as_bytes()), Some(1)); + /// assert_eq!(haystack.sz_find_byte_from("aeiou".as_bytes()), Some(1)); /// ``` - fn sz_find_char_from(&self, needles: N) -> Option; + fn sz_find_byte_from(&self, needles: N) -> Option; /// Finds the index of the last character in `self` that is also present in `needles`. /// @@ -1100,9 +1138,9 @@ where /// use stringzilla::StringZilla; /// /// let haystack = "Hello, world!"; - /// assert_eq!(haystack.sz_rfind_char_from("aeiou".as_bytes()), Some(8)); + /// assert_eq!(haystack.sz_rfind_byte_from("aeiou".as_bytes()), Some(8)); /// ``` - fn sz_rfind_char_from(&self, needles: N) -> Option; + fn sz_rfind_byte_from(&self, needles: N) -> Option; /// Finds the index of the first character in `self` that is not present in `needles`. /// @@ -1112,9 +1150,9 @@ where /// use stringzilla::StringZilla; /// /// let haystack = "Hello, world!"; - /// assert_eq!(haystack.sz_find_char_not_from("aeiou".as_bytes()), Some(0)); + /// assert_eq!(haystack.sz_find_byte_not_from("aeiou".as_bytes()), Some(0)); /// ``` - fn sz_find_char_not_from(&self, needles: N) -> Option; + fn sz_find_byte_not_from(&self, needles: N) -> Option; /// Finds the index of the last character in `self` that is not present in `needles`. /// @@ -1124,9 +1162,9 @@ where /// use stringzilla::StringZilla; /// /// let haystack = "Hello, world!"; - /// assert_eq!(haystack.sz_rfind_char_not_from("aeiou".as_bytes()), Some(12)); + /// assert_eq!(haystack.sz_rfind_byte_not_from("aeiou".as_bytes()), Some(12)); /// ``` - fn sz_rfind_char_not_from(&self, needles: N) -> Option; + fn sz_rfind_byte_not_from(&self, needles: N) -> Option; /// Computes the Levenshtein edit distance between `self` and `other`. /// @@ -1137,9 +1175,9 @@ where /// /// let first = "kitten"; /// let second = "sitting"; - /// assert_eq!(first.sz_edit_distance(second.as_bytes()), 3); + /// assert_eq!(first.sz_levenshtein_distance(second.as_bytes()), Ok(3)); /// ``` - fn sz_edit_distance(&self, other: N) -> usize; + fn sz_levenshtein_distance(&self, other: N) -> Result; /// Computes the Levenshtein edit distance between `self` and `other`. /// @@ -1150,9 +1188,9 @@ where /// /// let first = "kitten"; /// let second = "sitting"; - /// assert_eq!(first.sz_edit_distance_utf8(second.as_bytes()), 3); + /// assert_eq!(first.sz_levenshtein_distance_utf8(second.as_bytes()), Ok(3)); /// ``` - fn sz_edit_distance_utf8(&self, other: N) -> usize; + fn sz_levenshtein_distance_utf8(&self, other: N) -> Result; /// Computes the bounded Levenshtein edit distance between `self` and `other`. /// @@ -1163,9 +1201,9 @@ where /// /// let first = "kitten"; /// let second = "sitting"; - /// assert_eq!(first.sz_edit_distance_bounded(second.as_bytes()), 3); + /// assert_eq!(first.sz_levenshtein_distance_bounded(second.as_bytes()), Ok(3)); /// ``` - fn sz_edit_distance_bounded(&self, other: N, bound: usize) -> usize; + fn sz_levenshtein_distance_bounded(&self, other: N, bound: usize) -> Result; /// Computes the bounded Levenshtein edit distance between `self` and `other`. /// @@ -1176,9 +1214,9 @@ where /// /// let first = "kitten"; /// let second = "sitting"; - /// assert_eq!(first.sz_edit_distance_utf8_bounded(second.as_bytes()), 3); + /// assert_eq!(first.sz_levenshtein_distance_utf8_bounded(second.as_bytes()), Ok(3)); /// ``` - fn sz_edit_distance_utf8_bounded(&self, other: N, bound: usize) -> usize; + fn sz_levenshtein_distance_utf8_bounded(&self, other: N, bound: usize) -> Result; /// Computes the alignment score between `self` and `other` using the specified /// substitution matrix and gap penalty. @@ -1192,9 +1230,9 @@ where /// let second = "sitting"; /// let matrix = sz::unary_substitution_costs(); /// let gap_penalty = -1; - /// assert_eq!(first.sz_alignment_score(second.as_bytes(), matrix, gap_penalty), -3); + /// assert_eq!(first.sz_needleman_wunsch_score(second.as_bytes(), matrix, gap_penalty), Ok(-3)); /// ``` - fn sz_alignment_score(&self, other: N, matrix: [[i8; 256]; 256], gap: i8) -> isize; + fn sz_needleman_wunsch_score(&self, other: N, matrix: [[i8; 256]; 256], gap: i8) -> Result; /// Returns an iterator over all non-overlapping matches of the given `needle` in `self`. /// @@ -1362,39 +1400,39 @@ where sz::rfind(self, needle) } - fn sz_find_char_from(&self, needles: N) -> Option { - sz::find_char_from(self, needles) + fn sz_find_byte_from(&self, needles: N) -> Option { + sz::find_byte_from(self, needles) } - fn sz_rfind_char_from(&self, needles: N) -> Option { - sz::rfind_char_from(self, needles) + fn sz_rfind_byte_from(&self, needles: N) -> Option { + sz::rfind_byte_from(self, needles) } - fn sz_find_char_not_from(&self, needles: N) -> Option { - sz::find_char_not_from(self, needles) + fn sz_find_byte_not_from(&self, needles: N) -> Option { + sz::find_byte_not_from(self, needles) } - fn sz_rfind_char_not_from(&self, needles: N) -> Option { - sz::rfind_char_not_from(self, needles) + fn sz_rfind_byte_not_from(&self, needles: N) -> Option { + sz::rfind_byte_not_from(self, needles) } - fn sz_edit_distance(&self, other: N) -> usize { - sz::edit_distance(self, other) + fn sz_levenshtein_distance(&self, other: N) -> Result { + sz::levenshtein_distance(self, other) } - fn sz_edit_distance_utf8(&self, other: N) -> usize { - sz::edit_distance_utf8(self, other) + fn sz_levenshtein_distance_utf8(&self, other: N) -> Result { + sz::levenshtein_distance_utf8(self, other) } - fn sz_edit_distance_bounded(&self, other: N, bound: usize) -> usize { - sz::edit_distance_bounded(self, other, bound) + fn sz_levenshtein_distance_bounded(&self, other: N, bound: usize) -> Result { + sz::levenshtein_distance_bounded(self, other, bound) } - fn sz_edit_distance_utf8_bounded(&self, other: N, bound: usize) -> usize { - sz::edit_distance_utf8_bounded(self, other, bound) + fn sz_levenshtein_distance_utf8_bounded(&self, other: N, bound: usize) -> Result { + sz::levenshtein_distance_utf8_bounded(self, other, bound) } - fn sz_alignment_score(&self, other: N, matrix: [[i8; 256]; 256], gap: i8) -> isize { + fn sz_needleman_wunsch_score(&self, other: N, matrix: [[i8; 256]; 256], gap: i8) -> Result { sz::alignment_score(self, other, matrix, gap) } @@ -1415,84 +1453,19 @@ where } fn sz_find_first_of(&'a self, needles: &'a N) -> RangeMatches<'a> { - RangeMatches::new( - self.as_ref(), - MatcherType::FindFirstOf(needles.as_ref()), - true, - ) + RangeMatches::new(self.as_ref(), MatcherType::FindFirstOf(needles.as_ref()), true) } fn sz_find_last_of(&'a self, needles: &'a N) -> RangeRMatches<'a> { - RangeRMatches::new( - self.as_ref(), - MatcherType::FindLastOf(needles.as_ref()), - true, - ) + RangeRMatches::new(self.as_ref(), MatcherType::FindLastOf(needles.as_ref()), true) } fn sz_find_first_not_of(&'a self, needles: &'a N) -> RangeMatches<'a> { - RangeMatches::new( - self.as_ref(), - MatcherType::FindFirstNotOf(needles.as_ref()), - true, - ) + RangeMatches::new(self.as_ref(), MatcherType::FindFirstNotOf(needles.as_ref()), true) } fn sz_find_last_not_of(&'a self, needles: &'a N) -> RangeRMatches<'a> { - RangeRMatches::new( - self.as_ref(), - MatcherType::FindLastNotOf(needles.as_ref()), - true, - ) - } -} - -/// Provides a tool for mutating a byte slice by filling it with random data from a specified alphabet. -/// This trait is especially useful for types that need to be mutable and can reference or be converted to byte slices. -/// -/// # Examples -/// -/// Filling a mutable byte buffer with random ASCII letters: -/// -/// ``` -/// use stringzilla::MutableStringZilla; -/// -/// let mut buffer = vec![0u8; 10]; // A buffer to randomize -/// let alphabet = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; // Alphabet to use -/// buffer.sz_randomize(alphabet); -/// -/// println!("Random buffer: {:?}", buffer); -/// // The buffer will now contain random ASCII letters. -/// ``` -pub trait MutableStringZilla -where - A: AsRef<[u8]>, -{ - /// Fills the implementing byte slice with random bytes from the specified `alphabet`. - /// - /// # Examples - /// - /// ``` - /// use stringzilla::MutableStringZilla; - /// - /// let mut text = vec![0; 1000]; // A buffer to randomize - /// let alphabet = b"AGTC"; // Using a DNA alphabet - /// text.sz_randomize(alphabet); - /// - /// // `text` is now filled with random 'A', 'G', 'T', or 'C' values. - /// ``` - fn sz_randomize(&mut self, alphabet: A); -} - -impl MutableStringZilla for T -where - T: AsMut<[u8]>, - A: AsRef<[u8]>, -{ - fn sz_randomize(&mut self, alphabet: A) { - let self_mut = self.as_mut(); - let alphabet_ref = alphabet.as_ref(); - sz::randomize(self_mut, alphabet_ref); + RangeRMatches::new(self.as_ref(), MatcherType::FindLastNotOf(needles.as_ref()), true) } } @@ -1500,50 +1473,46 @@ where mod tests { use std::borrow::Cow; - use crate::sz; - use crate::MutableStringZilla; - use crate::StringZilla; + use crate::sz; // For global functions + use crate::StringZilla; // For member functions #[test] fn hamming() { - assert_eq!(sz::hamming_distance("hello", "hello"), 0); - assert_eq!(sz::hamming_distance("hello", "hell"), 1); - assert_eq!(sz::hamming_distance("abc", "adc"), 1); + assert_eq!(sz::hamming_distance("hello", "hello"), Ok(0)); + assert_eq!(sz::hamming_distance("hello", "hell"), Ok(1)); + assert_eq!(sz::hamming_distance("abc", "adc"), Ok(1)); - assert_eq!(sz::hamming_distance_bounded("abcdefgh", "ABCDEFGH", 2), 2); - assert_eq!(sz::hamming_distance_utf8("αβγδ", "αγγδ"), 1); + assert_eq!(sz::hamming_distance_bounded("abcdefgh", "ABCDEFGH", 2), Ok(2)); + assert_eq!(sz::hamming_distance_utf8("αβγδ", "αγγδ"), Ok(1)); } #[test] fn levenshtein() { - assert_eq!(sz::edit_distance("hello", "hell"), 1); - assert_eq!(sz::edit_distance("hello", "hell"), 1); - assert_eq!(sz::edit_distance("abc", ""), 3); - assert_eq!(sz::edit_distance("abc", "ac"), 1); - assert_eq!(sz::edit_distance("abc", "a_bc"), 1); - assert_eq!(sz::edit_distance("abc", "adc"), 1); - assert_eq!(sz::edit_distance("fitting", "kitty"), 4); - assert_eq!(sz::edit_distance("smitten", "mitten"), 1); - assert_eq!(sz::edit_distance("ggbuzgjux{}l", "gbuzgjux{}l"), 1); - assert_eq!(sz::edit_distance("abcdefgABCDEFG", "ABCDEFGabcdefg"), 14); - - assert_eq!(sz::edit_distance_bounded("fitting", "kitty", 2), 2); - assert_eq!(sz::edit_distance_utf8("façade", "facade"), 1); + assert_eq!(sz::levenshtein_distance("hello", "hell"), Ok(1)); + assert_eq!(sz::levenshtein_distance("hello", "hell"), Ok(1)); + assert_eq!(sz::levenshtein_distance("abc", ""), Ok(3)); + assert_eq!(sz::levenshtein_distance("abc", "ac"), Ok(1)); + assert_eq!(sz::levenshtein_distance("abc", "a_bc"), Ok(1)); + assert_eq!(sz::levenshtein_distance("abc", "adc"), Ok(1)); + assert_eq!(sz::levenshtein_distance("fitting", "kitty"), Ok(4)); + assert_eq!(sz::levenshtein_distance("smitten", "mitten"), Ok(1)); + assert_eq!(sz::levenshtein_distance("ggbuzgjux{}l", "gbuzgjux{}l"), Ok(1)); + assert_eq!(sz::levenshtein_distance("abcdefgABCDEFG", "ABCDEFGabcdefg"), Ok(14)); + + assert_eq!(sz::levenshtein_distance_bounded("fitting", "kitty", 2), Ok(2)); + assert_eq!(sz::levenshtein_distance_utf8("façade", "facade"), Ok(1)); } #[test] fn needleman() { let costs_vector = sz::unary_substitution_costs(); - assert_eq!( - sz::alignment_score("listen", "silent", costs_vector, -1), - -4 - ); + assert_eq!(sz::alignment_score("listen", "silent", costs_vector, -1), Ok(-4)); assert_eq!( sz::alignment_score("abcdefgABCDEFG", "ABCDEFGabcdefg", costs_vector, -1), - -14 + Ok(-14) ); - assert_eq!(sz::alignment_score("hello", "hello", costs_vector, -1), 0); - assert_eq!(sz::alignment_score("hello", "hell", costs_vector, -1), -1); + assert_eq!(sz::alignment_score("hello", "hello", costs_vector, -1), Ok(0)); + assert_eq!(sz::alignment_score("hello", "hell", costs_vector, -1), Ok(-1)); } #[test] @@ -1559,41 +1528,37 @@ mod tests { // Use the generic function with a String assert_eq!(my_string.sz_find("world"), Some(7)); assert_eq!(my_string.sz_rfind("world"), Some(7)); - assert_eq!(my_string.sz_find_char_from("world"), Some(2)); - assert_eq!(my_string.sz_rfind_char_from("world"), Some(11)); - assert_eq!(my_string.sz_find_char_not_from("world"), Some(0)); - assert_eq!(my_string.sz_rfind_char_not_from("world"), Some(12)); + assert_eq!(my_string.sz_find_byte_from("world"), Some(2)); + assert_eq!(my_string.sz_rfind_byte_from("world"), Some(11)); + assert_eq!(my_string.sz_find_byte_not_from("world"), Some(0)); + assert_eq!(my_string.sz_rfind_byte_not_from("world"), Some(12)); // Use the generic function with a &str assert_eq!(my_str.sz_find("world"), Some(7)); assert_eq!(my_str.sz_find("world"), Some(7)); - assert_eq!(my_str.sz_find_char_from("world"), Some(2)); - assert_eq!(my_str.sz_rfind_char_from("world"), Some(11)); - assert_eq!(my_str.sz_find_char_not_from("world"), Some(0)); - assert_eq!(my_str.sz_rfind_char_not_from("world"), Some(12)); + assert_eq!(my_str.sz_find_byte_from("world"), Some(2)); + assert_eq!(my_str.sz_rfind_byte_from("world"), Some(11)); + assert_eq!(my_str.sz_find_byte_not_from("world"), Some(0)); + assert_eq!(my_str.sz_rfind_byte_not_from("world"), Some(12)); // Use the generic function with a Cow<'_, str> assert_eq!(my_cow_str.as_ref().sz_find("world"), Some(7)); assert_eq!(my_cow_str.as_ref().sz_find("world"), Some(7)); - assert_eq!(my_cow_str.as_ref().sz_find_char_from("world"), Some(2)); - assert_eq!(my_cow_str.as_ref().sz_rfind_char_from("world"), Some(11)); - assert_eq!(my_cow_str.as_ref().sz_find_char_not_from("world"), Some(0)); - assert_eq!( - my_cow_str.as_ref().sz_rfind_char_not_from("world"), - Some(12) - ); + assert_eq!(my_cow_str.as_ref().sz_find_byte_from("world"), Some(2)); + assert_eq!(my_cow_str.as_ref().sz_rfind_byte_from("world"), Some(11)); + assert_eq!(my_cow_str.as_ref().sz_find_byte_not_from("world"), Some(0)); + assert_eq!(my_cow_str.as_ref().sz_rfind_byte_not_from("world"), Some(12)); } #[test] - fn randomize() { - let mut text: Vec = vec![0; 10]; // A buffer of ten zeros - let alphabet: &[u8] = b"abcd"; // A byte slice alphabet - text.sz_randomize(alphabet); - - // Iterate throught text and check that it only contains letters from the alphabet - assert!(text - .iter() - .all(|&b| b == b'd' || b == b'c' || b == b'b' || b == b'a')); + fn fill_random() { + let mut first_buffer: Vec = vec![0; 10]; // Ten zeros + let mut second_buffer: Vec = vec![1; 10]; // Ten ones + sz::fill_random(&mut first_buffer, 42); + sz::fill_random(&mut second_buffer, 42); + + // Same nonce will produce the same outputs + assert!(first_buffer != second_buffer); } mod search_split_iterators { diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..75306517 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 120 diff --git a/scripts/bench_memory.cpp b/scripts/bench_memory.cpp index 4f52c282..47a67835 100644 --- a/scripts/bench_memory.cpp +++ b/scripts/bench_memory.cpp @@ -176,7 +176,7 @@ tracked_unary_functions_t transform_functions() { auto wrap_sz = [](auto function) -> unary_function_t { return unary_function_t([function](std::string_view slice) { char *output = const_cast(slice.data()); - function((sz_cptr_t)output, (sz_size_t)slice.size(), (sz_cptr_t)look_up_table, (sz_ptr_t)output); + function((sz_ptr_t)output, (sz_size_t)slice.size(), (sz_cptr_t)output, (sz_cptr_t)look_up_table); return slice.size(); }); }; diff --git a/scripts/bench_similarity.cpp b/scripts/bench_similarity.cpp index ca901a5f..b2d1c9ee 100644 --- a/scripts/bench_similarity.cpp +++ b/scripts/bench_similarity.cpp @@ -38,25 +38,28 @@ tracked_binary_functions_t distance_functions() { }); auto wrap_sz_distance = [alloc](auto function) mutable -> binary_function_t { return binary_function_t([function, alloc](std::string_view a, std::string_view b) mutable -> std::size_t { - return function(a.data(), a.length(), b.data(), b.length(), SZ_SIZE_MAX, &alloc); + sz_size_t result; + function(a.data(), a.length(), b.data(), b.length(), SZ_SIZE_MAX, &alloc, &result); + return result; }); }; auto wrap_sz_scoring = [alloc, costs_ptr](auto function) mutable -> binary_function_t { return binary_function_t( [function, alloc, costs_ptr](std::string_view a, std::string_view b) mutable -> std::size_t { sz_memory_allocator_t *alloc_ptr = &alloc; - sz_ssize_t signed_result = - function(a.data(), a.length(), b.data(), b.length(), costs_ptr, (sz_error_cost_t)-1, alloc_ptr); + sz_ssize_t signed_result; + function(a.data(), a.length(), b.data(), b.length(), costs_ptr, (sz_error_cost_t)-1, alloc_ptr, + &signed_result); return (std::size_t)(-signed_result); }); }; tracked_binary_functions_t result = { {"naive", wrap_baseline}, - {"sz_edit_distance_serial", wrap_sz_distance(sz_edit_distance_serial), true}, - {"sz_alignment_score_serial", wrap_sz_scoring(sz_alignment_score_serial), true}, + {"sz_levenshtein_distance_serial", wrap_sz_distance(sz_levenshtein_distance_serial), true}, + {"sz_needleman_wunsch_score_serial", wrap_sz_scoring(sz_needleman_wunsch_score_serial), true}, #if SZ_USE_ICE - {"sz_edit_distance_ice", wrap_sz_distance(sz_edit_distance_ice), true}, - {"sz_alignment_score_ice", wrap_sz_scoring(sz_alignment_score_ice), true}, + {"sz_levenshtein_distance_ice", wrap_sz_distance(sz_levenshtein_distance_ice), true}, + {"sz_needleman_wunsch_score_ice", wrap_sz_scoring(sz_needleman_wunsch_score_ice), true}, #endif }; return result; diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index 22758d95..a045192f 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -23,13 +23,13 @@ using permute_t = std::vector; #pragma region C callbacks -static char const *get_start(sz_sequence_t const *array_c, sz_size_t i) { - strings_t const &array = *reinterpret_cast(array_c->handle); +static sz_cptr_t get_start(void const *handle, sz_size_t i) { + strings_t const &array = *reinterpret_cast(handle); return array[i].c_str(); } -static sz_size_t get_length(sz_sequence_t const *array_c, sz_size_t i) { - strings_t const &array = *reinterpret_cast(array_c->handle); +static sz_size_t get_length(void const *handle, sz_size_t i) { + strings_t const &array = *reinterpret_cast(handle); return array[i].size(); } @@ -112,21 +112,11 @@ int main(int argc, char const **argv) { }); expect_sorted(pgrams, permute); - bench_permute("sz_pgrams_sort_ice", [&]() { + bench_permute("sz_pgrams_sort_skylake", [&]() { std::copy(pgrams.begin(), pgrams.end(), pgrams_sorted.begin()); std::iota(permute.begin(), permute.end(), 0); sz::_with_alloc([&](sz_memory_allocator_t &alloc) { - return sz_pgrams_sort_ice(pgrams_sorted.data(), pgrams_sorted.size(), &alloc, permute.data()); - }); - }); - expect_sorted(pgrams, permute); - - // Unlike the `std::sort` adaptation above, the `sz_pgrams_sort_stable_serial` also sorts the input array inplace - bench_permute("sz_pgrams_sort_stable_serial", [&]() { - std::copy(pgrams.begin(), pgrams.end(), pgrams_sorted.begin()); - std::iota(permute.begin(), permute.end(), 0); - sz::_with_alloc([&](sz_memory_allocator_t &alloc) { - return sz_pgrams_sort_stable_serial(pgrams_sorted.data(), pgrams_sorted.size(), &alloc, permute.data()); + return sz_pgrams_sort_skylake(pgrams_sorted.data(), pgrams_sorted.size(), &alloc, permute.data()); }); }); expect_sorted(pgrams, permute); @@ -151,7 +141,7 @@ int main(int argc, char const **argv) { }); expect_sorted(strings, permute); - bench_permute("sz_sequence_argsort_ice", [&]() { + bench_permute("sz_sequence_argsort_skylake", [&]() { std::iota(permute.begin(), permute.end(), 0); sz_sequence_t array; array.count = strings.size(); @@ -159,7 +149,7 @@ int main(int argc, char const **argv) { array.get_start = get_start; array.get_length = get_length; sz::_with_alloc( - [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort_ice(&array, &alloc, permute.data()); }); + [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort_skylake(&array, &alloc, permute.data()); }); }); expect_sorted(strings, permute); diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 378ad4f0..0d83604b 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -61,12 +61,12 @@ tracked_unary_functions_t hash_functions() { return result; } -struct wrapped_incremental_hash { +struct wrap_hash_stream { sz_hash_state_t state; sz_hash_state_stream_t stream; sz_hash_state_fold_t fold; - wrapped_incremental_hash(sz_hash_state_stream_t s, sz_hash_state_fold_t f) : stream(s), fold(f) { + wrap_hash_stream(sz_hash_state_stream_t s, sz_hash_state_fold_t f) : stream(s), fold(f) { sz_hash_state_init(&state, 42); } @@ -78,20 +78,18 @@ struct wrapped_incremental_hash { tracked_unary_functions_t hash_stream_functions() { tracked_unary_functions_t result = { - {"sz_hash_stream_serial", wrapped_incremental_hash(sz_hash_state_stream_serial, sz_hash_state_fold_serial)}, + {"sz_hash_stream_serial", wrap_hash_stream(sz_hash_state_stream_serial, sz_hash_state_fold_serial)}, #if SZ_USE_HASWELL - {"sz_hash_stream_haswell", wrapped_incremental_hash(sz_hash_state_stream_haswell, sz_hash_state_fold_haswell), - true}, + {"sz_hash_stream_haswell", wrap_hash_stream(sz_hash_state_stream_haswell, sz_hash_state_fold_haswell), true}, #endif #if SZ_USE_SKYLAKE - {"sz_hash_stream_skylake", wrapped_incremental_hash(sz_hash_state_stream_skylake, sz_hash_state_fold_skylake), - true}, + {"sz_hash_stream_skylake", wrap_hash_stream(sz_hash_state_stream_skylake, sz_hash_state_fold_skylake), true}, #endif #if SZ_USE_ICE - {"sz_hash_stream_ice", wrapped_incremental_hash(sz_hash_state_stream_ice, sz_hash_state_fold_ice), true}, + {"sz_hash_stream_ice", wrap_hash_stream(sz_hash_state_stream_ice, sz_hash_state_fold_ice), true}, #endif #if SZ_USE_NEON - {"sz_hash_stream_neon", wrapped_incremental_hash(sz_hash_state_stream_neon, sz_hash_state_fold_neon), true}, + {"sz_hash_stream_neon", wrap_hash_stream(sz_hash_state_stream_neon, sz_hash_state_fold_neon), true}, #endif }; return result; diff --git a/scripts/test.py b/scripts/test.py index ea95e8d4..eb92252d 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -178,10 +178,10 @@ def test_unit_split(): assert letters == ["a", "b", "c", "d"] # Splitting using character sets - letters = sz.split_charset("a b_c d", " _") + letters = sz.split_byteset("a b_c d", " _") assert letters == ["a", "b", "c", "d"] - letters = sz.rsplit_charset("a b_c d", " _") + letters = sz.rsplit_byteset("a b_c d", " _") assert letters == ["a", "b", "c", "d"] # Check for equivalence with native Python strings for newline separators @@ -212,17 +212,17 @@ def test_unit_split(): with pytest.raises(ValueError): sz.rsplit(big, "") with pytest.raises(ValueError): - sz.split_charset(big, "") + sz.split_byteset(big, "") with pytest.raises(ValueError): - sz.rsplit_charset(big, "") + sz.rsplit_byteset(big, "") def test_unit_split_iterators(): """ Test the iterator-based split methods. This is slightly different from `split` and `rsplit` in that it returns an iterator instead of a list. - Moreover, the native `rsplit` and even `rsplit_charset` report results in the identical order to `split` - and `split_charset`. Here `rsplit_iter` reports elements in the reverse order, compared to `split_iter`. + Moreover, the native `rsplit` and even `rsplit_byteset` report results in the identical order to `split` + and `split_byteset`. Here `rsplit_iter` reports elements in the reverse order, compared to `split_iter`. """ native = "line1\nline2\nline3" big = Str(native) @@ -244,10 +244,10 @@ def test_unit_split_iterators(): assert letters == ["a", "b", "c", "d"] # Splitting using character sets - letters = list(sz.split_charset_iter("a-b_c-d", "-_")) + letters = list(sz.split_byteset_iter("a-b_c-d", "-_")) assert letters == ["a", "b", "c", "d"] - letters = list(sz.rsplit_charset_iter("a-b_c-d", "-_")) + letters = list(sz.rsplit_byteset_iter("a-b_c-d", "-_")) assert letters == ["d", "c", "b", "a"] # Check for equivalence with native Python strings, including boundary conditions @@ -279,9 +279,9 @@ def rlist(seq): with pytest.raises(ValueError): sz.rsplit_iter(big, "") with pytest.raises(ValueError): - sz.split_charset_iter(big, "") + sz.split_byteset_iter(big, "") with pytest.raises(ValueError): - sz.rsplit_charset_iter(big, "") + sz.rsplit_byteset_iter(big, "") def test_unit_strs_sequence(): @@ -289,7 +289,7 @@ def test_unit_strs_sequence(): big = Str(native) lines = big.splitlines() - assert [2, 1, 0] == list(lines.order()) + assert [2, 1, 0] == list(lines.argsort()) assert "p3" in lines assert "p4" not in lines @@ -301,11 +301,11 @@ def test_unit_strs_sequence(): assert str(Str("a" * 1_000_000).split()).endswith("aaa']") lines.sort() - assert [0, 1, 2] == list(lines.order()) + assert [0, 1, 2] == list(lines.argsort()) assert ["p1", "p2", "p3"] == list(lines) # Reverse order - assert [2, 1, 0] == list(lines.order(reverse=True)) + assert [2, 1, 0] == list(lines.argsort(reverse=True)) lines.sort(reverse=True) assert ["p3", "p2", "p1"] == list(lines) @@ -798,7 +798,7 @@ def test_fuzzy_sorting(list_length: int, part_length: int, variability: int): big_list = big_joined.split(".") native_ordered = sorted(native_list) - native_order = big_list.order() + native_order = big_list.argsort() for i in range(list_length): assert native_ordered[i] == native_list[native_order[i]], "Order is wrong" assert native_ordered[i] == str( @@ -826,7 +826,7 @@ def test_fuzzy_sorting(list_length: int, part_length: int, variability: int): big_list = big_joined.split(".") native_ordered = sorted(native_list) - native_order = big_list.order() + native_order = big_list.argsort() for i in range(list_length): assert native_ordered[i] == native_list[native_order[i]], "Order is wrong" assert native_ordered[i] == str( diff --git a/swift/StringProtocol+StringZilla.swift b/swift/StringProtocol+StringZilla.swift index d90c8afc..e573b609 100644 --- a/swift/StringProtocol+StringZilla.swift +++ b/swift/StringProtocol+StringZilla.swift @@ -18,13 +18,13 @@ import StringZillaC // We need to link the standard libraries. #if os(Linux) -import Glibc + import Glibc #else -import Darwin.C + import Darwin.C #endif /// Protocol defining a single-byte data type. -fileprivate protocol SingleByte {} +private protocol SingleByte {} extension UInt8: SingleByte {} extension Int8: SingleByte {} // This would match `CChar` as well. @@ -33,7 +33,7 @@ extension Int8: SingleByte {} // This would match `CChar` as well. enum StringZillaError: Error { case contiguousStorageUnavailable case memoryAllocationFailed - + var localizedDescription: String { switch self { case .contiguousStorageUnavailable: @@ -54,7 +54,7 @@ enum StringZillaError: Error { /// https://developer.apple.com/documentation/swift/stringprotocol/data(using:allowlossyconversion:) public protocol StringZillaViewable: Collection { /// A type that represents a position in the collection. - /// + /// /// Executes a closure with a pointer to the string's UTF8 C representation and its length. /// /// - Parameters: @@ -62,7 +62,7 @@ public protocol StringZillaViewable: Collection { /// - Throws: Can throw an error. /// - Returns: Returns a value of type R, which is the result of the closure. func withStringZillaScope(_ body: (sz_cptr_t, sz_size_t) throws -> R) rethrows -> R - + /// Calculates the offset index for a given byte pointer relative to a start pointer. /// /// - Parameters: @@ -74,24 +74,24 @@ public protocol StringZillaViewable: Collection { extension String: StringZillaViewable { public typealias Index = String.Index - + @_transparent public func withStringZillaScope(_ body: (sz_cptr_t, sz_size_t) throws -> R) rethrows -> R { let cLength = sz_size_t(utf8.count) - return try self.withCString { cString in + return try withCString { cString in try body(cString, cLength) } } - + @_transparent public func stringZillaByteOffset(forByte bytePointer: sz_cptr_t, after startPointer: sz_cptr_t) -> Index { - self.utf8.index(self.utf8.startIndex, offsetBy: bytePointer - startPointer) + utf8.index(utf8.startIndex, offsetBy: bytePointer - startPointer) } } extension Substring.UTF8View: StringZillaViewable { public typealias Index = Substring.UTF8View.Index - + /// Executes a closure with a pointer to the UTF8View's contiguous storage of single-byte elements (UTF-8 code units). /// - Parameters: /// - body: A closure that takes a pointer to the contiguous storage and its size. @@ -106,7 +106,7 @@ extension Substring.UTF8View: StringZillaViewable { throw StringZillaError.contiguousStorageUnavailable }() } - + /// Calculates the offset index for a given byte pointer relative to a start pointer. /// - Parameters: /// - bytePointer: A pointer to the byte for which the offset is calculated. @@ -114,13 +114,13 @@ extension Substring.UTF8View: StringZillaViewable { /// - Returns: The calculated index offset. @_transparent public func stringZillaByteOffset(forByte bytePointer: sz_cptr_t, after startPointer: sz_cptr_t) -> Index { - return self.index(self.startIndex, offsetBy: bytePointer - startPointer) + return index(startIndex, offsetBy: bytePointer - startPointer) } } extension String.UTF8View: StringZillaViewable { public typealias Index = String.UTF8View.Index - + /// Executes a closure with a pointer to the UTF8View's contiguous storage of single-byte elements (UTF-8 code units). /// - Parameters: /// - body: A closure that takes a pointer to the contiguous storage and its size. @@ -134,19 +134,18 @@ extension String.UTF8View: StringZillaViewable { throw StringZillaError.contiguousStorageUnavailable }() } - + /// Calculates the offset index for a given byte pointer relative to a start pointer. /// - Parameters: /// - bytePointer: A pointer to the byte for which the offset is calculated. /// - startPointer: The starting pointer for the calculation, previously obtained from `szScope`. /// - Returns: The calculated index offset. public func stringZillaByteOffset(forByte bytePointer: sz_cptr_t, after startPointer: sz_cptr_t) -> Index { - return self.index(self.startIndex, offsetBy: bytePointer - startPointer) + return index(startIndex, offsetBy: bytePointer - startPointer) } } public extension StringZillaViewable { - /// Finds the first occurrence of the specified substring within the receiver. /// - Parameter needle: The substring to search for. /// - Returns: The index of the found occurrence, or `nil` if not found. @@ -163,7 +162,7 @@ public extension StringZillaViewable { } return result } - + /// Finds the last occurrence of the specified substring within the receiver. /// - Parameter needle: The substring to search for. /// - Returns: The index of the found occurrence, or `nil` if not found. @@ -180,7 +179,7 @@ public extension StringZillaViewable { } return result } - + /// Finds the first occurrence of the specified character-set members within the receiver. /// - Parameter characters: A string-like collection of characters to match. /// - Returns: The index of the found occurrence, or `nil` if not found. @@ -190,14 +189,14 @@ public extension StringZillaViewable { var result: Index? withStringZillaScope { hPointer, hLength in characters.withStringZillaScope { nPointer, nLength in - if let matchPointer = sz_find_char_from(hPointer, hLength, nPointer, nLength) { + if let matchPointer = sz_find_byte_from(hPointer, hLength, nPointer, nLength) { result = self.stringZillaByteOffset(forByte: matchPointer, after: hPointer) } } } return result } - + /// Finds the last occurrence of the specified character-set members within the receiver. /// - Parameter characters: A string-like collection of characters to match. /// - Returns: The index of the found occurrence, or `nil` if not found. @@ -207,14 +206,14 @@ public extension StringZillaViewable { var result: Index? withStringZillaScope { hPointer, hLength in characters.withStringZillaScope { nPointer, nLength in - if let matchPointer = sz_rfind_char_from(hPointer, hLength, nPointer, nLength) { + if let matchPointer = sz_rfind_byte_from(hPointer, hLength, nPointer, nLength) { result = self.stringZillaByteOffset(forByte: matchPointer, after: hPointer) } } } return result } - + /// Finds the first occurrence of a character outside of the the given character-set within the receiver. /// - Parameter characters: A string-like collection of characters to exclude. /// - Returns: The index of the found occurrence, or `nil` if not found. @@ -224,14 +223,14 @@ public extension StringZillaViewable { var result: Index? withStringZillaScope { hPointer, hLength in characters.withStringZillaScope { nPointer, nLength in - if let matchPointer = sz_find_char_not_from(hPointer, hLength, nPointer, nLength) { + if let matchPointer = sz_find_byte_not_from(hPointer, hLength, nPointer, nLength) { result = self.stringZillaByteOffset(forByte: matchPointer, after: hPointer) } } } return result } - + /// Finds the last occurrence of a character outside of the the given character-set within the receiver. /// - Parameter characters: A string-like collection of characters to exclude. /// - Returns: The index of the found occurrence, or `nil` if not found. @@ -241,40 +240,46 @@ public extension StringZillaViewable { var result: Index? withStringZillaScope { hPointer, hLength in characters.withStringZillaScope { nPointer, nLength in - if let matchPointer = sz_rfind_char_not_from(hPointer, hLength, nPointer, nLength) { + if let matchPointer = sz_rfind_byte_not_from(hPointer, hLength, nPointer, nLength) { result = self.stringZillaByteOffset(forByte: matchPointer, after: hPointer) } } } return result } - - /// Computes the Levenshtein edit distance between this and another string. - /// - Parameter other: A string-like collection of characters to exclude. - /// - Returns: The edit distance, as an unsigned integer. - /// - Throws: If a memory allocation error has happened. - @_specialize(where Self == String, S == String) - @_specialize(where Self == String.UTF8View, S == String.UTF8View) - func editDistance(from other: S, bound: UInt64 = UInt64.max) throws -> UInt64? { - var result: UInt64? - - // Use a do-catch block to handle potential errors - do { - try withStringZillaScope { hPointer, hLength in - try other.withStringZillaScope { nPointer, nLength in - result = UInt64(sz_edit_distance(hPointer, hLength, nPointer, nLength, sz_size_t(bound), nil)) - if result == SZ_SIZE_MAX { - result = nil - throw StringZillaError.memoryAllocationFailed - } - } + + func levenshteinDistance( + from other: S, + bound: UInt? = nil + ) throws -> UInt { + // Prepare a local variable for the result. + var computedResult: sz_size_t = 0 + + // Swift has a ridiculous issue with casting unsigned 64-bit to unsigned 64-bit + // values which results in "Fatal error: Not enough bits to represent the passed value". + // Let's just copy the bytes: https://stackoverflow.com/a/68650250/2766161 + let effectiveBound: sz_size_t = bound.map { sz_size_t($0) } ?? _sz_size_max() + let status = try withStringZillaScope { hPointer, hLength in + try other.withStringZillaScope { nPointer, nLength in + // Pass a mutable pointer for the result. + sz_levenshtein_distance( + hPointer, + hLength, + nPointer, + nLength, + effectiveBound, + nil, // default allocator + &computedResult // out-parameter for the computed distance + ) } - } catch { - // Handle or rethrow the error - throw error } - - return result + + // Check the returned status code. + guard status == sz_success_k else { + // Map the status code to an appropriate Swift error. + throw StringZillaError.memoryAllocationFailed + } + + return UInt(computedResult) } - } diff --git a/swift/Test.swift b/swift/Test.swift index 4b839022..670d758a 100644 --- a/swift/Test.swift +++ b/swift/Test.swift @@ -5,58 +5,57 @@ // Created by Ash Vardanian on 18/1/24. // -import XCTest @testable import StringZilla +import XCTest class StringZillaTests: XCTestCase { - var testString: String! - + override func setUp() { super.setUp() testString = "Hello, world! Welcome to StringZilla. 👋" XCTAssertEqual(testString.count, 39) XCTAssertEqual(testString.utf8.count, 42) } - + func testFindFirstSubstring() { let index = testString.findFirst(substring: "world")! XCTAssertEqual(testString[index...], "world! Welcome to StringZilla. 👋") } - + func testFindLastSubstring() { let index = testString.findLast(substring: "o")! XCTAssertEqual(testString[index...], "o StringZilla. 👋") } - + func testFindFirstCharacterFromSet() { let index = testString.findFirst(characterFrom: "aeiou")! XCTAssertEqual(testString[index...], "ello, world! Welcome to StringZilla. 👋") } - + func testFindLastCharacterFromSet() { let index = testString.findLast(characterFrom: "aeiou")! XCTAssertEqual(testString[index...], "a. 👋") } - + func testFindFirstCharacterNotFromSet() { let index = testString.findFirst(characterNotFrom: "aeiou")! XCTAssertEqual(testString[index...], "Hello, world! Welcome to StringZilla. 👋") } - func testFindLastCharacterNotFromSet() { + func testFindLastCharacterNotFromSet() { let index = testString.findLast(characterNotFrom: "aeiou")! XCTAssertEqual(testString.distance(from: testString.startIndex, to: index), 38) XCTAssertEqual(testString[index...], "👋") } - - func testEditDistance() { + + func testLevenshteinDistance() { let otherString = "Hello, world!" - let distance = try? testString.editDistance(from: otherString) + let distance = try? testString.levenshteinDistance(from: otherString) XCTAssertNotNil(distance) XCTAssertEqual(distance, 29) } - + func testFindLastCharacterNotFromSetNoMatch() { let index = "aeiou".findLast(characterNotFrom: "aeiou") XCTAssertNil(index) From 8b396c829d20a59321b60aa6b856bf52f6426795 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 28 Feb 2025 18:17:39 +0000 Subject: [PATCH 136/751] Fix: `fill_random` test condition --- rust/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lib.rs b/rust/lib.rs index d5e9a682..58553d75 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -1558,7 +1558,7 @@ mod tests { sz::fill_random(&mut second_buffer, 42); // Same nonce will produce the same outputs - assert!(first_buffer != second_buffer); + assert_eq!(first_buffer, second_buffer); } mod search_split_iterators { From d52bf63a3d74529fa1b6c397b874e52243d3e80a Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 2 Mar 2025 00:33:43 +0000 Subject: [PATCH 137/751] Fix: Detecting caps in dynamic builds --- c/lib.c | 4 ++-- include/stringzilla/types.h | 21 ++++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/c/lib.c b/c/lib.c index f742ad2b..3132a8f5 100644 --- a/c/lib.c +++ b/c/lib.c @@ -158,9 +158,9 @@ SZ_INTERNAL sz_capability_t _sz_capabilities_x86(void) { * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. */ SZ_DYNAMIC sz_capability_t sz_capabilities(void) { -#if _SZ_IS_X86 +#if _SZ_IS_X86_64 return _sz_capabilities_x86(); -#elif _SZ_IS_ARM +#elif _SZ_IS_ARM64 return _sz_capabilities_arm(); #else return sz_cap_serial_k; diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index c4f71907..a15cf116 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -105,6 +105,25 @@ #endif #endif +/** + * @brief Infer the target architecture. + * At this point we only provide optimized backends for x86_64 and ARM64. + */ +#ifndef _SZ_IS_X86_64 +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) +#define _SZ_IS_X86_64 (1) +#else +#define _SZ_IS_X86_64 (0) +#endif +#endif +#ifndef _SZ_IS_ARM64 +#if defined(__aarch64__) || defined(__arm64__) || defined(__arm64) || defined(_M_ARM64) +#define _SZ_IS_ARM64 (1) +#else +#define _SZ_IS_ARM64 (0) +#endif +#endif + /** * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. @@ -230,7 +249,7 @@ */ #if SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE #include -#endif // SZ_USE_X86... +#endif // SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE #if SZ_USE_NEON #if !defined(_MSC_VER) #include From 8877c82aedf7acc6bba7e667daab3e9e5f36011d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 2 Mar 2025 00:34:29 +0000 Subject: [PATCH 138/751] Make: Decremental Rust builds SimSIMD uses the same approach --- build.rs | 126 ++++++++++++++++++++++++++----------------------------- 1 file changed, 60 insertions(+), 66 deletions(-) diff --git a/build.rs b/build.rs index bb5fb5cf..9622457f 100644 --- a/build.rs +++ b/build.rs @@ -6,7 +6,12 @@ fn main() { .file("c/lib.c") .include("include") .warnings(false) - .flag_if_supported("-std=c99") + .define("SZ_DYNAMIC_DISPATCH", "1") + .define("SZ_AVOID_LIBC", "0") + .define("SZ_DEBUG", "0") + .flag("-O3") + .flag("-std=c99") // Enforce C99 standard + .flag_if_supported("-fdiagnostics-color=always") .flag_if_supported("-fPIC"); // Cargo will set different environment variables that we can use to properly configure the build. @@ -14,70 +19,6 @@ fn main() { let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default(); let target_endian = env::var("CARGO_CFG_TARGET_ENDIAN").unwrap_or_default(); - // To get the operating system we can use the TARGET environment variable. - // To check the list of available targets, run `rustc --print target-list`. - let target = env::var("TARGET").unwrap_or_default(); - - if target.contains("linux") { - build.flag_if_supported("-fdiagnostics-color=always"); - build.flag_if_supported("-O3"); - build.flag_if_supported("-pedantic"); - - // Set architecture-specific flags and macros - if target_arch == "x86_64" { - build.define("SZ_USE_HASWELL", "1"); - build.define("SZ_USE_SKYLAKE", "1"); - build.define("SZ_USE_ICE", "1"); - } else { - build.define("SZ_USE_HASWELL", "0"); - build.define("SZ_USE_SKYLAKE", "0"); - build.define("SZ_USE_ICE", "0"); - } - - if target_arch == "aarch64" { - build.flag_if_supported("-march=armv8-a+simd"); - build.define("SZ_USE_NEON", "1"); - build.define("SZ_USE_SVE", "1"); - } else { - build.define("SZ_USE_NEON", "0"); - build.define("SZ_USE_SVE", "0"); - } - } else if target.contains("darwin") { - build.flag_if_supported("-fcolor-diagnostics"); - build.flag_if_supported("-O3"); - build.flag_if_supported("-pedantic"); - - if target_arch == "x86_64" { - // Assuming no AVX-512 support for Darwin as per setup.py logic - build.define("SZ_USE_HASWELL", "1"); - build.define("SZ_USE_SKYLAKE", "0"); - build.define("SZ_USE_ICE", "0"); - } else { - build.define("SZ_USE_HASWELL", "0"); - build.define("SZ_USE_SKYLAKE", "0"); - build.define("SZ_USE_ICE", "0"); - } - - if target_arch == "aarch64" { - build.define("SZ_USE_NEON", "1"); - build.define("SZ_USE_SVE", "0"); // Assuming no SVE support for Darwin - } else { - build.define("SZ_USE_NEON", "0"); - build.define("SZ_USE_SVE", "0"); - } - } else if target.contains("windows") { - // Set architecture-specific flags and macros - if target_arch == "x86_64" { - build.define("SZ_USE_HASWELL", "1"); - build.define("SZ_USE_SKYLAKE", "1"); - build.define("SZ_USE_ICE", "1"); - } else { - build.define("SZ_USE_HASWELL", "0"); - build.define("SZ_USE_SKYLAKE", "0"); - build.define("SZ_USE_ICE", "0"); - } - } - // Set endian-specific macro if target_endian == "big" { build.define("SZ_DETECT_BIG_ENDIAN", "1"); @@ -85,9 +26,62 @@ fn main() { build.define("SZ_DETECT_BIG_ENDIAN", "0"); } - build.compile("stringzilla"); + if target_arch == "x86_64" { + build.define("_SZ_IS_X86_64", "1"); + build.define("_SZ_IS_ARM64", "0"); + } else if target_arch == "aarch64" { + build.define("_SZ_IS_X86_64", "0"); + build.define("_SZ_IS_ARM64", "1"); + } + + // At start we will try compiling with all SIMD backends enabled + let flags_to_try = match target_arch.as_str() { + "arm" | "aarch64" => vec![ + // + "SZ_USE_SVE2", + "SZ_USE_SVE", + "SZ_USE_NEON", + ], + _ => vec![ + // + "SZ_USE_ICE", + "SZ_USE_SKYLAKE", + "SZ_USE_HASWELL", + ], + }; + for flag in flags_to_try.iter() { + build.define(flag, "1"); + } + + // If that fails, we will try disabling them one by one + if build.try_compile("stringzilla").is_err() { + print!("cargo:warning=Failed to compile with all SIMD backends..."); + + for flag in flags_to_try.iter() { + build.define(flag, "0"); + if build.try_compile("stringzilla").is_ok() { + break; + } + + // Print the failed configuration + println!( + "cargo:warning=Failed to compile after disabling {}, trying next configuration...", + flag + ); + } + } println!("cargo:rerun-if-changed=c/lib.c"); println!("cargo:rerun-if-changed=rust/lib.rs"); println!("cargo:rerun-if-changed=include/stringzilla/stringzilla.h"); + + // Constituent parts: + println!("cargo:rerun-if-changed=include/stringzilla/compare.h"); + println!("cargo:rerun-if-changed=include/stringzilla/find.h"); + println!("cargo:rerun-if-changed=include/stringzilla/hash.h"); + println!("cargo:rerun-if-changed=include/stringzilla/memory.h"); + println!("cargo:rerun-if-changed=include/stringzilla/similarity.h"); + println!("cargo:rerun-if-changed=include/stringzilla/small_string.h"); + println!("cargo:rerun-if-changed=include/stringzilla/sort.h"); + println!("cargo:rerun-if-changed=include/stringzilla/types.h"); } From fbf256aca7b62f3e978334d5554b968fec0a1fa2 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 2 Mar 2025 00:35:09 +0000 Subject: [PATCH 139/751] Make: `cibuildwheel` env variables --- pyproject.toml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a8dd42e2..ed969673 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,50 +79,50 @@ before-build = ["rd /s /q {project}\\build || echo Done"] [[tool.cibuildwheel.overrides]] select = "*-win_amd64" inherit.environment = "append" -environment.SZ_X86_64 = "1" +environment._SZ_IS_X86_64 = "1" [[tool.cibuildwheel.overrides]] select = "*-manylinux*_x86_64" inherit.environment = "append" -environment.SZ_X86_64 = "1" +environment._SZ_IS_X86_64 = "1" [[tool.cibuildwheel.overrides]] select = "*-musllinux*_x86_64" inherit.environment = "append" -environment.SZ_X86_64 = "1" +environment._SZ_IS_X86_64 = "1" [[tool.cibuildwheel.overrides]] select = "*-macos*_x86_64" inherit.environment = "append" -environment.SZ_X86_64 = "1" +environment._SZ_IS_X86_64 = "1" # Detect ARM 64-bit builds [[tool.cibuildwheel.overrides]] select = "*-win_arm64" inherit.environment = "append" -environment.SZ_ARM64 = "1" +environment._SZ_IS_ARM64 = "1" [[tool.cibuildwheel.overrides]] select = "*-manylinux*_aarch64" inherit.environment = "append" -environment.SZ_ARM64 = "1" +environment._SZ_IS_ARM64 = "1" [[tool.cibuildwheel.overrides]] select = "*-musllinux*_aarch64" inherit.environment = "append" -environment.SZ_ARM64 = "1" +environment._SZ_IS_ARM64 = "1" [[tool.cibuildwheel.overrides]] select = "*-macos*_arm64" inherit.environment = "append" -environment.SZ_ARM64 = "1" +environment._SZ_IS_ARM64 = "1" # Detect MacOS Universal2 builds [[tool.cibuildwheel.overrides]] select = "*-macos*_universal2" inherit.environment = "append" -environment.SZ_X86_64 = "1" -environment.SZ_ARM64 = "1" +environment._SZ_IS_X86_64 = "1" +environment._SZ_IS_ARM64 = "1" [tool.cibuildwheel.macos.environment] MACOSX_DEPLOYMENT_TARGET = "10.11" From a30b5b7f54dd586b703415162707cb3b22e40bd4 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 2 Mar 2025 00:36:26 +0000 Subject: [PATCH 140/751] Improve: Inline most common Rust APIs --- rust/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/rust/lib.rs b/rust/lib.rs index 58553d75..a7f2e0ba 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -153,6 +153,7 @@ pub mod sz { /// # Returns /// /// A `u64` representing the checksum value of the input byte slice. + #[inline(always)] pub fn bytesum(text: T) -> u64 where T: AsRef<[u8]>, @@ -177,6 +178,7 @@ pub mod sz { /// # Returns /// /// A `u64` representing the hash value of the input byte slice. + #[inline(always)] pub fn hash_with_seed(text: T, seed: u64) -> u64 where T: AsRef<[u8]>, @@ -200,6 +202,7 @@ pub mod sz { /// # Returns /// /// A `u64` representing the hash value of the input byte slice. + #[inline(always)] pub fn hash(text: T) -> u64 where T: AsRef<[u8]>, @@ -253,6 +256,7 @@ pub mod sz { /// /// An `Option` representing the starting index of the last occurrence of `needle` /// within `haystack` if found, otherwise `None`. + #[inline(always)] pub fn rfind(haystack: H, needle: N) -> Option where H: AsRef<[u8]>, @@ -286,6 +290,7 @@ pub mod sz { /// /// An `Option` representing the index of the first occurrence of any byte from /// `needles` within `haystack`, if found, otherwise `None`. + #[inline(always)] pub fn find_byte_from(haystack: H, needles: N) -> Option where H: AsRef<[u8]>, From 9a32744b7046c18039e6e6b4c1d33b9832726638 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 2 Mar 2025 00:37:48 +0000 Subject: [PATCH 141/751] Add: Dispatched version API --- c/lib.c | 4 + include/stringzilla/stringzilla.h | 55 ++++++++++++++ python/lib.c | 15 +--- rust/lib.rs | 118 +++++++++++++++++++++++++++++- 4 files changed, 176 insertions(+), 16 deletions(-) diff --git a/c/lib.c b/c/lib.c index 3132a8f5..9c6324dd 100644 --- a/c/lib.c +++ b/c/lib.c @@ -385,9 +385,13 @@ BOOL WINAPI _DllMainCRTStartup(HINSTANCE hints, DWORD forward_reason, LPVOID lp) __attribute__((constructor)) static void sz_dispatch_table_init_on_gcc_or_clang(void) { sz_dispatch_table_init(); } #endif +SZ_DYNAMIC int sz_dynamic_dispatch(void) { return 1; } SZ_DYNAMIC int sz_version_major(void) { return STRINGZILLA_H_VERSION_MAJOR; } SZ_DYNAMIC int sz_version_minor(void) { return STRINGZILLA_H_VERSION_MINOR; } SZ_DYNAMIC int sz_version_patch(void) { return STRINGZILLA_H_VERSION_PATCH; } +SZ_DYNAMIC sz_cptr_t sz_capabilities_to_string(sz_capability_t caps) { + return _sz_capabilities_to_string_implementation(caps); +} SZ_DYNAMIC sz_u64_t sz_bytesum(sz_cptr_t text, sz_size_t length) { return sz_dispatch_table.bytesum(text, length); } diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index c497d4f1..824bacd4 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -80,17 +80,72 @@ typedef enum { */ SZ_DYNAMIC sz_capability_t sz_capabilities(void); +/** + * @brief Internal helper function to convert SIMD capabilities to a string. + * @sa sz_capabilities_to_string, sz_capabilities + */ +SZ_INTERNAL sz_cptr_t _sz_capabilities_to_string_implementation(sz_capability_t caps) { + + static char buf[256]; + char *p = buf; + char *const end = buf + sizeof(buf); + + // Mapping each flag to its string literal. + struct { + sz_capability_t flag; + char const *name; + } capability_map[] = { + {sz_cap_serial_k, "serial"}, {sz_cap_haswell_k, "haswell"}, {sz_cap_skylake_k, "skylake"}, + {sz_cap_ice_k, "ice"}, {sz_cap_neon_k, "neon"}, {sz_cap_neon_aes_k, "neon+aes"}, + {sz_cap_sve_k, "sve"}, {sz_cap_sve2_k, "sve2"}, {sz_cap_sve2_aes_k, "sve2+aes"}, + }; + int const capabilities_count = sizeof(capability_map) / sizeof(capability_map[0]); + + // Iterate over each capability flag. + for (int i = 0; i < capabilities_count; i++) { + if (caps & capability_map[i].flag) { + int const is_first = p == buf; + // Add separator if this is not the first capability. + if (!is_first) { + char const sep[3] = {',', ' ', '\0'}; + char const *s = sep; + while (*s && p < end - 1) *p++ = *s++; + } + // Append the capability name character by character. + char const *s = capability_map[i].name; + while (*s && p < end - 1) *p++ = *s++; + } + } + + // If no capability was added, write "none". + int const nothing_detected = p == buf; + if (nothing_detected) { + char const *s = "none"; + while (*s && p < end - 1) *p++ = *s++; + } + + // Null-terminate the string. + *p = '\0'; + return buf; +} + #if defined(SZ_DYNAMIC_DISPATCH) +SZ_DYNAMIC int sz_dynamic_dispatch(void); SZ_DYNAMIC int sz_version_major(void); SZ_DYNAMIC int sz_version_minor(void); SZ_DYNAMIC int sz_version_patch(void); +SZ_DYNAMIC sz_cptr_t sz_capabilities_to_string(sz_capability_t caps); #else +SZ_DYNAMIC int sz_dynamic_dispatch(void) { return 0; } SZ_PUBLIC int sz_version_major(void) { return STRINGZILLA_H_VERSION_MAJOR; } SZ_PUBLIC int sz_version_minor(void) { return STRINGZILLA_H_VERSION_MINOR; } SZ_PUBLIC int sz_version_patch(void) { return STRINGZILLA_H_VERSION_PATCH; } +SZ_PUBLIC sz_cptr_t sz_capabilities_to_string(sz_capability_t caps) { + return _sz_capabilities_to_string_implementation(caps); +} #endif diff --git a/python/lib.c b/python/lib.c index 46ed1c51..cf6ec6fb 100644 --- a/python/lib.c +++ b/python/lib.c @@ -3726,20 +3726,7 @@ PyMODINIT_FUNC PyInit_stringzilla(void) { // Define SIMD capabilities { sz_capability_t caps = sz_capabilities(); - char caps_str[512]; - char const *serial = (caps & sz_cap_serial_k) ? "serial," : ""; - char const *neon = (caps & sz_cap_neon_k) ? "neon," : ""; - char const *neon_aes = (caps & sz_cap_neon_aes_k) ? "neon_aes," : ""; - char const *sve = (caps & sz_cap_sve_k) ? "sve," : ""; - char const *sve2 = (caps & sz_cap_sve2_k) ? "sve2," : ""; - char const *sve2_aes = (caps & sz_cap_sve2_aes_k) ? "sve2_aes," : ""; - char const *haswell = (caps & sz_cap_haswell_k) ? "haswell," : ""; - char const *skylake = (caps & sz_cap_skylake_k) ? "skylake," : ""; - char const *ice = (caps & sz_cap_ice_k) ? "ice," : ""; - sprintf(caps_str, "%s%s%s%s%s%s%s%s%s", // - serial, // - neon, neon_aes, sve, sve2, sve2_aes, // - haswell, skylake, ice); + sz_cptr_t caps_str = sz_capability_to_string(caps); PyModule_AddStringConstant(m, "__capabilities__", caps_str); } diff --git a/rust/lib.rs b/rust/lib.rs index a7f2e0ba..c8dd629f 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -1,7 +1,7 @@ #![cfg_attr(not(test), no_std)] /// The `sz` module provides a collection of string searching and manipulation functionality, -/// designed for high efficiency and compatibility with no_std environments. This module offers +/// designed for high efficiency and compatibility with `no_std` environments. This module offers /// various utilities for byte string manipulation, including search, reverse search, and /// edit-distance calculations, suitable for a wide range of applications from basic string /// processing to complex text analysis tasks. @@ -63,10 +63,19 @@ pub mod sz { } } - use core::{ffi::c_void, usize}; + use core::fmt::{self, Write}; + use core::{ffi::c_void, ffi::CStr, usize}; // Import the functions from the StringZilla C library. extern "C" { + + fn sz_dynamic_dispatch() -> i32; + fn sz_version_major() -> i32; + fn sz_version_minor() -> i32; + fn sz_version_patch() -> i32; + fn sz_capabilities() -> u32; + fn sz_capabilities_to_string(caps: u32) -> *const c_void; + fn sz_find( haystack: *const c_void, haystack_length: usize, @@ -91,6 +100,8 @@ pub mod sz { fn sz_fill_random(text: *mut c_void, length: usize, seed: u64); + // fn sz_sort() -> Status; + pub fn sz_levenshtein_distance( a: *const c_void, a_length: usize, @@ -142,6 +153,103 @@ pub mod sz { } + /// A simple semantic version structure. + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct SemVer { + pub major: i32, + pub minor: i32, + pub patch: i32, + } + + impl SemVer { + pub const fn new(major: i32, minor: i32, patch: i32) -> Self { + Self { major, minor, patch } + } + } + + /// Checks if the library was compiled with dynamic dispatch enabled. + pub fn dynamic_dispatch() -> bool { + unsafe { sz_dynamic_dispatch() != 0 } + } + + /// Returns the semantic version information. + pub fn version() -> SemVer { + SemVer { + major: unsafe { sz_version_major() }, + minor: unsafe { sz_version_minor() }, + patch: unsafe { sz_version_patch() }, + } + } + + /// A fixed-size, compile-time known C-string buffer type. + /// It keeps track of the number of written bytes (excluding the null terminator). + pub struct FixedCString { + buf: [u8; N], + len: usize, + } + + impl FixedCString { + /// Create a new, empty buffer. + /// The buffer always has a terminating NUL (0) byte at position `len`. + pub const fn new() -> Self { + Self { buf: [0u8; N], len: 0 } + } + + /// Returns the raw pointer to the C string. + pub fn as_ptr(&self) -> *const u8 { + self.buf.as_ptr() + } + + /// Returns a reference as a CStr. + /// # Safety + /// The buffer must be correctly NUL terminated. + pub fn as_c_str(&self) -> &CStr { + // We know buf[..=len] is NUL-terminated because write_str() always sets it. + unsafe { CStr::from_bytes_with_nul_unchecked(&self.buf[..=self.len]) } + } + + /// Returns the current content as a &str. + /// Returns an empty string if the content isn’t valid UTF‑8. + pub fn as_str(&self) -> &str { + core::str::from_utf8(&self.buf[..self.len]).unwrap_or("") + } + } + + impl Write for FixedCString { + fn write_str(&mut self, s: &str) -> fmt::Result { + let bytes = s.as_bytes(); + // Ensure we have room for the new bytes and a NUL terminator. + if self.len + bytes.len() >= N { + return Err(fmt::Error); + } + self.buf[self.len..self.len + bytes.len()].copy_from_slice(bytes); + self.len += bytes.len(); + // Always set a null terminator. + self.buf[self.len] = 0; + Ok(()) + } + } + + pub type SmallCString = FixedCString<256>; + + /// Copies the capabilities C-string into a fixed buffer and returns it. + /// The returned SmallCString is guaranteed to be null-terminated. + pub fn capabilities() -> SmallCString { + let caps = unsafe { sz_capabilities() }; + let caps_ptr = unsafe { sz_capabilities_to_string(caps) }; + // Assume that the external function returns a valid null-terminated C string. + let cstr = unsafe { CStr::from_ptr(caps_ptr as *const i8) }; + let bytes = cstr.to_bytes(); + + let mut buf = SmallCString::new(); + // Use core::fmt::Write to copy the bytes. + // If the string is too long, it will fail. You might want to truncate in a real-world use. + // Here, we assume it fits. + let s = core::str::from_utf8(bytes).unwrap_or(""); + let _ = buf.write_str(s); + buf + } + /// Computes the checksum value of unsigned bytes in a given byte slice `text`. /// This function is useful for verifying data integrity and detecting changes in /// binary data, such as files or network packets. @@ -1481,6 +1589,12 @@ mod tests { use crate::sz; // For global functions use crate::StringZilla; // For member functions + #[test] + fn metadata() { + assert!(sz::dynamic_dispatch()); + assert!(sz::capabilities().as_str().len() > 0); + } + #[test] fn hamming() { assert_eq!(sz::hamming_distance("hello", "hello"), Ok(0)); From b2085ccc4ee3cdfbca595f73c2cc9bcfe69e2587 Mon Sep 17 00:00:00 2001 From: Mikayel Grigoryan Date: Sun, 2 Mar 2025 10:43:29 +0400 Subject: [PATCH 142/751] Improve: Exposed sz_move, sz_fill, and sz_copy for Rust --- rust/lib.rs | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/rust/lib.rs b/rust/lib.rs index c8dd629f..26a86ae8 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -68,7 +68,9 @@ pub mod sz { // Import the functions from the StringZilla C library. extern "C" { - + fn sz_copy(target: *const c_void, source: *const c_void, length: usize); + fn sz_fill(target: *const c_void, length: usize, value: u8); + fn sz_move(target: *const c_void, source: *const c_void, length: usize); fn sz_dynamic_dispatch() -> i32; fn sz_version_major() -> i32; fn sz_version_minor() -> i32; @@ -273,6 +275,57 @@ pub mod sz { return result; } + /// Moves the contents of `source` into `target`, overwriting the existing contents of `target`. + /// This function is useful for scenarios where you need to replace the contents of a byte slice + /// with the contents of another byte slice. + pub fn move_bytes(target: &mut T, source: &S) + where + T: AsMut<[u8]> + ?Sized, + S: AsRef<[u8]> + ?Sized, + { + let target_slice = target.as_mut(); + let source_slice = source.as_ref(); + unsafe { + sz_move( + target_slice.as_mut_ptr() as *const c_void, + source_slice.as_ptr() as *const c_void, + source_slice.len(), + ); + } + } + + /// Fills the contents of `target` with the specified `value`. This function is useful for + /// scenarios where you need to set all bytes in a byte slice to a specific value, such as + /// zeroing out a buffer or initializing a buffer with a specific byte pattern. + pub fn fill(target: &mut T, value: u8) + where + T: AsMut<[u8]> + ?Sized, + { + let target_slice = target.as_mut(); + unsafe { + sz_fill(target_slice.as_ptr() as *const c_void, target_slice.len(), value); + } + } + + /// Copies the contents of `source` into `target`, overwriting the existing contents of `target`. + /// This function is useful for scenarios where you need to replace the contents of a byte slice + /// with the contents of another byte slice. + pub fn copy(target: &mut T, source: &S) + where + T: AsMut<[u8]> + ?Sized, + S: AsRef<[u8]> + ?Sized, + { + let target_slice = target.as_mut(); + let source_slice = source.as_ref(); + unsafe { + sz_copy( + target_slice.as_mut_ptr() as *mut c_void, + source_slice.as_ptr() as *const c_void, + source_slice.len(), + ); + } + } + /// Computes a 64-bit AES-based hash value for a given byte slice `text`. /// This function is designed to provide a high-quality hash value for use in /// hash tables, data structures, and cryptographic applications. From 1757e4ec50d9a177cd5f64d96337d51cbe28d4fc Mon Sep 17 00:00:00 2001 From: Mikayel Grigoryan Date: Sun, 2 Mar 2025 11:08:59 +0400 Subject: [PATCH 143/751] Improve: Expose sz_hash_state_init, sz_hash_state_stream, and sz_hash_state_fold to Rust --- rust/lib.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rust/lib.rs b/rust/lib.rs index 26a86ae8..784c5454 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -153,6 +153,14 @@ pub mod sz { result: *mut isize, ) -> Status; + /// Initializes a hash state with a given seed value. + fn sz_hash_state_init(state: *const c_void, seed: u64); + + /// Updates the hash state with a new byte slice. + fn sz_hash_state_stream(state: *const c_void, text: *const c_void, length: usize); + + /// Finalizes the hash state and returns the computed hash value. + fn sz_hash_state_fold(state: *const c_void) -> u64; } /// A simple semantic version structure. From 471b0024db2b4c183f3277ca08f59a2af89eec24 Mon Sep 17 00:00:00 2001 From: Mikayel Grigoryan Date: Sun, 2 Mar 2025 11:14:19 +0400 Subject: [PATCH 144/751] Improve: Expose sz_lookup --- rust/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rust/lib.rs b/rust/lib.rs index 784c5454..14f10c7e 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -161,6 +161,8 @@ pub mod sz { /// Finalizes the hash state and returns the computed hash value. fn sz_hash_state_fold(state: *const c_void) -> u64; + + fn sz_lookup(target: *const c_void, length: usize, source: *const c_void, lut: *const u8) -> *const c_void; } /// A simple semantic version structure. From 9fe25df8abbe428fc7b2255dce19e96acd3d444f Mon Sep 17 00:00:00 2001 From: Mikayel Grigoryan Date: Sun, 2 Mar 2025 11:26:45 +0400 Subject: [PATCH 145/751] Improve: Remove redundant comments from sz_hash_state functions in Rust --- rust/lib.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/rust/lib.rs b/rust/lib.rs index 14f10c7e..b3cbca4e 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -153,13 +153,10 @@ pub mod sz { result: *mut isize, ) -> Status; - /// Initializes a hash state with a given seed value. fn sz_hash_state_init(state: *const c_void, seed: u64); - /// Updates the hash state with a new byte slice. fn sz_hash_state_stream(state: *const c_void, text: *const c_void, length: usize); - /// Finalizes the hash state and returns the computed hash value. fn sz_hash_state_fold(state: *const c_void) -> u64; fn sz_lookup(target: *const c_void, length: usize, source: *const c_void, lut: *const u8) -> *const c_void; From c7b841e6eba0ec0600634daaf9b788f8916c425f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 4 Mar 2025 07:35:09 +0000 Subject: [PATCH 146/751] Add: Serial JOINs --- include/stringzilla/sort.h | 78 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 4 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 721ba940..d619369e 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -847,11 +847,81 @@ SZ_PUBLIC sz_status_t sz_pgrams_join_serial(sz_pgram_t *pgrams, sz_size_t count, return sz_success_k; } -SZ_PUBLIC sz_status_t sz_sequence_join_serial( // - sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // - sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // +SZ_PUBLIC sz_status_t sz_sequence_join_serial( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_count_ptr, // sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { - sz_unused(first_sequence && second_sequence && alloc && intersection_size && first_positions && second_positions); + + // To join to unordered sets of strings, the simplest approach would be to hash them into a dynamically + // allocated hash table and then iterate over the second set, checking for the presence of each element in the + // hash table. This would require O(N) memory and O(N) time complexity, where N is the smaller set. + sz_sequence_t const *small_sequence, *large_sequence; + sz_sorted_idx_t *small_positions, *large_positions; + if (first_sequence->count <= second_sequence->count) { + small_sequence = first_sequence, large_sequence = second_sequence; + small_positions = first_positions, large_positions = second_positions; + } + else { + small_sequence = second_sequence, large_sequence = first_sequence; + small_positions = second_positions, large_positions = first_positions; + } + + // We may very well have nothing to join + if (small_sequence->count == 0) { + *intersection_count_ptr = 0; + return sz_success_k; + } + + // Allocate memory for the hash table and initialize it with 0xFF. + sz_size_t const hash_table_slots = sz_size_bit_ceil(small_sequence->count * 2); + sz_size_t const bytes_per_entry = sizeof(sz_size_t) + sizeof(sz_u64_t); + sz_size_t *table_positions = (sz_size_t *)alloc->allocate(hash_table_slots * bytes_per_entry, alloc); + if (!table_positions) return sz_bad_alloc_k; + sz_u64_t *table_fingerprints = (sz_u64_t *)(table_positions + hash_table_slots); + sz_fill((sz_ptr_t)table_positions, hash_table_slots * bytes_per_entry, 0xFF); + + // Hash the smaller set into the hash table using the default available backend. + for (sz_size_t small_position = 0; small_position < small_sequence->count; ++small_position) { + sz_cptr_t const str = small_sequence->get_start(small_sequence->handle, small_position); + sz_size_t const length = small_sequence->get_length(small_sequence->handle, small_position); + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash; + // Implement linear probing to resolve collisions. + while (table_positions[hash_slot & (hash_table_slots - 1)] != SZ_SIZE_MAX) ++hash_slot; + table_positions[hash_slot & (hash_table_slots - 1)] = small_position; + table_fingerprints[hash_slot & (hash_table_slots - 1)] = hash; + } + + // Iterate over the larger set and check for the presence of each element in the hash table. + sz_size_t intersection_count = 0; + for (sz_size_t large_position = 0; large_position < large_sequence->count; ++large_position) { + sz_cptr_t const str = large_sequence->get_start(large_sequence->handle, large_position); + sz_size_t const length = large_sequence->get_length(large_sequence->handle, large_position); + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash; + // Implement linear probing to resolve collisions. + for (; table_positions[hash_slot & (hash_table_slots - 1)] != SZ_SIZE_MAX; ++hash_slot) { + sz_u64_t small_hash = table_fingerprints[hash_slot & (hash_table_slots - 1)]; + if (small_hash != hash) continue; + + // The hash matches, compare the strings. + sz_size_t const small_position = table_positions[hash_slot & (hash_table_slots - 1)]; + sz_size_t const small_length = small_sequence->get_length(small_sequence->handle, small_position); + if (length != small_length) continue; + + sz_cptr_t const small_str = small_sequence->get_start(small_sequence->handle, small_position); + sz_bool_t const same = sz_equal(str, small_str, length); + if (same != sz_true_k) continue; + + // Finally, there is a match, store the positions. + small_positions[intersection_count] = small_position; + large_positions[intersection_count] = large_position; + ++intersection_count; + break; + } + } + + *intersection_count_ptr = intersection_count; return sz_success_k; } From 75fabf1ba4c241790b0ecfd69c22c7c0a821ec4f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 4 Mar 2025 07:36:02 +0000 Subject: [PATCH 147/751] Fix: Passing `sz_sequence_t::handle` --- include/stringzilla/sort.h | 20 ++++++++++---------- scripts/bench_sort.cpp | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index d619369e..16ca31fe 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -264,10 +264,10 @@ SZ_PUBLIC void sz_sequence_argsort_with_insertion(sz_sequence_t const *sequence, while (j > 0) { // Get the two strings to compare. sz_sorted_idx_t previous_idx = order[j - 1]; - sz_cptr_t previous_start = sequence->get_start(sequence, previous_idx); - sz_cptr_t current_start = sequence->get_start(sequence, current_idx); - sz_size_t previous_length = sequence->get_length(sequence, previous_idx); - sz_size_t current_length = sequence->get_length(sequence, current_idx); + sz_cptr_t previous_start = sequence->get_start(sequence->handle, previous_idx); + sz_cptr_t current_start = sequence->get_start(sequence->handle, current_idx); + sz_size_t previous_length = sequence->get_length(sequence->handle, previous_idx); + sz_size_t current_length = sequence->get_length(sequence->handle, current_idx); // Use the provided sz_order to compare. sz_ordering_t ordering = sz_order(previous_start, previous_length, current_start, current_length); @@ -470,8 +470,8 @@ SZ_INTERNAL void _sz_sequence_argsort_serial_export_next_pgrams( _sz_assert(partial_order_index == i && "At start this must be an identity permutation."); // Get the string slice in global memory. - sz_cptr_t const source_str = sequence->get_start(sequence, partial_order_index); - sz_size_t const length = sequence->get_length(sequence, partial_order_index); + sz_cptr_t const source_str = sequence->get_start(sequence->handle, partial_order_index); + sz_size_t const length = sequence->get_length(sequence->handle, partial_order_index); sz_size_t const remaining_length = length > start_character ? length - start_character : 0; sz_size_t const exported_length = remaining_length > pgram_capacity ? pgram_capacity : remaining_length; @@ -497,10 +497,10 @@ SZ_INTERNAL void _sz_sequence_argsort_serial_export_next_pgrams( for (sz_size_t i = start_in_sequence + 1; i < end_in_sequence; ++i) { sz_pgram_t const previous_pgram = global_pgrams[i - 1]; sz_pgram_t const current_pgram = global_pgrams[i]; - sz_cptr_t const previous_str = sequence->get_start(sequence, i - 1); - sz_size_t const previous_length = sequence->get_length(sequence, i - 1); - sz_cptr_t const current_str = sequence->get_start(sequence, i); - sz_size_t const current_length = sequence->get_length(sequence, i); + sz_cptr_t const previous_str = sequence->get_start(sequence->handle, i - 1); + sz_size_t const previous_length = sequence->get_length(sequence->handle, i - 1); + sz_cptr_t const current_str = sequence->get_start(sequence->handle, i); + sz_size_t const current_length = sequence->get_length(sequence->handle, i); sz_ordering_t const ordering = sz_order( // previous_str, previous_length > pgram_capacity ? pgram_capacity : previous_length, // current_str, current_length > pgram_capacity ? pgram_capacity : current_length); diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index a045192f..f32d9909 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -42,10 +42,10 @@ static int _get_qsort_order(const void *a, const void *b, void *arg) { sz_size_t idx_a = *(sz_size_t *)a; sz_size_t idx_b = *(sz_size_t *)b; - char const *str_a = sequence->get_start(sequence, idx_a); - char const *str_b = sequence->get_start(sequence, idx_b); - sz_size_t len_a = sequence->get_length(sequence, idx_a); - sz_size_t len_b = sequence->get_length(sequence, idx_b); + char const *str_a = sequence->get_start(sequence->handle, idx_a); + char const *str_b = sequence->get_start(sequence->handle, idx_b); + sz_size_t len_a = sequence->get_length(sequence->handle, idx_a); + sz_size_t len_b = sequence->get_length(sequence->handle, idx_b); int res = strncmp(str_a, str_b, len_a < len_b ? len_a : len_b); return res ? res : (int)(len_a - len_b); From ea5dc76c2ff9f8b09913a271c4e1fe3696ca817b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 7 Mar 2025 12:26:17 +0000 Subject: [PATCH 148/751] Add: Intersections on Ice Lake --- include/stringzilla/intersect.h | 749 ++++++++++++++++++++++++++++++ include/stringzilla/sort.h | 305 +----------- include/stringzilla/stringzilla.h | 3 +- 3 files changed, 756 insertions(+), 301 deletions(-) create mode 100644 include/stringzilla/intersect.h diff --git a/include/stringzilla/intersect.h b/include/stringzilla/intersect.h new file mode 100644 index 00000000..77033148 --- /dev/null +++ b/include/stringzilla/intersect.h @@ -0,0 +1,749 @@ +/** + * @brief Hardware-accelerated string collection intersections for JOIN-like DBMS operations. + * @file intersect.h + * @author Ash Vardanian + * + * Includes core APIs for `sz_sequence_t` string collections with hardware-specific backends: + * + * - `sz_sequence_intersection` - to compute the strict intersection of two deduplicated string collections. + * - TODO: `sz_sequence_join` - to compute the intersection of two arbitrary string collections. + */ +#ifndef STRINGZILLA_INTERSECT_H_ +#define STRINGZILLA_INTERSECT_H_ + +#include "types.h" + +#include "compare.h" // `sz_compare` +#include "memory.h" // `sz_fill` +#include "hash.h" // `sz_hash` + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief The @b power-of-two memory-usage budget @b multiple for the hash table. + * + * The behaviour of hashing-based approaches can often be tuned with different "hyper-parameter" values. + * For "unordered set intersections" implemented here, the @p budget argument controls the balance between + * throughput and memory usage. The higher the budget, the more memory is used, but the fewer collisions + * will be observed + */ +#if !defined(SZ_SEQUENCE_INTERSECT_BUDGET) +#define SZ_SEQUENCE_INTERSECT_BUDGET (1) +#endif + +#pragma region Core API + +/** + * @brief Intersects two @b deduplicated binary @b string sequences, using a hash table. + * Outputs the @p first_positions from the @p first_sequence and @p second_positions from + * the @p second_sequence, that contain matched strings. Missing matches are represented as `SZ_SIZE_MAX`. + * + * @param[in] first_sequence First immutable sequence of strings to intersection. + * @param[in] second_sequence Second immutable sequence of strings to intersection. + * @param[in] semantics JOIN semantics for the intersection, including handling of duplicates. + * @param[in] alloc Optional memory allocator for temporary storage. + * @param[in] seed Optional seed for the hash table to avoid attacks. + * @param[out] intersection_size Number of matching strings in both sequences. + * @param[out] first_positions Offset positions of the matching strings from the @p first_sequence. + * @param[out] second_positions Offset positions of the matching strings from the @p second_sequence. + * + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @pre The @p first_positions array must fit at least `min(first_sequence->count, second_sequence->count)` items. + * @pre The @p second_positions array must fit at least `min(first_sequence->count, second_sequence->count)` items. + * @warning Doesn't check for duplicates and won't return `sz_contains_duplicates_k`. Duplicates result in UB. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char const *first[] = {"banana", "apple", "cherry"}; + * char const *second[] = {"cherry", "orange", "pineapple", "banana"}; + * sz_sequence_t first_sequence, second_sequence; + * sz_sequence_from_null_terminated_strings(first, 3, &first_sequence); + * sz_sequence_from_null_terminated_strings(second, 4, &second_sequence); + * sz_size_t intersection_size; + * sz_sorted_idx_t first_positions[3], second_positions[3]; //? 3 is the size of the smaller sequence + * sz_status_t status = sz_sequence_intersect(&first_sequence, &second_sequence, + * sz_join_inner_strict_k, NULL, 0, + * &intersection_size, first_positions, second_positions); + * return status == sz_success_k && intersection_size == 2 ? 0 : 1; + * } + * @endcode + * + * @note The algorithm has linear memory complexity and linear time complexity. + * @see https://en.wikipedia.org/wiki/Join_(SQL) + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_sequence_intersect_serial, sz_sequence_intersect_ice, sz_sequence_intersect_sve + */ +SZ_DYNAMIC sz_status_t sz_sequence_intersect(sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_size, + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); + +/** + * @brief Defines various JOIN semantics for string sequences, including handling of duplicates. + * @sa sz_join_inner_strict_k, sz_join_inner_k, sz_join_left_outer_k, sz_join_right_outer_k, sz_join_full_outer_k, + * sz_join_cross_k + */ +typedef enum { + /** + * @brief Strict inner join with uniqueness enforcement. + * + * In this mode, only unique matching strings from both sequences are returned. + * If either sequence contains duplicate strings, the operation will fail. + * + * Example: + * - Input: + * first_sequence: { "apple", "banana", "cherry" } + * second_sequence: { "banana", "cherry", "date" } + * - Output: + * Result: { ("banana", "banana"), ("cherry", "cherry") } + * + * SQL equivalent: + * @code{.sql} + * -- Returns unique matching rows only. + * SELECT DISTINCT a.* + * FROM first_sequence a + * INNER JOIN second_sequence b ON a.string = b.string; + * @endcode + */ + sz_join_inner_strict_k = 0, + + /** + * @brief Conventional inner join allowing duplicate entries. + * + * This mode returns all pairs of matching strings from both sequences. + * Each occurrence in the first sequence is paired with every matching occurrence + * in the second sequence. Order stability is not guaranteed. + * + * Example: + * - Input: + * first_sequence: { "apple", "banana", "banana" } + * second_sequence: { "banana", "banana", "cherry" } + * - Output: + * Result: { ("banana", "banana"), ("banana", "banana"), + * ("banana", "banana"), ("banana", "banana") } + * (2 occurrences of "banana" in the first sequence × 2 in the second = 4 pairs) + * + * SQL equivalent: + * @code{.sql} + * SELECT a.*, b.* + * FROM first_sequence a + * INNER JOIN second_sequence b ON a.string = b.string; + * @endcode + */ + sz_join_inner_k = 1, + + /** + * @brief Left outer join preserving all entries from the first sequence. + * + * This mode returns every string from the first sequence along with matching strings + * from the second sequence. If no match is found for an element in the first sequence, + * the corresponding output for the second sequence is NULL (or its equivalent). + * + * Example: + * - Input: + * first_sequence: { "apple", "banana", "cherry" } + * second_sequence: { "banana", "cherry", "date" } + * - Output: + * Result: { ("apple", NULL), ("banana", "banana"), ("cherry", "cherry") } + * + * SQL equivalent: + * @code{.sql} + * SELECT a.*, b.* + * FROM first_sequence a + * LEFT OUTER JOIN second_sequence b ON a.string = b.string; + * @endcode + */ + sz_join_left_outer_k = 2, + + /** + * @brief Right outer join preserving all entries from the second sequence. + * + * This mode returns every string from the second sequence along with matching strings + * from the first sequence. If no match is found for an element in the second sequence, + * the corresponding output for the first sequence is NULL (or its equivalent). + * + * Example: + * - Input: + * first_sequence: { "apple", "banana" } + * second_sequence: { "banana", "cherry", "date" } + * - Output: + * Result: { ("banana", "banana"), (NULL, "cherry"), (NULL, "date") } + * + * SQL equivalent: + * @code{.sql} + * SELECT a.*, b.* + * FROM first_sequence a + * RIGHT OUTER JOIN second_sequence b ON a.string = b.string; + * @endcode + */ + sz_join_right_outer_k = 3, + + /** + * @brief Full outer join combining all entries from both sequences. + * + * This mode returns all matching pairs along with unmatched strings from both sequences. + * For unmatched strings, the corresponding result from the other sequence is NULL. + * + * Example: + * - Input: + * first_sequence: { "apple", "banana" } + * second_sequence: { "banana", "cherry" } + * - Output: + * Result: { ("apple", NULL), ("banana", "banana"), (NULL, "cherry") } + * + * SQL equivalent: + * @code{.sql} + * SELECT a.*, b.* + * FROM first_sequence a + * FULL OUTER JOIN second_sequence b ON a.string = b.string; + * @endcode + */ + sz_join_full_outer_k = 4, + + /** + * @brief Cross join (Cartesian product) of two sequences. + * + * This mode returns the Cartesian product of both sequences, pairing every string in the first sequence + * with every string in the second sequence regardless of any matching condition. + * + * Example: + * - Input: + * first_sequence: { "apple", "banana" } + * second_sequence: { "cherry", "date" } + * - Output: + * Result: { ("apple", "cherry"), ("apple", "date"), + * ("banana", "cherry"), ("banana", "date") } + * + * SQL equivalent: + * @code{.sql} + * SELECT a.*, b.* + * FROM first_sequence a, second_sequence b; + * @endcode + */ + sz_join_cross_k = 5, +} sz_sequence_join_semantics_t; + +#if SZ_USE_ICE + +/** @copydoc sz_sequence_intersect */ +SZ_PUBLIC sz_status_t sz_sequence_intersect_ice( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_size, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); + +#endif + +#if SZ_USE_SVE + +/** @copydoc sz_sequence_intersect */ +SZ_PUBLIC sz_status_t sz_sequence_intersect_sve( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_size, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); + +#endif + +#pragma endregion + +#pragma region Serial Implementation + +SZ_PUBLIC sz_status_t sz_sequence_intersect_serial( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_count_ptr, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { + + // To join to unordered sets of strings, the simplest approach would be to hash them into a dynamically + // allocated hash table and then iterate over the second set, checking for the presence of each element in the + // hash table. This would require O(N) memory and O(N) time complexity, where N is the smaller set. + sz_sequence_t const *small_sequence, *large_sequence; + sz_sorted_idx_t *small_positions, *large_positions; + if (first_sequence->count <= second_sequence->count) { + small_sequence = first_sequence, large_sequence = second_sequence; + small_positions = first_positions, large_positions = second_positions; + } + else { + small_sequence = second_sequence, large_sequence = first_sequence; + small_positions = second_positions, large_positions = first_positions; + } + + // We may very well have nothing to join + if (small_sequence->count == 0) { + *intersection_count_ptr = 0; + return sz_success_k; + } + + // Allocate memory for the hash table and initialize it with 0xFF. + // The higher is the `hash_table_slots` multiple - the more memory we will use, + // but the less likely the collisions will be. + sz_size_t const hash_table_slots = sz_size_bit_ceil(small_sequence->count) * (1 << SZ_SEQUENCE_INTERSECT_BUDGET); + sz_size_t const bytes_per_entry = sizeof(sz_size_t) + sizeof(sz_u64_t); + sz_size_t *const table_positions = (sz_size_t *)alloc->allocate(hash_table_slots * bytes_per_entry, alloc); + if (!table_positions) return sz_bad_alloc_k; + sz_u64_t *const table_hashes = (sz_u64_t *)(table_positions + hash_table_slots); + sz_fill((sz_ptr_t)table_positions, hash_table_slots * bytes_per_entry, 0xFF); + + // Hash the smaller set into the hash table using the default available backend. + for (sz_size_t small_position = 0; small_position < small_sequence->count; ++small_position) { + sz_cptr_t const str = small_sequence->get_start(small_sequence->handle, small_position); + sz_size_t const length = small_sequence->get_length(small_sequence->handle, small_position); + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash & (hash_table_slots - 1); + // Implement linear probing to find the first free slot. + // If we somehow face 2 different strings with same hash, we will export that hash 2 times! + while (table_hashes[hash_slot] != SZ_SIZE_MAX) hash_slot = (hash_slot + 1) & (hash_table_slots - 1); + table_hashes[hash_slot] = hash; + table_positions[hash_slot] = small_position; + } + + // Iterate over the larger set and check for the presence of each element in the hash table. + sz_size_t intersection_count = 0; + for (sz_size_t large_position = 0; large_position < large_sequence->count; ++large_position) { + sz_cptr_t const str = large_sequence->get_start(large_sequence->handle, large_position); + sz_size_t const length = large_sequence->get_length(large_sequence->handle, large_position); + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash & (hash_table_slots - 1); + + // Implement linear probing to resolve collisions. + for (; table_hashes[hash_slot] != SZ_SIZE_MAX; hash_slot = (hash_slot + 1) & (hash_table_slots - 1)) { + sz_u64_t small_hash = table_hashes[hash_slot]; + if (small_hash != hash) continue; + + // The hash matches, compare the strings. + sz_size_t const small_position = table_positions[hash_slot]; + sz_size_t const small_length = small_sequence->get_length(small_sequence->handle, small_position); + if (length != small_length) continue; + + // Same hash may still imply different strings, so we need to compare them. + sz_cptr_t const small_str = small_sequence->get_start(small_sequence->handle, small_position); + sz_bool_t const same = sz_equal(str, small_str, length); + if (same != sz_true_k) continue; + + // Finally, there is a match, store the positions. + small_positions[intersection_count] = small_position; + large_positions[intersection_count] = large_position; + ++intersection_count; + break; + } + } + + *intersection_count_ptr = intersection_count; + return sz_success_k; +} + +#pragma endregion // Serial Implementation + +/* AVX512 implementation of the string search algorithms for Ice Lake and newer CPUs. + * Includes extensions: + * - 2017 Skylake: F, CD, ER, PF, VL, DQ, BW, + * - 2018 CannonLake: IFMA, VBMI, + * - 2019 Ice Lake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES. + * + * We are going to use VBMI2 for `_mm256_maskz_compress_epi8`. + */ +#pragma region Ice Lake Implementation +#if SZ_USE_ICE +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vbmi", "avx512vnni", "bmi", "bmi2", \ + "aes", "vaes") +#pragma clang attribute push( \ + __attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,avx512vbmi,avx512vnni,bmi,bmi2,aes,vaes"))), \ + apply_to = function) + +SZ_INTERNAL int _sz_u64x4_contains_collisions_haswell(__m256i v) { + // Assume `v` stores values: [a, b, c, d]. + __m256i cmp1 = _mm256_cmpeq_epi64(v, _mm256_permute4x64_epi64(v, 0xB1)); // 0xB1 produces [b, a, d, c] + __m256i cmp2 = _mm256_cmpeq_epi64(v, _mm256_permute4x64_epi64(v, 0x4E)); // 0x4E produces [c, d, a, b] + __m256i cmp3 = _mm256_cmpeq_epi64(v, _mm256_permute4x64_epi64(v, 0x1B)); // 0x1B produces [d, c, b, a] + + // Combine the results from the three comparisons. + __m256i cmp = _mm256_or_si256(_mm256_or_si256(cmp1, cmp2), cmp3); + + // Each 64-bit lane comparison yields all ones if equal, so the movemask will be nonzero if any pair matched. + int mask = _mm256_movemask_epi8(cmp); + return mask; +} + +SZ_PUBLIC sz_status_t sz_sequence_intersect_ice( // + sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_count_ptr, // + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { + + // To join to unordered sets of strings, the simplest approach would be to hash them into a dynamically + // allocated hash table and then iterate over the second set, checking for the presence of each element in the + // hash table. This would require O(N) memory and O(N) time complexity, where N is the smaller set. + sz_sequence_t const *small_sequence, *large_sequence; + sz_sorted_idx_t *small_positions, *large_positions; + if (first_sequence->count <= second_sequence->count) { + small_sequence = first_sequence, large_sequence = second_sequence; + small_positions = first_positions, large_positions = second_positions; + } + else { + small_sequence = second_sequence, large_sequence = first_sequence; + small_positions = second_positions, large_positions = first_positions; + } + + // We may very well have nothing to join + if (small_sequence->count == 0) { + *intersection_count_ptr = 0; + return sz_success_k; + } + + // Allocate memory for the hash table and initialize it with 0xFF. + // The higher is the `hash_table_slots` multiple - the more memory we will use, + // but the less likely the collisions will be. + sz_size_t const hash_table_slots = sz_size_bit_ceil(small_sequence->count) * (1 << SZ_SEQUENCE_INTERSECT_BUDGET); + sz_size_t const bytes_per_entry = sizeof(sz_size_t) + sizeof(sz_u64_t); + sz_size_t *table_positions = (sz_size_t *)alloc->allocate(hash_table_slots * bytes_per_entry, alloc); + if (!table_positions) return sz_bad_alloc_k; + sz_u64_t *table_hashes = (sz_u64_t *)(table_positions + hash_table_slots); + sz_fill((sz_ptr_t)table_positions, hash_table_slots * bytes_per_entry, 0xFF); + + // Conceptually the Ice Lake variant is similar to the serial one, except it takes advantage of: + // - computing 4x individual high-quality hashes with `_mm512_aesenc_epi128`. + // - gathering values from the hash-table using `_mm256_mmask_i64gather_epi64`. + // + // We still start by hashing the smaller set into the hash table, but we will process 4 entries + // at a time and will separately handle values under 16 bytes fitting into one AES block and the + // larger values. + // + // For larger entries, we will use a separate loop afterwards to decrease the likelihood of collisions + // on the shorter entries, that can benefit from vectorized processing. + _sz_hash_minimal_x4_t batch_hashes_states_initial; + _sz_hash_minimal_x4_init_ice(&batch_hashes_states_initial, seed); + sz_size_t count_longer = 0; + for (sz_size_t small_position = 0; small_position < small_sequence->count;) { + sz_string_view_t batch[4]; + sz_u256_vec_t batch_positions; + sz_size_t batch_size; + for (batch_size = 0; batch_size < 4 && small_position < small_sequence->count; ++small_position) { + sz_size_t length = small_sequence->get_length(small_sequence->handle, small_position); + if (length > 16) { + count_longer++; + continue; + } + sz_cptr_t str = small_sequence->get_start(small_sequence->handle, small_position); + batch[batch_size].start = str; + batch[batch_size].length = length; + batch_positions.u64s[batch_size] = small_position; + ++batch_size; + } + + // If we couldn't populate the whole batch, fall back to the serial solution + if (batch_size != 4) { + for (sz_size_t i = 0; i < batch_size; ++i) { + sz_cptr_t const str = batch[i].start; + sz_size_t const length = batch[i].length; + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash & (hash_table_slots - 1); + // Implement linear probing to find the first free slot. + // If we somehow face 2 different strings with same hash, we will export that hash 2 times! + while (table_hashes[hash_slot] != SZ_SIZE_MAX) hash_slot = (hash_slot + 1) & (hash_table_slots - 1); + table_hashes[hash_slot] = hash; + table_positions[hash_slot] = batch_positions.u64s[i]; + } + } + // The batch is successfully populated, let's use the vectorized solution + else { + // Now let's load the first bytes of each string. + sz_u256_vec_t batch_hashes; + sz_u512_vec_t batch_prefixes; + batch_prefixes.xmms[0] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[0].length), batch[0].start); + batch_prefixes.xmms[1] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[1].length), batch[1].start); + batch_prefixes.xmms[2] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[2].length), batch[2].start); + batch_prefixes.xmms[3] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[3].length), batch[3].start); + + // Reuse the already computed state for hashes + _sz_hash_minimal_x4_t batch_hashes_states = batch_hashes_states_initial; + _sz_hash_minimal_x4_update_ice(&batch_hashes_states, batch_prefixes.zmm); + batch_hashes.ymm = _sz_hash_minimal_x4_finalize_ice(&batch_hashes_states, batch[0].length, batch[1].length, + batch[2].length, batch[3].length); + _sz_assert(batch_hashes.u64s[0] == sz_hash(batch[0].start, batch[0].length, seed)); + _sz_assert(batch_hashes.u64s[1] == sz_hash(batch[1].start, batch[1].length, seed)); + _sz_assert(batch_hashes.u64s[2] == sz_hash(batch[2].start, batch[2].length, seed)); + _sz_assert(batch_hashes.u64s[3] == sz_hash(batch[3].start, batch[3].length, seed)); + + // Now let's perform an optimistic hash-table lookup using vectorized gathers + sz_u256_vec_t batch_slots, existing_hashes; + batch_slots.ymm = _mm256_and_si256(batch_hashes.ymm, _mm256_set1_epi64x(hash_table_slots - 1)); + + // In case of very small inputs, it's more likely, that some of the 4x hashes or their slots will collide + int const has_slot_collisions = _sz_u64x4_contains_collisions_haswell(batch_slots.ymm); + + // Before scattering the new positions - gather the pre-existing ones. + // In case of `has_slot_collisions`, this will practically be a "prefetch" operation. + existing_hashes.ymm = + _mm256_mmask_i64gather_epi64(_mm256_setzero_si256(), 0xFF, batch_slots.ymm, table_hashes, 8); + + // Check that we don't have any collisions - in that case each value will be equal to `SZ_SIZE_MAX` + int const all_empty = _mm256_testc_si256(existing_hashes.ymm, _mm256_set1_epi64x(SZ_SIZE_MAX)); + if (all_empty && !has_slot_collisions) { + // Scatter the new positions + _mm256_mask_i64scatter_epi64(table_hashes, 0xFF, batch_slots.ymm, batch_hashes.ymm, 8); + _mm256_mask_i64scatter_epi64(table_positions, 0xFF, batch_slots.ymm, batch_positions.ymm, 8); + } + else { + // We have a collision, let's resolve it with a serial solution + for (sz_size_t i = 0; i < 4; ++i) { + sz_size_t hash_slot = batch_slots.u64s[i] & (hash_table_slots - 1); + // Implement linear probing to find the first free slot. + // If we somehow face 2 different strings with same hash, we will export that hash 2 times! + while (table_hashes[hash_slot] != SZ_SIZE_MAX) hash_slot = (hash_slot + 1) & (hash_table_slots - 1); + table_hashes[hash_slot] = batch_hashes.u64s[i]; + table_positions[hash_slot] = batch_positions.u64s[i]; + } + } + } + } + + // Now, let's cross-reference all shorter values from the larger collection. + sz_size_t intersection_count = 0; + for (sz_size_t large_position = 0; large_position < large_sequence->count;) { + sz_string_view_t batch[4]; + sz_u256_vec_t batch_positions; + sz_size_t batch_size; + for (batch_size = 0; batch_size < 4 && large_position < large_sequence->count; ++large_position) { + sz_size_t length = large_sequence->get_length(large_sequence->handle, large_position); + if (length > 16) { + count_longer++; + continue; + } + sz_cptr_t str = large_sequence->get_start(large_sequence->handle, large_position); + batch[batch_size].start = str; + batch[batch_size].length = length; + batch_positions.u64s[batch_size] = large_position; + ++batch_size; + } + + // If we couldn't populate the whole batch, fall back to the serial solution + if (batch_size != 4) { + for (sz_size_t i = 0; i < batch_size; ++i) { + sz_cptr_t const str = batch[i].start; + sz_size_t const length = batch[i].length; + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash & (hash_table_slots - 1); + // Implement linear probing to resolve collisions. + for (; table_hashes[hash_slot] != SZ_SIZE_MAX; hash_slot = (hash_slot + 1) & (hash_table_slots - 1)) { + sz_u64_t small_hash = table_hashes[hash_slot]; + if (small_hash != hash) continue; + + // The hash matches, compare the strings. + sz_size_t const small_position = table_positions[hash_slot]; + sz_size_t const small_length = small_sequence->get_length(small_sequence->handle, small_position); + if (length != small_length) continue; + + // Same hash may still imply different strings, so we need to compare them. + sz_cptr_t const small_str = small_sequence->get_start(small_sequence->handle, small_position); + sz_bool_t const same = sz_equal(str, small_str, length); + if (same != sz_true_k) continue; + + // Finally, there is a match, store the positions. + small_positions[intersection_count] = small_position; + large_positions[intersection_count] = batch_positions.u64s[i]; + ++intersection_count; + break; + } + } + } + // The batch is successfully populated, let's use the vectorized solution + else { + // Now let's load the first bytes of each string. + sz_u256_vec_t batch_hashes; + sz_u512_vec_t batch_prefixes; + batch_prefixes.xmms[0] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[0].length), batch[0].start); + batch_prefixes.xmms[1] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[1].length), batch[1].start); + batch_prefixes.xmms[2] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[2].length), batch[2].start); + batch_prefixes.xmms[3] = _mm_maskz_loadu_epi8(_sz_u16_mask_until(batch[3].length), batch[3].start); + + // Reuse the already computed state for hashes + _sz_hash_minimal_x4_t batch_hashes_states = batch_hashes_states_initial; + _sz_hash_minimal_x4_update_ice(&batch_hashes_states, batch_prefixes.zmm); + batch_hashes.ymm = _sz_hash_minimal_x4_finalize_ice(&batch_hashes_states, batch[0].length, batch[1].length, + batch[2].length, batch[3].length); + _sz_assert(batch_hashes.u64s[0] == sz_hash(batch[0].start, batch[0].length, seed)); + _sz_assert(batch_hashes.u64s[1] == sz_hash(batch[1].start, batch[1].length, seed)); + _sz_assert(batch_hashes.u64s[2] == sz_hash(batch[2].start, batch[2].length, seed)); + _sz_assert(batch_hashes.u64s[3] == sz_hash(batch[3].start, batch[3].length, seed)); + + // Now let's perform an optimistic hash-table lookup using vectorized gathers. + sz_u256_vec_t batch_slots, existing_hashes; + batch_slots.ymm = _mm256_and_si256(batch_hashes.ymm, _mm256_set1_epi64x(hash_table_slots - 1)); + + // Before scattering the new positions - gather the pre-existing ones. + // This can help us detect values: + // - that are definitely missing in the hash table, if the slot is just NULL-ed + // - that may be present in the hash table, and need to be validated in the loop + existing_hashes.ymm = + _mm256_mmask_i64gather_epi64(_mm256_setzero_si256(), 0xFF, batch_slots.ymm, table_hashes, 8); + + // Check if we already have all of those slots populated with exactly the same values + int const same_hashes = _mm256_movemask_epi8(_mm256_cmpeq_epi64(existing_hashes.ymm, batch_hashes.ymm)); + int const nulled_hashes = + _mm256_movemask_epi8(_mm256_cmpeq_epi64(existing_hashes.ymm, _mm256_set1_epi64x(SZ_SIZE_MAX))); + + // Now for every one of the 4 hashed values we can have several outcomes: + // - it's an "empty" value → no match + // - it's a different hash → continue probing + // - it's the same hash for a different string, so we have a rare collision → continue probing + // - it's the same hash for the same string, so we have a match → export + // + // That logic is too complex to be effectively handled by SIMD, so we switch back to serial code. + for (sz_size_t i = 0; i < 4; ++i) { + sz_cptr_t const str = batch[i].start; + sz_size_t const length = batch[i].length; + sz_u64_t const hash = batch_hashes.u64s[i]; + int const same_hash = (same_hashes >> (8 * i)) & 0xFF; + int const nulled_hash = (nulled_hashes >> (8 * i)) & 0xFF; + if (nulled_hash) continue; + + sz_size_t hash_slot = batch_slots.u64s[i]; + // This optimization may look like just one less memory load, + // but it will help us produce a different set of branches and will affect + // the branch prediction quality on the CPU backend. + if (same_hash) { + // The hash matches, compare the strings. + sz_size_t const small_position = table_positions[hash_slot]; + sz_size_t const small_length = small_sequence->get_length(small_sequence->handle, small_position); + if (length == small_length) { + // Same hash may still imply different strings, so we need to compare them. + sz_cptr_t const small_str = small_sequence->get_start(small_sequence->handle, small_position); + sz_bool_t const same = sz_equal(str, small_str, length); + if (same == sz_true_k) { + // Finally, there is a match, store the positions. + small_positions[intersection_count] = small_position; + large_positions[intersection_count] = batch_positions.u64s[i]; + ++intersection_count; + // Now go to the next value in the batch. + continue; + } + } + // If any of the conditions above didn't hold, just continue probing. + hash_slot = (hash_slot + 1) & (hash_table_slots - 1); + } + + // Implement linear probing to resolve collisions. + for (; table_hashes[hash_slot] != SZ_SIZE_MAX; hash_slot = (hash_slot + 1) & (hash_table_slots - 1)) { + sz_u64_t small_hash = table_hashes[hash_slot]; + if (small_hash != hash) continue; + + // The hash matches, compare the strings. + sz_size_t const small_position = table_positions[hash_slot]; + sz_size_t const small_length = small_sequence->get_length(small_sequence->handle, small_position); + if (length != small_length) continue; + + // Same hash may still imply different strings, so we need to compare them. + sz_cptr_t const small_str = small_sequence->get_start(small_sequence->handle, small_position); + sz_bool_t const same = sz_equal(str, small_str, length); + if (same != sz_true_k) continue; + + // Finally, there is a match, store the positions. + small_positions[intersection_count] = small_position; + large_positions[intersection_count] = batch_positions.u64s[i]; + ++intersection_count; + break; + } + } + } + } + + // TODO: Consider one more level of partitioning, separating the values into [17:64] and [64:] ranges. + if (count_longer) { + // At this point only large values are remaining, let's process them with the code identical to our + // serial solution, but dispatching the right Ice Lake kernel under the hood. + sz_fill((sz_ptr_t)table_positions, hash_table_slots * bytes_per_entry, 0xFF); + + // Hash the smaller set into the hash table using the default available backend. + for (sz_size_t small_position = 0; small_position < small_sequence->count; ++small_position) { + sz_size_t const length = small_sequence->get_length(small_sequence->handle, small_position); + if (length <= 16) continue; //! This is the only difference from the serial solution + sz_cptr_t const str = small_sequence->get_start(small_sequence->handle, small_position); + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash & (hash_table_slots - 1); + // Implement linear probing to find the first free slot. + // If we somehow face 2 different strings with same hash, we will export that hash 2 times! + while (table_hashes[hash_slot] != SZ_SIZE_MAX) hash_slot = (hash_slot + 1) & (hash_table_slots - 1); + table_hashes[hash_slot] = hash; + table_positions[hash_slot] = small_position; + } + + // Iterate over the larger set and check for the presence of each element in the hash table. + for (sz_size_t large_position = 0; large_position < large_sequence->count; ++large_position) { + sz_size_t const length = large_sequence->get_length(large_sequence->handle, large_position); + if (length <= 16) continue; //! This is the only difference from the serial solution + sz_cptr_t const str = large_sequence->get_start(large_sequence->handle, large_position); + sz_u64_t const hash = sz_hash(str, length, seed); + sz_size_t hash_slot = hash & (hash_table_slots - 1); + + // Implement linear probing to resolve collisions. + for (; table_hashes[hash_slot] != SZ_SIZE_MAX; hash_slot = (hash_slot + 1) & (hash_table_slots - 1)) { + sz_u64_t small_hash = table_hashes[hash_slot]; + if (small_hash != hash) continue; + + // The hash matches, compare the strings. + sz_size_t const small_position = table_positions[hash_slot]; + sz_size_t const small_length = small_sequence->get_length(small_sequence->handle, small_position); + if (length != small_length) continue; + + // Same hash may still imply different strings, so we need to compare them. + sz_cptr_t const small_str = small_sequence->get_start(small_sequence->handle, small_position); + sz_bool_t const same = sz_equal(str, small_str, length); + if (same != sz_true_k) continue; + + // Finally, there is a match, store the positions. + small_positions[intersection_count] = small_position; + large_positions[intersection_count] = large_position; + ++intersection_count; + break; + } + } + } + + // Finalize + *intersection_count_ptr = intersection_count; + return sz_success_k; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SZ_USE_ICE +#pragma endregion // Ice Lake Implementation + +/* Pick the right implementation for the string search algorithms. + * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. + */ +#pragma region Compile Time Dispatching +#if !SZ_DYNAMIC_DISPATCH + +SZ_DYNAMIC sz_status_t sz_sequence_intersect(sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_size, + sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { +#if SZ_USE_SKYLAKE + return sz_sequence_intersect_ice( // + first_sequence, second_sequence, // + alloc, seed, intersection_size, // + first_positions, second_positions); +#elif SZ_USE_SVE + return sz_sequence_intersect_sve( // + first_sequence, second_sequence, // + alloc, seed, intersection_size, // + first_positions, second_positions); +#else + return sz_sequence_intersect_serial( // + first_sequence, second_sequence, // + alloc, seed, intersection_size, // + first_positions, second_positions); +#endif +} + +#endif // !SZ_DYNAMIC_DISPATCH +#pragma endregion // Compile Time Dispatching + +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // STRINGZILLA_INTERSECT_H_ diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index 16ca31fe..af808cb5 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -1,15 +1,10 @@ /** - * @brief Hardware-accelerated string collection sorting & joins. + * @brief Hardware-accelerated string collection sorting. * @file sort.h * @author Ash Vardanian * - * Includes core APIs for `sz_sequence_t` string collections with hardware-specific backends: - * - * - `sz_sequence_argsort` - to get the sorting permutation of a string collection. - * - `sz_sequence_join` - to compute the intersection of two arbitrary string collections. - * - * The first can easily be used to implement SORT and GROUPBY operations SQL, while the second can be used to - * implement JOIN operations. Both are essential for implementing efficient database engines. + * Provides the @b `sz_sequence_argsort` API to get the sorting permutation of `sz_sequence_t` binary + * string collections in lexicographical order. * * The core idea of all following string algorithms is to process strings not based on 1 character at a time, * but on a larger "Pointer-sized N-grams" fitting in 4 or 8 bytes at once, on 32-bit or 64-bit architectures, @@ -17,7 +12,7 @@ * rest for some metadata. * * That, however, means, that unsigned integer sorting & matching is a constituent part of our sequence - * algorithms and we can expose them as an additional set of APIs for the users: + * algorithms and we can expose them as an additional APIs for the users: * * - `sz_pgrams_sort` - to inplace sort continuous pointer-sized integers. * - `sz_pgrams_join` - to compute the intersection of two arbitrary integer collections. @@ -116,94 +111,6 @@ SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_mem SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** - * @brief Intersects two arbitrary @b string sequences, using a hash table. - * Outputs the @p first_positions from the @p first_sequence and @p second_positions from - * the @p second_sequence, that contain identical strings. - * - * - * @param[in] first_sequence First immutable sequence of strings to intersection. - * @param[in] second_sequence Second immutable sequence of strings to intersection. - * @param[in] alloc Optional memory allocator for temporary storage. - * @param[out] intersection_size Number of identical strings in both sequences. - * @param[out] first_positions Offset positions of the identical strings from the @p first_sequence. - * @param[out] second_positions Offset positions of the identical strings from the @p second_sequence. - * - * @retval `sz_success_k` if the operation was successful. - * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. - * @retval `sz_contains_duplicates_k` if any of the sequences contain duplicate strings. - * @pre The @p first_positions arrays must fit at least `min(first_sequence->count, second_sequence->count)` items. - * @pre The @p second_positions arrays must fit at least `min(first_sequence->count, second_sequence->count)` items. - * - * Example usage: - * - * @code{.c} - * #include - * int main() { - * char const *first[] = {"banana", "apple", "cherry"}; - * char const *second[] = {"cherry", "orange", "pineapple", "banana"}; - * sz_sequence_t first_sequence, second_sequence; - * sz_sequence_from_null_terminated_strings(first, 3, &first_sequence); - * sz_sequence_from_null_terminated_strings(second, 4, &second_sequence); - * sz_size_t intersection_size; - * sz_sorted_idx_t first_positions[3], second_positions[3]; //? 3 is the size of the smaller sequence - * sz_status_t status = sz_sequence_join(&first_sequence, &second_sequence, NULL, - * &intersection_size, first_positions, second_positions); - * return status == sz_success_k && intersection_size == 2 ? 0 : 1; - * } - * @endcode - * - * @note The algorithm has linear memory complexity and linear time complexity. - * @see https://en.wikipedia.org/wiki/Join_(SQL) - * - * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. - * @sa sz_sequence_join_serial, sz_sequence_join_skylake, sz_sequence_join_sve - */ -SZ_DYNAMIC sz_status_t sz_sequence_join(sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, - sz_memory_allocator_t *alloc, sz_size_t *intersection_size, - sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); - -/** - * @brief Faster @b inplace `std::stable_sort` for a continuous @b unsigned-integer sequence, using MergeSort. - * Overwrites the input @p pgrams with the sorted sequence and exports the @p order permutation. - * - * This algorithm guarantees stability, ensuring that the relative order of equal elements is preserved. - * It uses more memory than `sz_pgrams_sort`, but its performance is more predictable. - * It's preferred for very large inputs, as most memory access happens in a sequential pattern. - * - * @param[inout] pgrams Continuous buffer of unsigned integers to sort in place. - * @param[in] count Number of elements in the sequence. - * @param[in] alloc Optional memory allocator for temporary storage. - * @param[out] order Output permutation that sorts the elements. Must fit at least @p count integers. - * - * @retval `sz_success_k` if the operation was successful. - * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. - * @post The @p order array will contain a valid permutation of `[0, count - 1]`. - * - * Example usage: - * - * @code{.c} - * #include - * int main() { - * sz_pgram_t pgrams[] = {42, 17, 99, 8}; - * sz_sorted_idx_t order[4]; - * sz_pgrams_join(pgrams, 4, NULL, order); - * return order[0] == 3 && order[1] == 1 && order[2] == 0 && order[3] == 2 ? 0 : 1; - * } - * @endcode - * - * @note The algorithm has linear memory complexity and log-linear time complexity. - * @see [MergeSort Algorithm](https://en.wikipedia.org/wiki/Merge_sort) - * - * @note This algorithm is @b stable: equal elements maintain their relative order. - * @sa sz_pgrams_sort - * - * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. - * @sa sz_pgrams_join_serial, sz_pgrams_join_skylake, sz_pgrams_join_sve - */ -SZ_DYNAMIC sz_status_t sz_pgrams_join(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order); - /** @copydoc sz_sequence_argsort */ SZ_PUBLIC sz_status_t sz_sequence_argsort_serial(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); @@ -222,12 +129,6 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, SZ_PUBLIC sz_status_t sz_pgrams_sort_skylake(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** @copydoc sz_sequence_join */ -SZ_PUBLIC sz_status_t sz_sequence_join_skylake( // - sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // - sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // - sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); - #endif #if SZ_USE_SVE @@ -240,12 +141,6 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_sve(sz_sequence_t const *sequence, sz_ SZ_PUBLIC sz_status_t sz_pgrams_sort_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, sz_sorted_idx_t *order); -/** @copydoc sz_sequence_join */ -SZ_PUBLIC sz_status_t sz_sequence_join_sve( // - sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // - sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // - sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions); - #endif #pragma endregion @@ -717,7 +612,7 @@ SZ_PUBLIC sz_status_t sz_pgrams_sort_serial(sz_pgram_t *pgrams, sz_size_t count, * @brief Helper function similar to `std::set_union` over pairs of integers and their original indices. * @see https://en.cppreference.com/w/cpp/algorithm/set_union */ -SZ_INTERNAL void _sz_sequence_join_serial_merge( // +SZ_INTERNAL void _sz_pgrams_union_serial( // sz_pgram_t const *first_pgrams, sz_sorted_idx_t const *first_indices, sz_size_t first_count, // sz_pgram_t const *second_pgrams, sz_sorted_idx_t const *second_indices, sz_size_t second_count, // sz_pgram_t *result_pgrams, sz_sorted_idx_t *result_indices) { @@ -764,167 +659,6 @@ SZ_INTERNAL void _sz_sequence_join_serial_merge( _sz_assert(merged_begin[i - 1] <= merged_begin[i] && "The merged pgrams must be in ascending order."); } -SZ_PUBLIC sz_status_t sz_pgrams_join_serial(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, - sz_sorted_idx_t *order) { - - // First, initialize the `order` with `std::iota`-like behavior. - for (sz_size_t i = 0; i != count; ++i) order[i] = i; - - // On very small collections - just use the quadratic-complexity insertion sort - // without any smart optimizations or memory allocations. - if (count <= 32) { - sz_pgrams_sort_with_insertion(pgrams, count, order); - return sz_success_k; - } - - // Go through short chunks of 8 elements and sort them with a sorting network. - for (sz_size_t i = 0; i + 8u <= count; i += 8u) _sz_sequence_sorting_network_8x(pgrams + i, order + i); - - // For the tail of the array, sort it with insertion sort. - sz_size_t const tail_count = count & 7u; - sz_pgrams_sort_with_insertion(pgrams + count - tail_count, tail_count, order + count - tail_count); - - // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. - sz_memory_allocator_t global_alloc; - if (!alloc) { - sz_memory_allocator_init_default(&global_alloc); - alloc = &global_alloc; - } - - // At this point, the array is partitioned into sorted runs. - // We'll now merge these runs until the whole array is sorted. - // Allocate temporary memory to hold merged results: - // - one block for keys (`sz_pgram_t`) - // - one block for indices (`sz_sorted_idx_t`) - sz_size_t memory_usage = sizeof(sz_pgram_t) * count + sizeof(sz_sorted_idx_t) * count; - sz_pgram_t *pgrams_temporary = (sz_pgram_t *)alloc->allocate(memory_usage, alloc); - sz_sorted_idx_t *order_temporary = (sz_sorted_idx_t *)(pgrams_temporary + count); - if (!pgrams_temporary) return sz_bad_alloc_k; - - // Set initial run size (the sorted chunks). - sz_size_t run_size = 8; - - // Pointers for current source and destination arrays. - sz_pgram_t *src_pgrams = pgrams; - sz_sorted_idx_t *src_order = order; - sz_pgram_t *dst_pgrams = pgrams_temporary; - sz_sorted_idx_t *dst_order = order_temporary; - - // Merge sorted runs in a bottom-up manner until the run size covers the whole array. - while (run_size < count) { - // Process adjacent runs. - for (sz_size_t i = 0; i < count; i += run_size * 2) { - // Determine the number of elements in the left run. - sz_size_t left_count = run_size; - if (i + left_count > count) { left_count = count - i; } - - // Determine the number of elements in the right run. - sz_size_t right_count = run_size; - if (i + left_count >= count) { right_count = 0; } - else if (i + left_count + right_count > count) { right_count = count - (i + left_count); } - - // Merge the two runs: - _sz_sequence_join_serial_merge( // - src_pgrams + i, src_order + i, left_count, // - src_pgrams + i + run_size, src_order + i + run_size, right_count, // - dst_pgrams + i, dst_order + i); - } - - // Swap the roles of the source and destination arrays. - _sz_swap(sz_pgram_t *, src_pgrams, dst_pgrams); - _sz_swap(sz_sorted_idx_t *, src_order, dst_order); - - // Double the run size for the next pass. - run_size *= 2; - } - - // If the final sorted result is not in the original array, copy the sorted results back. - if (src_pgrams != pgrams) - for (sz_size_t i = 0; i < count; ++i) pgrams[i] = src_pgrams[i], order[i] = src_order[i]; - - // Free the temporary memory used for merging. - alloc->free(pgrams_temporary, memory_usage, alloc); - return sz_success_k; -} - -SZ_PUBLIC sz_status_t sz_sequence_join_serial( // - sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // - sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_count_ptr, // - sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { - - // To join to unordered sets of strings, the simplest approach would be to hash them into a dynamically - // allocated hash table and then iterate over the second set, checking for the presence of each element in the - // hash table. This would require O(N) memory and O(N) time complexity, where N is the smaller set. - sz_sequence_t const *small_sequence, *large_sequence; - sz_sorted_idx_t *small_positions, *large_positions; - if (first_sequence->count <= second_sequence->count) { - small_sequence = first_sequence, large_sequence = second_sequence; - small_positions = first_positions, large_positions = second_positions; - } - else { - small_sequence = second_sequence, large_sequence = first_sequence; - small_positions = second_positions, large_positions = first_positions; - } - - // We may very well have nothing to join - if (small_sequence->count == 0) { - *intersection_count_ptr = 0; - return sz_success_k; - } - - // Allocate memory for the hash table and initialize it with 0xFF. - sz_size_t const hash_table_slots = sz_size_bit_ceil(small_sequence->count * 2); - sz_size_t const bytes_per_entry = sizeof(sz_size_t) + sizeof(sz_u64_t); - sz_size_t *table_positions = (sz_size_t *)alloc->allocate(hash_table_slots * bytes_per_entry, alloc); - if (!table_positions) return sz_bad_alloc_k; - sz_u64_t *table_fingerprints = (sz_u64_t *)(table_positions + hash_table_slots); - sz_fill((sz_ptr_t)table_positions, hash_table_slots * bytes_per_entry, 0xFF); - - // Hash the smaller set into the hash table using the default available backend. - for (sz_size_t small_position = 0; small_position < small_sequence->count; ++small_position) { - sz_cptr_t const str = small_sequence->get_start(small_sequence->handle, small_position); - sz_size_t const length = small_sequence->get_length(small_sequence->handle, small_position); - sz_u64_t const hash = sz_hash(str, length, seed); - sz_size_t hash_slot = hash; - // Implement linear probing to resolve collisions. - while (table_positions[hash_slot & (hash_table_slots - 1)] != SZ_SIZE_MAX) ++hash_slot; - table_positions[hash_slot & (hash_table_slots - 1)] = small_position; - table_fingerprints[hash_slot & (hash_table_slots - 1)] = hash; - } - - // Iterate over the larger set and check for the presence of each element in the hash table. - sz_size_t intersection_count = 0; - for (sz_size_t large_position = 0; large_position < large_sequence->count; ++large_position) { - sz_cptr_t const str = large_sequence->get_start(large_sequence->handle, large_position); - sz_size_t const length = large_sequence->get_length(large_sequence->handle, large_position); - sz_u64_t const hash = sz_hash(str, length, seed); - sz_size_t hash_slot = hash; - // Implement linear probing to resolve collisions. - for (; table_positions[hash_slot & (hash_table_slots - 1)] != SZ_SIZE_MAX; ++hash_slot) { - sz_u64_t small_hash = table_fingerprints[hash_slot & (hash_table_slots - 1)]; - if (small_hash != hash) continue; - - // The hash matches, compare the strings. - sz_size_t const small_position = table_positions[hash_slot & (hash_table_slots - 1)]; - sz_size_t const small_length = small_sequence->get_length(small_sequence->handle, small_position); - if (length != small_length) continue; - - sz_cptr_t const small_str = small_sequence->get_start(small_sequence->handle, small_position); - sz_bool_t const same = sz_equal(str, small_str, length); - if (same != sz_true_k) continue; - - // Finally, there is a match, store the positions. - small_positions[intersection_count] = small_position; - large_positions[intersection_count] = large_position; - ++intersection_count; - break; - } - } - - *intersection_count_ptr = intersection_count; - return sz_success_k; -} - #pragma endregion // Serial MergeSort Implementation /* AVX512 implementation of the string search algorithms for Ice Lake and newer CPUs. @@ -1186,14 +920,6 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, return sz_success_k; } -SZ_PUBLIC sz_status_t sz_sequence_join_skylake( // - sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, // - sz_memory_allocator_t *alloc, sz_size_t *intersection_size, // - sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { - sz_unused(first_sequence && second_sequence && alloc && intersection_size && first_positions && second_positions); - return sz_success_k; -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SKYLAKE @@ -1227,27 +953,6 @@ SZ_DYNAMIC sz_status_t sz_pgrams_sort(sz_pgram_t *pgrams, sz_size_t count, sz_me #endif } -SZ_DYNAMIC sz_status_t sz_sequence_join(sz_sequence_t const *first_sequence, sz_sequence_t const *second_sequence, - sz_memory_allocator_t *alloc, sz_size_t *intersection_size, - sz_sorted_idx_t *first_positions, sz_sorted_idx_t *second_positions) { -#if SZ_USE_SKYLAKE - return sz_sequence_join_skylake( // - first_sequence, second_sequence, // - alloc, intersection_size, // - first_positions, second_positions); -#elif SZ_USE_SVE - return sz_sequence_join_sve( // - first_sequence, second_sequence, // - alloc, intersection_size, // - first_positions, second_positions); -#else - return sz_sequence_join_serial( // - first_sequence, second_sequence, // - alloc, intersection_size, // - first_positions, second_positions); -#endif -} - #endif // !SZ_DYNAMIC_DISPATCH #pragma endregion // Compile Time Dispatching diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 824bacd4..682f63f0 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -48,7 +48,8 @@ #include "find.h" // `sz_find`, `sz_find_byteset`, `sz_rfind` #include "small_string.h" // `sz_string_t`, `sz_string_init`, `sz_string_free` #include "similarity.h" // `sz_levenshtein_distance`, `sz_needleman_wunsch_score` -#include "sort.h" // `sz_sequence_argsort`, `sz_pgrams_sort`, `sz_pgrams_sort_stable` +#include "sort.h" // `sz_sequence_argsort`, `sz_pgrams_sort` +#include "intersect.h" // `sz_sequence_intersect` #ifdef __cplusplus extern "C" { From e25f518b8ff21ad244923a8562ca0325d024fa66 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 7 Mar 2025 12:27:07 +0000 Subject: [PATCH 149/751] Import: Portable macros for C++ version inference --- include/stringzilla/stringzilla.hpp | 36 ++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index a1b2de28..eafa448f 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -28,12 +28,36 @@ /* We need to detect the version of the C++ language we are compiled with. * This will affect recent features like `operator<=>` and tests against STL. */ -#define _SZ_IS_CPP23 (__cplusplus >= 202101L) -#define _SZ_IS_CPP20 (__cplusplus >= 202002L) -#define _SZ_IS_CPP17 (__cplusplus >= 201703L) -#define _SZ_IS_CPP14 (__cplusplus >= 201402L) -#define _SZ_IS_CPP11 (__cplusplus >= 201103L) -#define _SZ_IS_CPP98 (__cplusplus >= 199711L) +#if __cplusplus >= 202101L +#define _SZ_IS_CPP23 1 +#else +#define _SZ_IS_CPP23 0 +#endif +#if __cplusplus >= 202002L +#define _SZ_IS_CPP20 1 +#else +#define _SZ_IS_CPP20 0 +#endif +#if __cplusplus >= 201703L +#define _SZ_IS_CPP17 1 +#else +#define _SZ_IS_CPP17 0 +#endif +#if __cplusplus >= 201402L +#define _SZ_IS_CPP14 1 +#else +#define _SZ_IS_CPP14 0 +#endif +#if __cplusplus >= 201103L +#define _SZ_IS_CPP11 1 +#else +#define _SZ_IS_CPP11 0 +#endif +#if __cplusplus >= 199711L +#define _SZ_IS_CPP98 1 +#else +#define _SZ_IS_CPP98 0 +#endif /** * @brief Expands to `constexpr` in C++20 and later, and to nothing in older C++ versions. From 0d982a45f842287d7e344f0d8b360f52482017f5 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 7 Mar 2025 12:53:39 +0000 Subject: [PATCH 150/751] Docs: New formatting in C++ Refreshed the C++ doxygen docstring to match the style of C headers with the new `@sa`, `@p`, and `@retval` tags. --- include/stringzilla/stringzilla.hpp | 1102 +++++++++++++-------------- 1 file changed, 545 insertions(+), 557 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index eafa448f..98817dd7 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -118,30 +118,33 @@ using carray = char[count_characters]; #pragma region Memory Operations /** - * @brief Analog to @b `std::memset`, but with a more efficient implementation. - * @param target The pointer to the target memory region. - * @param value The byte value to set. - * @param n The number of bytes to copy. + * @brief Analog to @b `std::memset`, but with a more efficient implementation. + * @param[in] target The pointer to the target memory region. + * @param[in] value The byte value to set. + * @param[in] n The number of bytes to copy. + * @see https://en.cppreference.com/w/cpp/string/byte/memset */ inline void memset(void *target, char value, std::size_t n) noexcept { return sz_fill(reinterpret_cast(target), n, value); } /** - * @brief Analog to @b `std::memmove`, but with a more efficient implementation. - * @param target The pointer to the target memory region. - * @param source The pointer to the source memory region. - * @param n The number of bytes to copy. + * @brief Analog to @b `std::memmove`, but with a more efficient implementation. + * @param[in] target The pointer to the target memory region. + * @param[in] source The pointer to the source memory region. + * @param[in] n The number of bytes to copy. + * @see https://en.cppreference.com/w/cpp/string/byte/memmove */ inline void memmove(void *target, void const *source, std::size_t n) noexcept { return sz_move(reinterpret_cast(target), reinterpret_cast(source), n); } /** - * @brief Analog to @b `std::memcpy`, but with a more efficient implementation. - * @param target The pointer to the target memory region. - * @param source The pointer to the source memory region. - * @param n The number of bytes to copy. + * @brief Analog to @b `std::memcpy`, but with a more efficient implementation. + * @param[in] target The pointer to the target memory region. + * @param[in] source The pointer to the source memory region. + * @param[in] n The number of bytes to copy. + * @see https://en.cppreference.com/w/cpp/string/byte/memcpy */ inline void memcpy(void *target, void const *source, std::size_t n) noexcept { return sz_copy(reinterpret_cast(target), reinterpret_cast(source), n); @@ -152,8 +155,8 @@ inline void memcpy(void *target, void const *source, std::size_t n) noexcept { #pragma region Character Sets /** - * @brief The concatenation of the `ascii_lowercase` and `ascii_uppercase`. This value is not locale-dependent. - * https://docs.python.org/3/library/string.html#string.ascii_letters + * @brief The concatenation of the `ascii_lowercase` and `ascii_uppercase`. This value is not locale-dependent. + * @see https://docs.python.org/3/library/string.html#string.ascii_letters */ inline carray<52> const &ascii_letters() noexcept { static carray<52> const all = { @@ -166,8 +169,8 @@ inline carray<52> const &ascii_letters() noexcept { } /** - * @brief The lowercase letters "abcdefghijklmnopqrstuvwxyz". This value is not locale-dependent. - * https://docs.python.org/3/library/string.html#string.ascii_lowercase + * @brief The lowercase letters "abcdefghijklmnopqrstuvwxyz". This value is not locale-dependent. + * @see https://docs.python.org/3/library/string.html#string.ascii_lowercase */ inline carray<26> const &ascii_lowercase() noexcept { static carray<26> const all = { @@ -179,8 +182,8 @@ inline carray<26> const &ascii_lowercase() noexcept { } /** - * @brief The uppercase letters "ABCDEFGHIJKLMNOPQRSTUVWXYZ". This value is not locale-dependent. - * https://docs.python.org/3/library/string.html#string.ascii_uppercase + * @brief The uppercase letters "ABCDEFGHIJKLMNOPQRSTUVWXYZ". This value is not locale-dependent. + * @see https://docs.python.org/3/library/string.html#string.ascii_uppercase */ inline carray<26> const &ascii_uppercase() noexcept { static carray<26> const all = { @@ -192,9 +195,8 @@ inline carray<26> const &ascii_uppercase() noexcept { } /** - * @brief ASCII characters which are considered printable. - * A combination of `digits`, `ascii_letters`, `punctuation`, and `whitespace`. - * https://docs.python.org/3/library/string.html#string.printable + * @brief Printable ASCII characters, including: `digits`, `ascii_letters`, `punctuation`, and `whitespace`. + * @see https://docs.python.org/3/library/string.html#string.printable */ inline carray<100> const &ascii_printables() noexcept { static carray<100> const all = { @@ -209,8 +211,7 @@ inline carray<100> const &ascii_printables() noexcept { } /** - * @brief Non-printable ASCII control characters. - * Includes all codes from 0 to 31 and 127. + * @brief Non-printable ASCII control characters. Includes all codes from 0 to 31 and 127. */ inline carray<33> const &ascii_controls() noexcept { static carray<33> const all = { @@ -222,8 +223,8 @@ inline carray<33> const &ascii_controls() noexcept { } /** - * @brief The digits "0123456789". - * https://docs.python.org/3/library/string.html#string.digits + * @brief The digits "0123456789". + * @see https://docs.python.org/3/library/string.html#string.digits */ inline carray<10> const &digits() noexcept { static carray<10> const all = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}; @@ -231,8 +232,8 @@ inline carray<10> const &digits() noexcept { } /** - * @brief The letters "0123456789abcdefABCDEF". - * https://docs.python.org/3/library/string.html#string.hexdigits + * @brief The letters "0123456789abcdefABCDEF". + * @see https://docs.python.org/3/library/string.html#string.hexdigits */ inline carray<22> const &hexdigits() noexcept { static carray<22> const all = { @@ -244,8 +245,8 @@ inline carray<22> const &hexdigits() noexcept { } /** - * @brief The letters "01234567". - * https://docs.python.org/3/library/string.html#string.octdigits + * @brief The letters "01234567". + * @see https://docs.python.org/3/library/string.html#string.octdigits */ inline carray<8> const &octdigits() noexcept { static carray<8> const all = {'0', '1', '2', '3', '4', '5', '6', '7'}; @@ -253,9 +254,8 @@ inline carray<8> const &octdigits() noexcept { } /** - * @brief ASCII characters considered punctuation characters in the C locale: - * !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~. - * https://docs.python.org/3/library/string.html#string.punctuation + * @brief ASCII characters considered punctuation characters in the C locale: @b !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~. + * @see https://docs.python.org/3/library/string.html#string.punctuation */ inline carray<32> const &punctuation() noexcept { static carray<32> const all = { @@ -267,9 +267,8 @@ inline carray<32> const &punctuation() noexcept { } /** - * @brief ASCII characters that are considered whitespace. - * This includes space, tab, linefeed, return, formfeed, and vertical tab. - * https://docs.python.org/3/library/string.html#string.whitespace + * @brief Whitespace ASCII characters, including: space, tab, linefeed, return, formfeed, and vertical tab. + * @see https://docs.python.org/3/library/string.html#string.whitespace */ inline carray<6> const &whitespaces() noexcept { static carray<6> const all = {' ', '\t', '\n', '\r', '\f', '\v'}; @@ -277,8 +276,8 @@ inline carray<6> const &whitespaces() noexcept { } /** - * @brief ASCII characters that are considered line delimiters. - * https://docs.python.org/3/library/stdtypes.html#str.splitlines + * @brief ASCII characters that are considered line delimiters. + * @see https://docs.python.org/3/library/stdtypes.html#str.splitlines */ inline carray<8> const &newlines() noexcept { static carray<8> const all = {'\n', '\r', '\f', '\v', '\x1C', '\x1D', '\x1E', '\x85'}; @@ -286,7 +285,8 @@ inline carray<8> const &newlines() noexcept { } /** - * @brief ASCII characters forming the BASE64 encoding alphabet. + * @brief ASCII characters forming the BASE64 encoding alphabet: a-z, A-Z, 0-9, +, and /. + * @see https://docs.python.org/3/library/base64.html */ inline carray<64> const &base64() noexcept { static carray<64> const all = { @@ -299,7 +299,7 @@ inline carray<64> const &base64() noexcept { } /** - * @brief A set of characters represented as a bitset with 256 slots. + * @brief A set of characters represented as a bitset with 256 slots. */ template class basic_byteset { @@ -383,9 +383,8 @@ inline byteset newlines_set() { return byteset {newlines(), sizeof(newlines())}; inline byteset base64_set() { return byteset {base64(), sizeof(base64())}; } /** - * @brief A look-up table for character replacement operations. - * Exactly 256 bytes for byte-to-byte replacement. - * ! For larger character types should be allocated on the heap. + * @brief A look-up table for character replacement operations. Exactly 256 bytes for byte-to-byte replacement. + * @warning For larger character types should be allocated on the heap. */ template class basic_look_up_table { @@ -415,8 +414,8 @@ class basic_look_up_table { } /** - * @brief Creates a look-up table with a one-to-one mapping of characters to themselves. - * Similar to `std::iota` filling, but properly handles signed integer casts. + * @brief Creates a look-up table with a one-to-one mapping of characters to themselves. + * @see Similar to `std::iota` filling, but properly handles signed integer casts. */ static basic_look_up_table identity() noexcept { basic_look_up_table result; @@ -446,7 +445,8 @@ inline static constexpr exclude_overlaps_type exclude_overlaps; #endif /** - * @brief Zero-cost wrapper around the `.find` member function of string-like classes. + * @brief Zero-cost wrapper around the `.find` member function of string-like classes. + * @see https://en.cppreference.com/w/cpp/string/basic_string/find */ template struct matcher_find { @@ -463,7 +463,8 @@ struct matcher_find { }; /** - * @brief Zero-cost wrapper around the `.rfind` member function of string-like classes. + * @brief Zero-cost wrapper around the `.rfind` member function of string-like classes. + * @see https://en.cppreference.com/w/cpp/string/basic_string/rfind */ template struct matcher_rfind { @@ -480,7 +481,8 @@ struct matcher_rfind { }; /** - * @brief Zero-cost wrapper around the `.find_first_of` member function of string-like classes. + * @brief Zero-cost wrapper around the `.find_first_of` member function of string-like classes. + * @see https://en.cppreference.com/w/cpp/string/basic_string/find_first_of */ template struct matcher_find_first_of { @@ -492,7 +494,8 @@ struct matcher_find_first_of { }; /** - * @brief Zero-cost wrapper around the `.find_last_of` member function of string-like classes. + * @brief Zero-cost wrapper around the `.find_last_of` member function of string-like classes. + * @see https://en.cppreference.com/w/cpp/string/basic_string/find_last_of */ template struct matcher_find_last_of { @@ -504,7 +507,8 @@ struct matcher_find_last_of { }; /** - * @brief Zero-cost wrapper around the `.find_first_not_of` member function of string-like classes. + * @brief Zero-cost wrapper around the `.find_first_not_of` member function of string-like classes. + * @see https://en.cppreference.com/w/cpp/string/basic_string/find_first_not_of */ template struct matcher_find_first_not_of { @@ -516,7 +520,8 @@ struct matcher_find_first_not_of { }; /** - * @brief Zero-cost wrapper around the `.find_last_not_of` member function of string-like classes. + * @brief Zero-cost wrapper around the `.find_last_not_of` member function of string-like classes. + * @see https://en.cppreference.com/w/cpp/string/basic_string/find_last_not_of */ template struct matcher_find_last_not_of { @@ -528,9 +533,9 @@ struct matcher_find_last_not_of { }; /** - * @brief A range of string slices representing the matches of a substring search. - * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. - * Similar to a pair of `boost::algorithm::find_iterator`. + * @brief A range of string slices representing the matches of a substring search. + * @note Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * @see Similar to a pair of `boost::algorithm::find_iterator`. */ template class range_matches { @@ -597,17 +602,13 @@ class range_matches { bool empty() const noexcept { return begin() == end_sentinel_type {}; } bool include_overlaps() const noexcept { return matcher_.skip_length() < matcher_.needle_length(); } - /** - * @brief Copies the matches into a container. - */ + /** @brief Copies the matches into a container. */ template void to(container_ &container) { - for (auto match : *this) { container.push_back(match); } + for (auto match : *this) container.push_back(match); } - /** - * @brief Copies the matches into a consumed container, returning it at the end. - */ + /** @brief Copies the matches into a consumed container, returning it at the end. */ template container_ to() { return container_ {begin(), end()}; @@ -615,9 +616,9 @@ class range_matches { }; /** - * @brief A range of string slices representing the matches of a @b reverse-order substring search. - * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. - * Similar to a pair of `boost::algorithm::find_iterator`. + * @brief A range of string slices representing the matches of a @b reverse-order substring search. + * @note Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * @see Similar to a pair of `boost::algorithm::find_iterator`. */ template class range_rmatches { @@ -695,17 +696,13 @@ class range_rmatches { bool empty() const noexcept { return begin() == end_sentinel_type {}; } bool include_overlaps() const noexcept { return matcher_.skip_length() < matcher_.needle_length(); } - /** - * @brief Copies the matches into a container. - */ + /** @brief Copies the matches into a container. */ template void to(container_ &container) { - for (auto match : *this) { container.push_back(match); } + for (auto match : *this) container.push_back(match); } - /** - * @brief Copies the matches into a consumed container, returning it at the end. - */ + /** @brief Copies the matches into a consumed container, returning it at the end. */ template container_ to() { return container_ {begin(), end()}; @@ -713,9 +710,9 @@ class range_rmatches { }; /** - * @brief A range of string slices for different splits of the data. - * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. - * Similar to a pair of `boost::algorithm::split_iterator`. + * @brief A range of string slices for different splits of the data. + * @note Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * @see Similar to a pair of `boost::algorithm::split_iterator`. * * In some sense, represents the inverse operation to `range_matches`, as it reports not the search matches * but the data between them. Meaning that for `N` search matches, there will be `N+1` elements in the range. @@ -797,28 +794,24 @@ class range_splits { difference_type ssize() const noexcept { return std::distance(begin(), end()); } constexpr bool empty() const noexcept { return false; } - /** - * @brief Copies the matches into a container. - */ + /** @brief Copies the matches into a container. */ template void to(container_ &container) { - for (auto match : *this) { container.push_back(match); } + for (auto match : *this) container.push_back(match); } - /** - * @brief Copies the matches into a consumed container, returning it at the end. - */ + /** @brief Copies the matches into a consumed container, returning it at the end. */ template container_ to(container_ &&container = {}) { - for (auto match : *this) { container.push_back(match); } + for (auto match : *this) container.push_back(match); return std::move(container); } }; /** - * @brief A range of string slices for different splits of the data in @b reverse-order. - * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. - * Similar to a pair of `boost::algorithm::split_iterator`. + * @brief A range of string slices for different splits of the data in @b reverse-order. + * @note Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * @see Similar to a pair of `boost::algorithm::split_iterator`. * * In some sense, represents the inverse operation to `range_matches`, as it reports not the search matches * but the data between them. Meaning that for `N` search matches, there will be `N+1` elements in the range. @@ -906,27 +899,23 @@ class range_rsplits { difference_type ssize() const noexcept { return std::distance(begin(), end()); } constexpr bool empty() const noexcept { return false; } - /** - * @brief Copies the matches into a container. - */ + /** @brief Copies the matches into a container. */ template void to(container_ &container) { - for (auto match : *this) { container.push_back(match); } + for (auto match : *this) container.push_back(match); } - /** - * @brief Copies the matches into a consumed container, returning it at the end. - */ + /** @brief Copies the matches into a consumed container, returning it at the end. */ template container_ to(container_ &&container = {}) { - for (auto match : *this) { container.push_back(match); } + for (auto match : *this) container.push_back(match); return std::move(container); } }; /** - * @brief Find all potentially @b overlapping inclusions of a needle substring. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all potentially @b overlapping inclusions of a needle substring. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_matches> find_all(string const &h, string const &n, @@ -935,8 +924,8 @@ range_matches> find_all(stri } /** - * @brief Find all potentially @b overlapping inclusions of a needle substring in @b reverse order. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all potentially @b overlapping inclusions of a needle substring in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_rmatches> rfind_all(string const &h, string const &n, @@ -945,8 +934,8 @@ range_rmatches> rfind_all(s } /** - * @brief Find all @b non-overlapping inclusions of a needle substring. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all @b non-overlapping inclusions of a needle substring. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_matches> find_all(string const &h, string const &n, @@ -955,8 +944,8 @@ range_matches> find_all(stri } /** - * @brief Find all @b non-overlapping inclusions of a needle substring in @b reverse order. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all @b non-overlapping inclusions of a needle substring in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_rmatches> rfind_all(string const &h, string const &n, @@ -965,8 +954,8 @@ range_rmatches> rfind_all(s } /** - * @brief Find all inclusions of characters from the second string. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all inclusions of characters from the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_matches> find_all_characters(string const &h, string const &n) noexcept { @@ -974,8 +963,8 @@ range_matches> find_all_characters(string } /** - * @brief Find all inclusions of characters from the second string in @b reverse order. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all inclusions of characters from the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_rmatches> rfind_all_characters(string const &h, string const &n) noexcept { @@ -983,8 +972,8 @@ range_rmatches> rfind_all_characters(string } /** - * @brief Find all characters except the ones in the second string. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all characters except the ones in the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_matches> find_all_other_characters(string const &h, @@ -993,8 +982,8 @@ range_matches> find_all_other_characte } /** - * @brief Find all characters except the ones in the second string in @b reverse order. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Find all characters except the ones in the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_rmatches> rfind_all_other_characters(string const &h, @@ -1003,8 +992,8 @@ range_rmatches> rfind_all_other_charact } /** - * @brief Splits a string around every @b non-overlapping inclusion of the second string. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Splits a string around every @b non-overlapping inclusion of the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_splits> split(string const &h, string const &n) noexcept { @@ -1012,8 +1001,8 @@ range_splits> split(string c } /** - * @brief Splits a string around every @b non-overlapping inclusion of the second string in @b reverse order. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Splits a string around every @b non-overlapping inclusion of the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_rsplits> rsplit(string const &h, string const &n) noexcept { @@ -1021,8 +1010,8 @@ range_rsplits> rsplit(strin } /** - * @brief Splits a string around every character from the second string. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Splits a string around every character from the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_splits> split_characters(string const &h, string const &n) noexcept { @@ -1030,8 +1019,8 @@ range_splits> split_characters(string cons } /** - * @brief Splits a string around every character from the second string in @b reverse order. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Splits a string around every character from the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_rsplits> rsplit_characters(string const &h, string const &n) noexcept { @@ -1039,8 +1028,8 @@ range_rsplits> rsplit_characters(string con } /** - * @brief Splits a string around every character except the ones from the second string. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Splits a string around every character except the ones from the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_splits> split_other_characters(string const &h, @@ -1049,8 +1038,8 @@ range_splits> split_other_characters(s } /** - * @brief Splits a string around every character except the ones from the second string in @b reverse order. - * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + * @brief Splits a string around every character except the ones from the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. */ template range_rsplits> rsplit_other_characters(string const &h, @@ -1058,14 +1047,14 @@ range_rsplits> rsplit_other_characters( return {h, n}; } -/** @brief Helper function using `std::advance` iterator and return it back. */ +/** @brief Helper function using `std::advance` iterator and return it back. */ template iterator_type advanced(iterator_type &&it, distance_type n) { std::advance(it, n); return it; } -/** @brief Helper function using `range_length` to compute the unsigned distance. */ +/** @brief Helper function using `range_length` to compute the unsigned distance. */ template std::size_t range_length(iterator_type first, iterator_type last) { return static_cast(std::distance(first, last)); @@ -1172,7 +1161,9 @@ class reversed_iterator_for { }; /** - * @brief An "expression template" for lazy concatenation of strings using the `operator|`. + * @brief An "expression template" for lazy concatenation of strings using the `operator|`. + * @see https://en.wikipedia.org/wiki/Expression_templates + * @sa `concatenate` function for usage examples. */ template struct concatenation { @@ -1218,7 +1209,7 @@ struct concatenation { * with much faster SIMD-accelerated substring search and approximate matching. * Constructors are `constexpr` enabling `_sz` literals. * - * @tparam char_type_ The character type, usually `char const` or `char`. Must be a single byte long. + * @tparam char_type_ The character type, usually `char const` or `char`. Must be a single byte long. */ template class basic_string_slice { @@ -1254,7 +1245,7 @@ class basic_string_slice { using string_view = basic_string_slice; using partition_type = string_partition_result; - /** @brief Special value for missing matches. + /** @brief Special value for missing matches. * * We take the largest 63-bit unsigned integer on 64-bit machines. * We take the largest 31-bit unsigned integer on 32-bit machines. @@ -1298,8 +1289,8 @@ class basic_string_slice { operator std::string() const { return {data(), size()}; } /** - * @brief Formatted output function for compatibility with STL's `std::basic_ostream`. - * @throw `std::ios_base::failure` if an exception occurred during output. + * @brief Formatted output function for compatibility with STL's `std::basic_ostream`. + * @throw `std::ios_base::failure` if an exception occurred during output. */ template friend std::basic_ostream &operator<<(std::basic_ostream &os, @@ -1364,7 +1355,7 @@ class basic_string_slice { } /** - * @brief Signed alternative to `at()`. Handy if you often write `str[str.size() - 2]`. + * @brief Signed alternative to `at()`. Handy if you often write `str[str.size() - 2]`. * @warning The behavior is @b undefined if the position is beyond bounds. */ reference sat(difference_type signed_offset) const noexcept { @@ -1376,6 +1367,7 @@ class basic_string_slice { /** * @brief The slice that would be dropped by `remove_prefix`, that accepts signed arguments * and does no bounds checking. Equivalent to Python's `"abc"[:2]` and `"abc"[:-1]`. + * * @warning The behavior is @b undefined if `n > size() || n < -size() || n == -0`. */ string_slice front(difference_type signed_offset) const noexcept { @@ -1420,49 +1412,49 @@ class basic_string_slice { #pragma region STL Style /** - * @brief Removes the first `n` characters from the view. + * @brief Removes the first @p `n` bytes from the view. * @warning The behavior is @b undefined if `n > size()`. */ void remove_prefix(size_type n) noexcept { assert(n <= size()), start_ += n, length_ -= n; } /** - * @brief Removes the last `n` characters from the view. + * @brief Removes the last @p `n` bytes from the view. * @warning The behavior is @b undefined if `n > size()`. */ void remove_suffix(size_type n) noexcept { assert(n <= size()), length_ -= n; } - /** @brief Added for STL compatibility. */ + /** @brief Added for STL compatibility. */ string_slice substr() const noexcept { return *this; } /** - * @brief Return a slice of this view after first `skip` bytes. - * @throws `std::out_of_range` if `skip > size()`. - * @see `sub` for a cleaner exception-less alternative. + * @brief Return a slice of this view after first @p `n` bytes. + * @throws `std::out_of_range` if `n > size()`. + * @sa `sub` for a cleaner exception-less alternative. */ - string_slice substr(size_type skip) const noexcept(false) { - if (skip > size()) throw std::out_of_range("string_slice::substr"); - return string_slice(start_ + skip, length_ - skip); + string_slice substr(size_type n) const noexcept(false) { + if (n > size()) throw std::out_of_range("string_slice::substr"); + return string_slice(start_ + n, length_ - n); } /** - * @brief Return a slice of this view after first `skip` bytes, taking at most `count` bytes. - * @throws `std::out_of_range` if `skip > size()`. - * @see `sub` for a cleaner exception-less alternative. + * @brief Return a slice of this view after first @p `n` bytes, taking at most `count` bytes. + * @throws `std::out_of_range` if `n > size()`. + * @sa `sub` for a cleaner exception-less alternative. */ - string_slice substr(size_type skip, size_type count) const noexcept(false) { - if (skip > size()) throw std::out_of_range("string_slice::substr"); - return string_slice(start_ + skip, sz_min_of_two(count, length_ - skip)); + string_slice substr(size_type n, size_type count) const noexcept(false) { + if (n > size()) throw std::out_of_range("string_slice::substr"); + return string_slice(start_ + n, sz_min_of_two(count, length_ - n)); } /** - * @brief Exports a slice of this view after first `skip` bytes, taking at most `count` bytes. - * @throws `std::out_of_range` if `skip > size()`. - * @see `sub` for a cleaner exception-less alternative. + * @brief Exports a slice of this view after first @p `n` bytes, taking at most `count` bytes. + * @throws `std::out_of_range` if `n > size()`. + * @sa `sub` for a cleaner exception-less alternative. */ - size_type copy(value_type *destination, size_type count, size_type skip = 0) const noexcept(false) { - if (skip > size()) throw std::out_of_range("string_slice::copy"); - count = sz_min_of_two(count, length_ - skip); - sz_copy((sz_ptr_t)destination, start_ + skip, count); + size_type copy(value_type *destination, size_type count, size_type n = 0) const noexcept(false) { + if (n > size()) throw std::out_of_range("string_slice::copy"); + count = sz_min_of_two(count, length_ - n); + sz_copy((sz_ptr_t)destination, start_ + n, count); return count; } @@ -1475,26 +1467,26 @@ class basic_string_slice { #pragma region Whole String Comparisons /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. */ int compare(string_view other) const noexcept { return (int)sz_order(data(), size(), other.data(), other.size()); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to `substr(pos1, count1).compare(other)`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @see Equivalent to `substr(pos1, count1).compare(other)`. * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()`. + * @throw `std::out_of_range` if `pos1 > size()`. */ int compare(size_type pos1, size_type count1, string_view other) const noexcept(false) { return substr(pos1, count1).compare(other); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to `substr(pos1, count1).compare(other.substr(pos2, count2))`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @see Equivalent to `substr(pos1, count1).compare(other.substr(pos2, count2))`. * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()` or if `pos2 > other.size()`. + * @throw `std::out_of_range` if `pos1 > size()` or if `pos2 > other.size()`. */ int compare(size_type pos1, size_type count1, string_view other, size_type pos2, size_type count2) const noexcept(false) { @@ -1502,37 +1494,37 @@ class basic_string_slice { } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. */ int compare(const_pointer other) const noexcept { return compare(string_view(other)); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to substr(pos1, count1).compare(other). + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @see Equivalent to substr(pos1, count1).compare(other). * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()`. + * @throw `std::out_of_range` if `pos1 > size()`. */ int compare(size_type pos1, size_type count1, const_pointer other) const noexcept(false) { return substr(pos1, count1).compare(string_view(other)); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to `substr(pos1, count1).compare({s, count2})`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @see Equivalent to `substr(pos1, count1).compare({s, count2})`. * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()`. + * @throw `std::out_of_range` if `pos1 > size()`. */ int compare(size_type pos1, size_type count1, const_pointer other, size_type count2) const noexcept(false) { return substr(pos1, count1).compare(string_view(other, count2)); } - /** @brief Checks if the string is equal to the other string. */ + /** @brief Checks if the string is equal to the other string. */ bool operator==(string_view other) const noexcept { return size() == other.size() && sz_equal(data(), other.data(), other.size()) == sz_true_k; } - /** @brief Checks if the string is equal to a concatenation of two strings. */ + /** @brief Checks if the string is equal to a concatenation of two strings. */ bool operator==(concatenation const &other) const noexcept { return size() == other.size() && sz_equal(data(), other.first.data(), other.first.size()) == sz_true_k && sz_equal(data() + other.first.size(), other.second.data(), other.second.size()) == sz_true_k; @@ -1540,7 +1532,7 @@ class basic_string_slice { #if _SZ_IS_CPP20 - /** @brief Computes the lexicographic ordering between this and the ::other string. */ + /** @brief Computes the lexicographic ordering between this and the ::other string. */ std::strong_ordering operator<=>(string_view other) const noexcept { std::strong_ordering orders[3] {std::strong_ordering::less, std::strong_ordering::equal, std::strong_ordering::greater}; @@ -1549,19 +1541,19 @@ class basic_string_slice { #else - /** @brief Checks if the string is not equal to the other string. */ + /** @brief Checks if the string is not equal to the other string. */ bool operator!=(string_view other) const noexcept { return !operator==(other); } - /** @brief Checks if the string is lexicographically smaller than the other string. */ + /** @brief Checks if the string is lexicographically smaller than the other string. */ bool operator<(string_view other) const noexcept { return compare(other) == sz_less_k; } - /** @brief Checks if the string is lexicographically equal or smaller than the other string. */ + /** @brief Checks if the string is lexicographically equal or smaller than the other string. */ bool operator<=(string_view other) const noexcept { return compare(other) != sz_greater_k; } - /** @brief Checks if the string is lexicographically greater than the other string. */ + /** @brief Checks if the string is lexicographically greater than the other string. */ bool operator>(string_view other) const noexcept { return compare(other) == sz_greater_k; } - /** @brief Checks if the string is lexicographically equal or greater than the other string. */ + /** @brief Checks if the string is lexicographically equal or greater than the other string. */ bool operator>=(string_view other) const noexcept { return compare(other) != sz_less_k; } #endif @@ -1569,41 +1561,41 @@ class basic_string_slice { #pragma endregion #pragma region Prefix and Suffix Comparisons - /** @brief Checks if the string starts with the other string. */ + /** @brief Checks if the string starts with the other string. */ bool starts_with(string_view other) const noexcept { return length_ >= other.size() && sz_equal(start_, other.data(), other.size()) == sz_true_k; } - /** @brief Checks if the string starts with the other string. */ + /** @brief Checks if the string starts with the other string. */ bool starts_with(const_pointer other) const noexcept { auto other_length = null_terminated_length(other); return length_ >= other_length && sz_equal(start_, other, other_length) == sz_true_k; } - /** @brief Checks if the string starts with the other character. */ + /** @brief Checks if the string starts with the other character. */ bool starts_with(value_type other) const noexcept { return length_ && start_[0] == other; } - /** @brief Checks if the string ends with the other string. */ + /** @brief Checks if the string ends with the other string. */ bool ends_with(string_view other) const noexcept { return length_ >= other.size() && sz_equal(start_ + length_ - other.size(), other.data(), other.size()) == sz_true_k; } - /** @brief Checks if the string ends with the other string. */ + /** @brief Checks if the string ends with the other string. */ bool ends_with(const_pointer other) const noexcept { auto other_length = null_terminated_length(other); return length_ >= other_length && sz_equal(start_ + length_ - other_length, other, other_length) == sz_true_k; } - /** @brief Checks if the string ends with the other character. */ + /** @brief Checks if the string ends with the other character. */ bool ends_with(value_type other) const noexcept { return length_ && start_[length_ - 1] == other; } - /** @brief Python-like convenience function, dropping the matching prefix. */ + /** @brief Python-like convenience function, dropping the matching prefix. */ string_slice remove_prefix(string_view other) const noexcept { return starts_with(other) ? string_slice {start_ + other.size(), length_ - other.size()} : *this; } - /** @brief Python-like convenience function, dropping the matching suffix. */ + /** @brief Python-like convenience function, dropping the matching suffix. */ string_slice remove_suffix(string_view other) const noexcept { return ends_with(other) ? string_slice {start_, length_ - other.size()} : *this; } @@ -1620,9 +1612,9 @@ class basic_string_slice { #pragma region Returning offsets /** - * @brief Find the first occurrence of a substring, skipping the first `skip` characters. - * The behavior is @b undefined if `skip > size()`. + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. * @return The offset of the first character of the match, or `npos` if not found. + * @warning The behavior is @b undefined if `skip > size()`. */ size_type find(string_view other, size_type skip = 0) const noexcept { auto ptr = sz_find(start_ + skip, length_ - skip, other.data(), other.size()); @@ -1630,9 +1622,9 @@ class basic_string_slice { } /** - * @brief Find the first occurrence of a character, skipping the first `skip` characters. - * The behavior is @b undefined if `skip > size()`. + * @brief Find the first occurrence of a character, skipping the first `skip` characters. * @return The offset of the match, or `npos` if not found. + * @warning The behavior is @b undefined if `skip > size()`. */ size_type find(value_type character, size_type skip = 0) const noexcept { auto ptr = sz_find_byte(start_ + skip, length_ - skip, &character); @@ -1640,16 +1632,16 @@ class basic_string_slice { } /** - * @brief Find the first occurrence of a substring, skipping the first `skip` characters. - * The behavior is @b undefined if `skip > size()`. + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. * @return The offset of the first character of the match, or `npos` if not found. + * @warning The behavior is @b undefined if `skip > size()`. */ size_type find(const_pointer other, size_type pos, size_type count) const noexcept { return find(string_view(other, count), pos); } /** - * @brief Find the last occurrence of a substring. + * @brief Find the last occurrence of a substring. * @return The offset of the first character of the match, or `npos` if not found. */ size_type rfind(string_view other) const noexcept { @@ -1658,7 +1650,7 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a substring, within first `until` characters. + * @brief Find the last occurrence of a substring, within first `until` characters. * @return The offset of the first character of the match, or `npos` if not found. */ size_type rfind(string_view other, size_type until) const noexcept(false) { @@ -1666,7 +1658,7 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a character. + * @brief Find the last occurrence of a character. * @return The offset of the match, or `npos` if not found. */ size_type rfind(value_type character) const noexcept { @@ -1675,7 +1667,7 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a character, within first `until` characters. + * @brief Find the last occurrence of a character, within first `until` characters. * @return The offset of the match, or `npos` if not found. */ size_type rfind(value_type character, size_type until) const noexcept { @@ -1683,38 +1675,38 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a substring, within first `until` characters. + * @brief Find the last occurrence of a substring, within first `until` characters. * @return The offset of the first character of the match, or `npos` if not found. */ size_type rfind(const_pointer other, size_type until, size_type count) const noexcept { return rfind(string_view(other, count), until); } - /** @brief Find the first occurrence of a character from a set. */ + /** @brief Find the first occurrence of a character from a set. */ size_type find(byteset set) const noexcept { return find_first_of(set); } - /** @brief Find the last occurrence of a character from a set. */ + /** @brief Find the last occurrence of a character from a set. */ size_type rfind(byteset set) const noexcept { return find_last_of(set); } #pragma endregion #pragma region Returning Partitions - /** @brief Split the string into three parts, before the match, the match itself, and after it. */ + /** @brief Split the string into three parts, before the match, the match itself, and after it. */ partition_type partition(string_view pattern) const noexcept { return partition_(pattern, pattern.length()); } - /** @brief Split the string into three parts, before the match, the match itself, and after it. */ + /** @brief Split the string into three parts, before the match, the match itself, and after it. */ partition_type partition(value_type pattern) const noexcept { return partition_(string_view(&pattern, 1), 1); } - /** @brief Split the string into three parts, before the match, the match itself, and after it. */ + /** @brief Split the string into three parts, before the match, the match itself, and after it. */ partition_type partition(byteset pattern) const noexcept { return partition_(pattern, 1); } - /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ + /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ partition_type rpartition(string_view pattern) const noexcept { return rpartition_(pattern, pattern.length()); } - /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ + /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ partition_type rpartition(value_type pattern) const noexcept { return rpartition_(string_view(&pattern, 1), 1); } - /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ + /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ partition_type rpartition(byteset pattern) const noexcept { return rpartition_(pattern, 1); } #pragma endregion @@ -1735,8 +1727,8 @@ class basic_string_slice { #pragma region Character Set Arguments /** - * @brief Find the first occurrence of a character from a set. - * @param skip Number of characters to skip before the search. + * @brief Find the first occurrence of a character from a @p `set`. + * @param[in] skip Number of characters to skip before the search. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_of(byteset set, size_type skip = 0) const noexcept { @@ -1745,8 +1737,8 @@ class basic_string_slice { } /** - * @brief Find the first occurrence of a character outside a set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character outside a @p `set`. + * @param[in] skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_not_of(byteset set, size_type skip = 0) const noexcept { @@ -1754,7 +1746,7 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a character from a set. + * @brief Find the last occurrence of a character from a @p `set`. */ size_type find_last_of(byteset set) const noexcept { auto ptr = sz_rfind_byteset(start_, length_, &set.raw()); @@ -1762,13 +1754,13 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a character outside a set. + * @brief Find the last occurrence of a character outside a @p `set`. */ size_type find_last_not_of(byteset set) const noexcept { return find_last_of(set.inverted()); } /** - * @brief Find the last occurrence of a character from a set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character from a @p `set`. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_of(byteset set, size_type until) const noexcept { auto len = sz_min_of_two(until + 1, length_); @@ -1777,8 +1769,8 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a character outside a set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character outside a @p `set`. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_not_of(byteset set, size_type until) const noexcept { return find_last_of(set.inverted(), until); @@ -1788,32 +1780,32 @@ class basic_string_slice { #pragma region String Arguments /** - * @brief Find the first occurrence of a character from a ::set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character from the @p `other` string. + * @param[in] skip The number of first characters to be skipped. */ size_type find_first_of(string_view other, size_type skip = 0) const noexcept { return find_first_of(other.as_set(), skip); } /** - * @brief Find the first occurrence of a character outside a ::set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character missing in the @p `other` string. + * @param[in] skip The number of first characters to be skipped. */ size_type find_first_not_of(string_view other, size_type skip = 0) const noexcept { return find_first_not_of(other.as_set(), skip); } /** - * @brief Find the last occurrence of a character from a ::set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character from the @p `other` string. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_of(string_view other, size_type until = npos) const noexcept { return find_last_of(other.as_set(), until); } /** - * @brief Find the last occurrence of a character outside a ::set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character missing in the @p `other` string. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_not_of(string_view other, size_type until = npos) const noexcept { return find_last_not_of(other.as_set(), until); @@ -1823,8 +1815,8 @@ class basic_string_slice { #pragma region C Style Arguments /** - * @brief Find the first occurrence of a character from a set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character from the @p `other` string. + * @param[in] skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_of(const_pointer other, size_type skip, size_type count) const noexcept { @@ -1832,8 +1824,8 @@ class basic_string_slice { } /** - * @brief Find the first occurrence of a character outside a set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character missing in the @p `other` string. + * @param[in] skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_not_of(const_pointer other, size_type skip, size_type count) const noexcept { @@ -1841,16 +1833,16 @@ class basic_string_slice { } /** - * @brief Find the last occurrence of a character from a set. - * @param until The number of first characters to be considered. + * @brief Find the last occurrence of a character from the @p `other` string. + * @param[in] until The number of first characters to be considered. */ size_type find_last_of(const_pointer other, size_type until, size_type count) const noexcept { return find_last_of(string_view(other, count), until); } /** - * @brief Find the last occurrence of a character outside a set. - * @param until The number of first characters to be considered. + * @brief Find the last occurrence of a character missing in the @p `other` string. + * @param[in] until The number of first characters to be considered. */ size_type find_last_not_of(const_pointer other, size_type until, size_type count) const noexcept { return find_last_not_of(string_view(other, count), until); @@ -1860,8 +1852,8 @@ class basic_string_slice { #pragma region Slicing /** - * @brief Python-like convenience function, dropping prefix formed of given characters. - * Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. + * @brief Python-like convenience function, dropping prefix formed of given characters. + * @see Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. */ string_slice lstrip(byteset set) const noexcept { set = set.inverted(); @@ -1871,8 +1863,8 @@ class basic_string_slice { } /** - * @brief Python-like convenience function, dropping suffix formed of given characters. - * Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. + * @brief Python-like convenience function, dropping suffix formed of given characters. + * @see Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. */ string_slice rstrip(byteset set) const noexcept { set = set.inverted(); @@ -1881,8 +1873,8 @@ class basic_string_slice { } /** - * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. - * Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. + * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. + * @see Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. */ string_slice strip(byteset set) const noexcept { set = set.inverted(); @@ -1908,22 +1900,22 @@ class basic_string_slice { using find_all_chars_type = range_matches>; using rfind_all_chars_type = range_rmatches>; - /** @brief Find all potentially @b overlapping occurrences of a given string. */ + /** @brief Find all potentially @b overlapping occurrences of a given string. */ find_all_type find_all(string_view needle, include_overlaps_type = {}) const noexcept { return {*this, needle}; } - /** @brief Find all potentially @b overlapping occurrences of a given string in @b reverse order. */ + /** @brief Find all potentially @b overlapping occurrences of a given string in @b reverse order. */ rfind_all_type rfind_all(string_view needle, include_overlaps_type = {}) const noexcept { return {*this, needle}; } - /** @brief Find all @b non-overlapping occurrences of a given string. */ + /** @brief Find all @b non-overlapping occurrences of a given string. */ find_disjoint_type find_all(string_view needle, exclude_overlaps_type) const noexcept { return {*this, needle}; } - /** @brief Find all @b non-overlapping occurrences of a given string in @b reverse order. */ + /** @brief Find all @b non-overlapping occurrences of a given string in @b reverse order. */ rfind_disjoint_type rfind_all(string_view needle, exclude_overlaps_type) const noexcept { return {*this, needle}; } - /** @brief Find all occurrences of given characters. */ + /** @brief Find all occurrences of given characters. */ find_all_chars_type find_all(byteset set) const noexcept { return {*this, {set}}; } - /** @brief Find all occurrences of given characters in @b reverse order. */ + /** @brief Find all occurrences of given characters in @b reverse order. */ rfind_all_chars_type rfind_all(byteset set) const noexcept { return {*this, {set}}; } using split_type = range_splits>; @@ -1932,32 +1924,32 @@ class basic_string_slice { using split_chars_type = range_splits>; using rsplit_chars_type = range_rsplits>; - /** @brief Split around occurrences of a given string. */ + /** @brief Split around occurrences of a given string. */ split_type split(string_view delimiter) const noexcept { return {*this, delimiter}; } - /** @brief Split around occurrences of a given string in @b reverse order. */ + /** @brief Split around occurrences of a given string in @b reverse order. */ rsplit_type rsplit(string_view delimiter) const noexcept { return {*this, delimiter}; } - /** @brief Split around occurrences of given characters. */ + /** @brief Split around occurrences of given characters. */ split_chars_type split(byteset set = whitespaces_set()) const noexcept { return {*this, {set}}; } - /** @brief Split around occurrences of given characters in @b reverse order. */ + /** @brief Split around occurrences of given characters in @b reverse order. */ rsplit_chars_type rsplit(byteset set = whitespaces_set()) const noexcept { return {*this, {set}}; } - /** @brief Split around the occurrences of all newline characters. */ + /** @brief Split around the occurrences of all newline characters. */ split_chars_type splitlines() const noexcept { return split(newlines_set()); } #pragma endregion - /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ + /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ size_type hash(std::uint64_t seed = 42) const noexcept { return static_cast(sz_hash(start_, length_, static_cast(seed))); } - /** @brief Aggregates the values of individual bytes of a string. */ + /** @brief Aggregates the values of individual bytes of a string. */ size_type bytesum() const noexcept { return static_cast(sz_bytesum(start_, length_)); } - /** @brief Populate a character set with characters present in this string. */ + /** @brief Populate a character set with characters present in this string. */ byteset as_set() const noexcept { byteset set; for (auto c : *this) set.add(c); @@ -2157,7 +2149,7 @@ class basic_string { basic_string(std::nullptr_t) = delete; - /** @brief Construct a string by repeating a certain ::character ::count times. */ + /** @brief Construct a string by repeating a certain @p character @p count times. */ basic_string(size_type count, value_type character) noexcept(false) { init(count, character); } basic_string(basic_string const &other, size_type pos) noexcept(false) { init(string_view(other).substr(pos)); } @@ -2210,8 +2202,8 @@ class basic_string { operator std::string() const { return view(); } /** - * @brief Formatted output function for compatibility with STL's `std::basic_ostream`. - * @throw `std::ios_base::failure` if an exception occurred during output. + * @brief Formatted output function for compatibility with STL's `std::basic_ostream`. + * @throw `std::ios_base::failure` if an exception occurred during output. */ template friend std::basic_ostream &operator<<(std::basic_ostream &os, @@ -2231,12 +2223,12 @@ class basic_string { template explicit basic_string(concatenation const &expression) noexcept(false) { - _with_alloc([&](sz_alloc_type &alloc) { + raise(_with_alloc([&](sz_alloc_type &alloc) { sz_ptr_t ptr = sz_string_init_length(&string_, expression.length(), &alloc); - if (!ptr) return false; + if (!ptr) return sz_bad_alloc_k; expression.copy(ptr); - return true; - }); + return sz_success_k; + })); } template @@ -2318,21 +2310,21 @@ class basic_string { string_span operator[](std::initializer_list offsets) noexcept { return span()[offsets]; } /** - * @brief Signed alternative to `at()`. Handy if you often write `str[str.size() - 2]`. + * @brief Signed alternative to `at()`. Handy if you often write `str[str.size() - 2]`. * @warning The behavior is @b undefined if the position is beyond bounds. */ value_type sat(difference_type offset) const noexcept { return view().sat(offset); } reference sat(difference_type offset) noexcept { return span().sat(offset); } /** - * @brief The opposite operation to `remove_prefix`, that does no bounds checking. + * @brief The opposite operation to `remove_prefix`, that does no bounds checking. * @warning The behavior is @b undefined if `n > size()`. */ string_view front(difference_type n) const noexcept { return view().front(n); } string_span front(difference_type n) noexcept { return span().front(n); } /** - * @brief The opposite operation to `remove_prefix`, that does no bounds checking. + * @brief The opposite operation to `remove_prefix`, that does no bounds checking. * @warning The behavior is @b undefined if `n > size()`. */ string_view back(difference_type n) const noexcept { return view().back(n); } @@ -2356,7 +2348,7 @@ class basic_string { #pragma region STL Style /** - * @brief Removes the first `n` characters from the view. + * @brief Removes the first `n` characters from the view. * @warning The behavior is @b undefined if `n > size()`. */ void remove_prefix(size_type n) noexcept { @@ -2365,7 +2357,7 @@ class basic_string { } /** - * @brief Removes the last `n` characters from the view. + * @brief Removes the last `n` characters from the view. * @warning The behavior is @b undefined if `n > size()`. */ void remove_suffix(size_type n) noexcept { @@ -2373,27 +2365,27 @@ class basic_string { sz_string_erase(&string_, size() - n, n); } - /** @brief Added for STL compatibility. */ + /** @brief Added for STL compatibility. */ basic_string substr() const noexcept { return *this; } /** - * @brief Return a slice of this view after first `skip` bytes. + * @brief Return a slice of this view after first `skip` bytes. * @throws `std::out_of_range` if `skip > size()`. - * @see `sub` for a cleaner exception-less alternative. + * @sa `sub` for a cleaner exception-less alternative. */ basic_string substr(size_type skip) const noexcept(false) { return view().substr(skip); } /** - * @brief Return a slice of this view after first `skip` bytes, taking at most `count` bytes. + * @brief Return a slice of this view after first `skip` bytes, taking at most `count` bytes. * @throws `std::out_of_range` if `skip > size()`. - * @see `sub` for a cleaner exception-less alternative. + * @sa `sub` for a cleaner exception-less alternative. */ basic_string substr(size_type skip, size_type count) const noexcept(false) { return view().substr(skip, count); } /** - * @brief Exports a slice of this view after first `skip` bytes, taking at most `count` bytes. + * @brief Exports a slice of this view after first `skip` bytes, taking at most `count` bytes. * @throws `std::out_of_range` if `skip > size()`. - * @see `sub` for a cleaner exception-less alternative. + * @sa `sub` for a cleaner exception-less alternative. */ size_type copy(value_type *destination, size_type count, size_type skip = 0) const noexcept(false) { return view().copy(destination, count, skip); @@ -2408,26 +2400,26 @@ class basic_string { #pragma region Whole String Comparisons /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. */ int compare(string_view other) const noexcept { return view().compare(other); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to `substr(pos1, count1).compare(other)`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()`. + * @throw `std::out_of_range` if `pos1 > size()`. + * @sa Equivalent to `substr(pos1, count1).compare(other)`. */ int compare(size_type pos1, size_type count1, string_view other) const noexcept(false) { return view().compare(pos1, count1, other); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to `substr(pos1, count1).compare(other.substr(pos2, count2))`. - * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()` or if `pos2 > other.size()`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than @p other, positive if `*this` is greater than @p other. + * @throw `std::out_of_range` if `pos1 > size()` or if `pos2 > other.size()`. + * @sa Equivalent to `substr(pos1, count1).compare(other.substr(pos2, count2))`. */ int compare(size_type pos1, size_type count1, string_view other, size_type pos2, size_type count2) const noexcept(false) { @@ -2435,58 +2427,58 @@ class basic_string { } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than @p other, positive if `*this` is greater than @p other. */ int compare(const_pointer other) const noexcept { return view().compare(other); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to substr(pos1, count1).compare(other). - * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than @p other, positive if `*this` is greater than @p other. + * @throw `std::out_of_range` if `pos1 > size()`. + * @sa Equivalent to `substr(pos1, count1).compare(other)`. */ int compare(size_type pos1, size_type count1, const_pointer other) const noexcept(false) { return view().compare(pos1, count1, other); } /** - * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. - * Equivalent to `substr(pos1, count1).compare({s, count2})`. - * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. - * @throw `std::out_of_range` if `pos1 > size()`. + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than @p other, positive if `*this` is greater than @p other. + * @throw `std::out_of_range` if `pos1 > size()`. + * @sa Equivalent to `substr(pos1, count1).compare({s, count2})`. */ int compare(size_type pos1, size_type count1, const_pointer other, size_type count2) const noexcept(false) { return view().compare(pos1, count1, other, count2); } - /** @brief Checks if the string is equal to the other string. */ + /** @brief Checks if the string is equal to the other string. */ bool operator==(basic_string const &other) const noexcept { return view() == other.view(); } bool operator==(string_view other) const noexcept { return view() == other; } bool operator==(const_pointer other) const noexcept { return view() == string_view(other); } #if _SZ_IS_CPP20 - /** @brief Computes the lexicographic ordering between this and the ::other string. */ + /** @brief Computes the lexicographic ordering between this and the @p other string. */ std::strong_ordering operator<=>(basic_string const &other) const noexcept { return view() <=> other.view(); } std::strong_ordering operator<=>(string_view other) const noexcept { return view() <=> other; } std::strong_ordering operator<=>(const_pointer other) const noexcept { return view() <=> string_view(other); } #else - /** @brief Checks if the string is not equal to the other string. */ + /** @brief Checks if the string is not equal to the other string. */ bool operator!=(string_view other) const noexcept { return !operator==(other); } - /** @brief Checks if the string is lexicographically smaller than the other string. */ + /** @brief Checks if the string is lexicographically smaller than the other string. */ bool operator<(string_view other) const noexcept { return compare(other) == sz_less_k; } - /** @brief Checks if the string is lexicographically equal or smaller than the other string. */ + /** @brief Checks if the string is lexicographically equal or smaller than the other string. */ bool operator<=(string_view other) const noexcept { return compare(other) != sz_greater_k; } - /** @brief Checks if the string is lexicographically greater than the other string. */ + /** @brief Checks if the string is lexicographically greater than the other string. */ bool operator>(string_view other) const noexcept { return compare(other) == sz_greater_k; } - /** @brief Checks if the string is lexicographically equal or greater than the other string. */ + /** @brief Checks if the string is lexicographically equal or greater than the other string. */ bool operator>=(string_view other) const noexcept { return compare(other) != sz_less_k; } #endif @@ -2494,22 +2486,22 @@ class basic_string { #pragma endregion #pragma region Prefix and Suffix Comparisons - /** @brief Checks if the string starts with the other string. */ + /** @brief Checks if the string starts with the other string. */ bool starts_with(string_view other) const noexcept { return view().starts_with(other); } - /** @brief Checks if the string starts with the other string. */ + /** @brief Checks if the string starts with the other string. */ bool starts_with(const_pointer other) const noexcept { return view().starts_with(other); } - /** @brief Checks if the string starts with the other character. */ + /** @brief Checks if the string starts with the other character. */ bool starts_with(value_type other) const noexcept { return view().starts_with(other); } - /** @brief Checks if the string ends with the other string. */ + /** @brief Checks if the string ends with the other string. */ bool ends_with(string_view other) const noexcept { return view().ends_with(other); } - /** @brief Checks if the string ends with the other string. */ + /** @brief Checks if the string ends with the other string. */ bool ends_with(const_pointer other) const noexcept { return view().ends_with(other); } - /** @brief Checks if the string ends with the other character. */ + /** @brief Checks if the string ends with the other character. */ bool ends_with(value_type other) const noexcept { return view().ends_with(other); } #pragma endregion @@ -2524,64 +2516,64 @@ class basic_string { #pragma region Returning offsets /** - * @brief Find the first occurrence of a substring, skipping the first `skip` characters. - * The behavior is @b undefined if `skip > size()`. + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. * @return The offset of the first character of the match, or `npos` if not found. + * @warning The behavior is @b undefined if `skip > size()`. */ size_type find(string_view other, size_type skip = 0) const noexcept { return view().find(other, skip); } /** - * @brief Find the first occurrence of a character, skipping the first `skip` characters. - * The behavior is @b undefined if `skip > size()`. + * @brief Find the first occurrence of a character, skipping the first `skip` characters. * @return The offset of the match, or `npos` if not found. + * @warning The behavior is @b undefined if `skip > size()`. */ size_type find(value_type character, size_type skip = 0) const noexcept { return view().find(character, skip); } /** - * @brief Find the first occurrence of a substring, skipping the first `skip` characters. - * The behavior is @b undefined if `skip > size()`. + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. * @return The offset of the first character of the match, or `npos` if not found. + * @warning The behavior is @b undefined if `skip > size()`. */ size_type find(const_pointer other, size_type pos, size_type count) const noexcept { return view().find(other, pos, count); } /** - * @brief Find the last occurrence of a substring. + * @brief Find the last occurrence of a substring. * @return The offset of the first character of the match, or `npos` if not found. */ size_type rfind(string_view other) const noexcept { return view().rfind(other); } /** - * @brief Find the last occurrence of a substring, within first `until` characters. + * @brief Find the last occurrence of a substring, within first `until` characters. * @return The offset of the first character of the match, or `npos` if not found. */ size_type rfind(string_view other, size_type until) const noexcept { return view().rfind(other, until); } /** - * @brief Find the last occurrence of a character. + * @brief Find the last occurrence of a character. * @return The offset of the match, or `npos` if not found. */ size_type rfind(value_type character) const noexcept { return view().rfind(character); } /** - * @brief Find the last occurrence of a character, within first `until` characters. + * @brief Find the last occurrence of a character, within first `until` characters. * @return The offset of the match, or `npos` if not found. */ size_type rfind(value_type character, size_type until) const noexcept { return view().rfind(character, until); } /** - * @brief Find the last occurrence of a substring, within first `until` characters. + * @brief Find the last occurrence of a substring, within first `until` characters. * @return The offset of the first character of the match, or `npos` if not found. */ size_type rfind(const_pointer other, size_type until, size_type count) const noexcept { return view().rfind(other, until, count); } - /** @brief Find the first occurrence of a character from a set. */ + /** @brief Find the first occurrence of a character from a set. */ size_type find(byteset set) const noexcept { return view().find(set); } - /** @brief Find the last occurrence of a character from a set. */ + /** @brief Find the last occurrence of a character from a set. */ size_type rfind(byteset set) const noexcept { return view().rfind(set); } #pragma endregion @@ -2603,40 +2595,36 @@ class basic_string { #pragma region Character Set Arguments /** - * @brief Find the first occurrence of a character from a set. - * @param skip Number of characters to skip before the search. + * @brief Find the first occurrence of a character from a @p `set`. + * @param[in] skip Number of characters to skip before the search. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_of(byteset set, size_type skip = 0) const noexcept { return view().find_first_of(set, skip); } /** - * @brief Find the first occurrence of a character outside a set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character outside a @p `set`. + * @param[in] skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_not_of(byteset set, size_type skip = 0) const noexcept { return view().find_first_not_of(set, skip); } - /** - * @brief Find the last occurrence of a character from a set. - */ + /** @brief Find the last occurrence of a character from a @p `set`. */ size_type find_last_of(byteset set) const noexcept { return view().find_last_of(set); } - /** - * @brief Find the last occurrence of a character outside a set. - */ + /** @brief Find the last occurrence of a character outside a @p `set`. */ size_type find_last_not_of(byteset set) const noexcept { return view().find_last_not_of(set); } /** - * @brief Find the last occurrence of a character from a set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character from a @p `set`. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_of(byteset set, size_type until) const noexcept { return view().find_last_of(set, until); } /** - * @brief Find the last occurrence of a character outside a set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character outside a @p `set`. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_not_of(byteset set, size_type until) const noexcept { return view().find_last_not_of(set, until); @@ -2646,32 +2634,32 @@ class basic_string { #pragma region String Arguments /** - * @brief Find the first occurrence of a character from a ::set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character from the @p `other` string. + * @param[in] skip The number of first characters to be skipped. */ size_type find_first_of(string_view other, size_type skip = 0) const noexcept { return view().find_first_of(other, skip); } /** - * @brief Find the first occurrence of a character outside a ::set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character outside the @p `other` string. + * @param[in] skip The number of first characters to be skipped. */ size_type find_first_not_of(string_view other, size_type skip = 0) const noexcept { return view().find_first_not_of(other, skip); } /** - * @brief Find the last occurrence of a character from a ::set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character from the @p `other` string. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_of(string_view other, size_type until = npos) const noexcept { return view().find_last_of(other, until); } /** - * @brief Find the last occurrence of a character outside a ::set. - * @param until The offset of the last character to be considered. + * @brief Find the last occurrence of a character outside the @p `other` string. + * @param[in] until The offset of the last character to be considered. */ size_type find_last_not_of(string_view other, size_type until = npos) const noexcept { return view().find_last_not_of(other, until); @@ -2681,8 +2669,8 @@ class basic_string { #pragma region C Style Arguments /** - * @brief Find the first occurrence of a character from a set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character from a set. + * @param[in] skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_of(const_pointer other, size_type skip, size_type count) const noexcept { @@ -2690,8 +2678,8 @@ class basic_string { } /** - * @brief Find the first occurrence of a character outside a set. - * @param skip The number of first characters to be skipped. + * @brief Find the first occurrence of a character outside a set. + * @param[in] skip The number of first characters to be skipped. * @warning The behavior is @b undefined if `skip > size()`. */ size_type find_first_not_of(const_pointer other, size_type skip, size_type count) const noexcept { @@ -2699,16 +2687,16 @@ class basic_string { } /** - * @brief Find the last occurrence of a character from a set. - * @param until The number of first characters to be considered. + * @brief Find the last occurrence of a character from a set. + * @param[in] until The number of first characters to be considered. */ size_type find_last_of(const_pointer other, size_type until, size_type count) const noexcept { return view().find_last_of(other, until, count); } /** - * @brief Find the last occurrence of a character outside a set. - * @param until The number of first characters to be considered. + * @brief Find the last occurrence of a character outside a set. + * @param[in] until The number of first characters to be considered. */ size_type find_last_not_of(const_pointer other, size_type until, size_type count) const noexcept { return view().find_last_not_of(other, until, count); @@ -2718,8 +2706,8 @@ class basic_string { #pragma region Slicing /** - * @brief Python-like convenience function, dropping prefix formed of given characters. - * Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. + * @brief Python-like convenience function, dropping prefix formed of given characters. + * @see Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. */ basic_string &lstrip(byteset set) noexcept { auto remaining = view().lstrip(set); @@ -2729,7 +2717,7 @@ class basic_string { /** * @brief Python-like convenience function, dropping suffix formed of given characters. - * Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. + * @see Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. */ basic_string &rstrip(byteset set) noexcept { auto remaining = view().rstrip(set); @@ -2738,8 +2726,8 @@ class basic_string { } /** - * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. - * Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. + * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. + * @see Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. */ basic_string &strip(byteset set) noexcept { return lstrip(set).rstrip(set); } @@ -2750,15 +2738,15 @@ class basic_string { #pragma region Non STL API /** - * @brief Resizes the string to a specified number of characters, padding with the specified character if needed. - * @param count The new size of the string. - * @param character The character to fill new elements with, if expanding. Defaults to null character. + * @brief Resizes the string to a specified number of characters, padding with the specified character if needed. + * @param[in] count The new size of the string. + * @param[in] character The character to fill new elements with, if expanding. Defaults to null character. * @return `true` if the resizing was successful, `false` otherwise. */ bool try_resize(size_type count, value_type character = '\0') noexcept; /** - * @brief Attempts to reduce memory usage by freeing unused memory. + * @brief Attempts to reduce memory usage by freeing unused memory. * @return `true` if the operation was successful and potentially reduced the memory footprint, `false` otherwise. */ bool try_shrink_to_fit() noexcept { @@ -2766,8 +2754,8 @@ class basic_string { } /** - * @brief Attempts to reserve enough space for a specified number of characters. - * @param capacity The new capacity to reserve. + * @brief Attempts to reserve enough space for a specified number of characters. + * @param[in] capacity The new capacity to reserve. * @return `true` if the reservation was successful, `false` otherwise. */ bool try_reserve(size_type capacity) noexcept { @@ -2775,44 +2763,44 @@ class basic_string { } /** - * @brief Assigns a new value to the string, replacing its current contents. - * @param other The string view whose contents to assign. + * @brief Assigns a new value to the string, replacing its current contents. + * @param[in] other The string view whose contents to assign. * @return `true` if the assignment was successful, `false` otherwise. */ bool try_assign(string_view other) noexcept; /** - * @brief Assigns a concatenated sequence to the string, replacing its current contents. - * @param other The concatenation object representing the sequence to assign. + * @brief Assigns a concatenated sequence to the string, replacing its current contents. + * @param[in] other The concatenation object representing the sequence to assign. * @return `true` if the assignment was successful, `false` otherwise. */ template bool try_assign(concatenation const &other) noexcept; /** - * @brief Attempts to add a single character to the end of the string. - * @param c The character to add. + * @brief Attempts to add a single character to the end of the string. + * @param[in] c The character to add. * @return `true` if the character was successfully added, `false` otherwise. */ bool try_push_back(char_type c) noexcept; /** - * @brief Attempts to append a given character array to the string. - * @param str The pointer to the array of characters to append. - * @param length The number of characters to append. + * @brief Attempts to append a given character array to the string. + * @param[in] str The pointer to the array of characters to append. + * @param[in] length The number of characters to append. * @return `true` if the append operation was successful, `false` otherwise. */ bool try_append(const_pointer str, size_type length) noexcept; /** - * @brief Attempts to append a string view to the string. - * @param str The string view to append. + * @brief Attempts to append a string view to the string. + * @param[in] str The string view to append. * @return `true` if the append operation was successful, `false` otherwise. */ bool try_append(string_view str) noexcept { return try_append(str.data(), str.size()); } /** - * @brief Clears the contents of the string and resets its length to 0. + * @brief Clears the contents of the string and resets its length to 0. * @return Always returns `true` as this operation cannot fail under normal conditions. */ bool try_clear() noexcept { @@ -2821,7 +2809,7 @@ class basic_string { } /** - * @brief Erases @b (in-place) a range of characters defined with signed offsets. + * @brief Erases @b (in-place) a range of characters defined with signed offsets. * @return Number of characters removed. */ size_type try_erase(difference_type signed_start_offset = 0, difference_type signed_end_offset = npos) noexcept { @@ -2833,7 +2821,7 @@ class basic_string { } /** - * @brief Inserts @b (in-place) a range of characters at a given signed offset. + * @brief Inserts @b (in-place) a range of characters at a given signed offset. * @return `true` if the insertion was successful, `false` otherwise. */ bool try_insert(difference_type signed_offset, string_view string) noexcept { @@ -2866,16 +2854,13 @@ class basic_string { #pragma region STL Interfaces - /** - * @brief Clears the string contents, but @b no deallocations happen. - */ + /** @brief Clears the string contents, but @b no deallocations happen. */ void clear() noexcept { sz_string_erase(&string_, 0, SZ_SIZE_MAX); } /** - * @brief Resizes the string to the given size, filling the new space with the given character, - * or NULL-character if nothing is provided. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Resizes the string to match @p count, filling the new space with the given @p character. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ void resize(size_type count, value_type character = '\0') noexcept(false) { if (count > max_size()) throw std::length_error("sz::basic_string::resize"); @@ -2883,16 +2868,16 @@ class basic_string { } /** - * @brief Reclaims the unused memory, if any. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Reclaims the unused memory, if any. + * @throw `std::bad_alloc` if the allocation fails. */ void shrink_to_fit() noexcept(false) { if (!try_shrink_to_fit()) throw std::bad_alloc(); } /** - * @brief Informs the string object of a planned change in size, so that it pre-allocate once. - * @throw `std::length_error` if the string is too long. + * @brief Informs the string object of a planned change in size, so that it pre-allocate once. + * @throw `std::length_error` if the string is too long. */ void reserve(size_type capacity) noexcept(false) { if (capacity > max_size()) throw std::length_error("sz::basic_string::reserve"); @@ -2900,10 +2885,10 @@ class basic_string { } /** - * @brief Inserts @b (in-place) a ::character multiple times at the given offset. - * @throw `std::out_of_range` if `offset > size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) a ::character multiple times at the given offset. + * @throw `std::out_of_range` if `offset > size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ basic_string &insert(size_type offset, size_type repeats, char_type character) noexcept(false) { if (offset > size()) throw std::out_of_range("sz::basic_string::insert"); @@ -2916,10 +2901,10 @@ class basic_string { } /** - * @brief Inserts @b (in-place) a range of characters at the given offset. - * @throw `std::out_of_range` if `offset > size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) a range of characters at the given offset. + * @throw `std::out_of_range` if `offset > size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ basic_string &insert(size_type offset, string_view other) noexcept(false) { if (offset > size()) throw std::out_of_range("sz::basic_string::insert"); @@ -2933,20 +2918,20 @@ class basic_string { } /** - * @brief Inserts @b (in-place) a range of characters at the given offset. - * @throw `std::out_of_range` if `offset > size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) a range of characters at the given offset. + * @throw `std::out_of_range` if `offset > size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ basic_string &insert(size_type offset, const_pointer start, size_type length) noexcept(false) { return insert(offset, string_view(start, length)); } /** - * @brief Inserts @b (in-place) a slice of another string at the given offset. - * @throw `std::out_of_range` if `offset > size()` or `other_index > other.size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) a slice of another string at the given offset. + * @throw `std::out_of_range` if `offset > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ basic_string &insert(size_type offset, string_view other, size_type other_index, size_type count = npos) noexcept(false) { @@ -2966,10 +2951,10 @@ class basic_string { } /** - * @brief Inserts @b (in-place) a ::character multiple times at the given iterator position. - * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) a ::character multiple times at the given iterator position. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ iterator insert(const_iterator it, size_type repeats, char_type character) noexcept(false) { auto pos = range_length(cbegin(), it); @@ -2978,10 +2963,10 @@ class basic_string { } /** - * @brief Inserts @b (in-place) a range at the given iterator position. - * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) a range at the given iterator position. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ template iterator insert(const_iterator it, input_iterator first, input_iterator last) noexcept(false) { @@ -3001,19 +2986,19 @@ class basic_string { } /** - * @brief Inserts @b (in-place) an initializer list of characters. - * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) an initializer list of characters. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ iterator insert(const_iterator it, std::initializer_list list) noexcept(false) { return insert(it, list.begin(), list.end()); } /** - * @brief Erases @b (in-place) the given range of characters. + * @brief Erases @b (in-place) the given range of characters. * @throws `std::out_of_range` if `pos > size()`. - * @see `try_erase_slice` for a cleaner exception-less alternative. + * @sa `try_erase_slice` for a cleaner exception-less alternative. */ basic_string &erase(size_type pos = 0, size_type count = npos) noexcept(false) { if (!count || empty()) return *this; @@ -3023,7 +3008,7 @@ class basic_string { } /** - * @brief Erases @b (in-place) the given range of characters. + * @brief Erases @b (in-place) the given range of characters. * @return Iterator pointing following the erased character, or end() if no such character exists. */ iterator erase(const_iterator first, const_iterator last) noexcept { @@ -3034,16 +3019,16 @@ class basic_string { } /** - * @brief Erases @b (in-place) the one character at a given postion. + * @brief Erases @b (in-place) the one character at a given postion. * @return Iterator pointing following the erased character, or end() if no such character exists. */ iterator erase(const_iterator pos) noexcept { return erase(pos, pos + 1); } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(size_type pos, size_type count, string_view const &str) noexcept(false) { if (pos > size()) throw std::out_of_range("sz::basic_string::replace"); @@ -3054,20 +3039,20 @@ class basic_string { } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(const_iterator first, const_iterator last, string_view const &str) noexcept(false) { return replace(range_length(cbegin(), first), last - first, str); } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()` or `pos2 > str.size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(size_type pos, size_type count, string_view const &str, size_type pos2, size_type count2 = npos) noexcept(false) { @@ -3075,20 +3060,20 @@ class basic_string { } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(size_type pos, size_type count, const_pointer cstr, size_type count2) noexcept(false) { return replace(pos, count, string_view(cstr, count2)); } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(const_iterator first, const_iterator last, const_pointer cstr, size_type count2) noexcept(false) { @@ -3096,30 +3081,30 @@ class basic_string { } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(size_type pos, size_type count, const_pointer cstr) noexcept(false) { return replace(pos, count, string_view(cstr)); } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(const_iterator first, const_iterator last, const_pointer cstr) noexcept(false) { return replace(range_length(cbegin(), first), last - first, string_view(cstr)); } /** - * @brief Replaces @b (in-place) a range of characters with a repetition of given characters. + * @brief Replaces @b (in-place) a range of characters with a repetition of given characters. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(size_type pos, size_type count, size_type count2, char_type character) noexcept(false) { if (pos > size()) throw std::out_of_range("sz::basic_string::replace"); @@ -3130,10 +3115,10 @@ class basic_string { } /** - * @brief Replaces @b (in-place) a range of characters with a repetition of given characters. + * @brief Replaces @b (in-place) a range of characters with a repetition of given characters. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(const_iterator first, const_iterator last, size_type count2, char_type character) noexcept(false) { @@ -3141,10 +3126,10 @@ class basic_string { } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ template basic_string &replace(const_iterator first, const_iterator last, input_iterator first2, @@ -3160,10 +3145,10 @@ class basic_string { } /** - * @brief Replaces @b (in-place) a range of characters with a given initializer list. + * @brief Replaces @b (in-place) a range of characters with a given initializer list. * @throws `std::out_of_range` if `pos > size()`. * @throws `std::length_error` if the string is too long. - * @see `try_replace` for a cleaner exception-less alternative. + * @sa `try_replace` for a cleaner exception-less alternative. */ basic_string &replace(const_iterator first, const_iterator last, std::initializer_list list) noexcept(false) { @@ -3171,9 +3156,9 @@ class basic_string { } /** - * @brief Appends the given character at the end. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Appends the given character at the end. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ void push_back(char_type ch) noexcept(false) { if (size() == max_size()) throw std::length_error("string::push_back"); @@ -3181,16 +3166,16 @@ class basic_string { } /** - * @brief Removes the last character from the string. + * @brief Removes the last character from the string. * @warning The behavior is @b undefined if the string is empty. */ void pop_back() noexcept { sz_string_erase(&string_, size() - 1, 1); } /** - * @brief Overwrites the string with the given string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_assign` for a cleaner exception-less alternative. + * @brief Overwrites the string with the given string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_assign` for a cleaner exception-less alternative. */ basic_string &assign(string_view other) noexcept(false) { if (!try_assign(other)) throw std::bad_alloc(); @@ -3198,10 +3183,10 @@ class basic_string { } /** - * @brief Overwrites the string with the given repeated character. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_assign` for a cleaner exception-less alternative. + * @brief Overwrites the string with the given repeated character. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_assign` for a cleaner exception-less alternative. */ basic_string &assign(size_type repeats, char_type character) noexcept(false) { resize(repeats, character); @@ -3210,28 +3195,28 @@ class basic_string { } /** - * @brief Overwrites the string with the given string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_assign` for a cleaner exception-less alternative. + * @brief Overwrites the string with the given string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_assign` for a cleaner exception-less alternative. */ basic_string &assign(const_pointer other, size_type length) noexcept(false) { return assign({other, length}); } /** - * @brief Overwrites the string with the given string. - * @throw `std::length_error` if the string is too long or `pos > str.size()`. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_assign` for a cleaner exception-less alternative. + * @brief Overwrites the string with the given string. + * @throw `std::length_error` if the string is too long or `pos > str.size()`. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_assign` for a cleaner exception-less alternative. */ basic_string &assign(string_view str, size_type pos, size_type count = npos) noexcept(false) { return assign(str.substr(pos, count)); } /** - * @brief Overwrites the string with the given iterator range. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_assign` for a cleaner exception-less alternative. + * @brief Overwrites the string with the given iterator range. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_assign` for a cleaner exception-less alternative. */ template basic_string &assign(input_iterator first, input_iterator last) noexcept(false) { @@ -3241,20 +3226,20 @@ class basic_string { } /** - * @brief Overwrites the string with the given initializer list. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_assign` for a cleaner exception-less alternative. + * @brief Overwrites the string with the given initializer list. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_assign` for a cleaner exception-less alternative. */ basic_string &assign(std::initializer_list list) noexcept(false) { return assign(list.begin(), list.end()); } /** - * @brief Appends to the end of the current string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_append` for a cleaner exception-less alternative. + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_append` for a cleaner exception-less alternative. */ basic_string &append(string_view str) noexcept(false) { if (!try_append(str)) throw std::bad_alloc(); @@ -3262,36 +3247,36 @@ class basic_string { } /** - * @brief Appends to the end of the current string. - * @throw `std::length_error` if the string is too long or `pos > str.size()`. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_append` for a cleaner exception-less alternative. + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long or `pos > str.size()`. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_append` for a cleaner exception-less alternative. */ basic_string &append(string_view str, size_type pos, size_type length = npos) noexcept(false) { return append(str.substr(pos, length)); } /** - * @brief Appends to the end of the current string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_append` for a cleaner exception-less alternative. + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_append` for a cleaner exception-less alternative. */ basic_string &append(const_pointer str, size_type length) noexcept(false) { return append({str, length}); } /** - * @brief Appends to the end of the current string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_append` for a cleaner exception-less alternative. + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_append` for a cleaner exception-less alternative. */ basic_string &append(const_pointer str) noexcept(false) { return append(string_view(str)); } /** - * @brief Appends a repeated character to the end of the current string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_append` for a cleaner exception-less alternative. + * @brief Appends a repeated character to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_append` for a cleaner exception-less alternative. */ basic_string &append(size_type repeats, char_type ch) noexcept(false) { resize(size() + repeats, ch); @@ -3299,20 +3284,20 @@ class basic_string { } /** - * @brief Appends to the end of the current string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_append` for a cleaner exception-less alternative. + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_append` for a cleaner exception-less alternative. */ basic_string &append(std::initializer_list other) noexcept(false) { return append(other.begin(), other.end()); } /** - * @brief Appends to the end of the current string. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. - * @see `try_append` for a cleaner exception-less alternative. + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @sa `try_append` for a cleaner exception-less alternative. */ template basic_string &append(input_iterator first, input_iterator last) noexcept(false) { @@ -3348,16 +3333,16 @@ class basic_string { return result; } - /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ + /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ size_type hash() const noexcept { return view().hash(); } - /** @brief Aggregates the values of individual bytes of a string. */ + /** @brief Aggregates the values of individual bytes of a string. */ size_type bytesum() const noexcept { return view().bytesum(); } /** * @brief Overwrites the string with random binary data. * - * @param nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! + * @param[in] nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! */ basic_string &randomize(sz_u64_t nonce) noexcept { sz_ptr_t start; @@ -3369,8 +3354,11 @@ class basic_string { /** * @brief Overwrites the string with random binary data. - * Produces the nonce from a static variable, incrementing it each time. - * In this case the undefined behaviour in concurrent environments plays in our favor. + * @sa sz_fill_random + * + * This overload produces the nonce from a static variable, incrementing it each time. + * In this case the undefined behaviour in concurrent environments may play in our favor, + * but it's recommended to use the other overload in such cases. */ basic_string &randomize() noexcept { static sz_u64_t nonce = 42; @@ -3378,27 +3366,25 @@ class basic_string { } /** - * @brief Generate a new random string of given length using `std::rand` as the random generator. - * May throw exceptions if the memory allocation fails. - * - * @param length The length of the generated string. - * @param nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! + * @brief Generate a new random binary string of given @p length. + * @param[in] length The length of the generated string. + * @param[in] nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! + * @throw `std::bad_alloc` if the allocation fails. */ static basic_string random(size_type length, sz_u64_t nonce) noexcept(false) { return basic_string(length, '\0').randomize(nonce); } /** - * @brief Generate a new random string of given length using the provided random number generator. - * May throw exceptions if the memory allocation fails. - * - * @param length The length of the generated string. + * @brief Generate a new random binary string of given @p length. + * @param[in] length The length of the generated string. + * @throw `std::bad_alloc` if the allocation fails. */ static basic_string random(size_type length) noexcept(false) { return basic_string(length, '\0').randomize(); } /** - * @brief Replaces @b (in-place) all occurrences of a given string with the ::replacement string. - * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * @brief Replaces @b (in-place) all occurrences of a given string with the ::replacement string. + * @see Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. @@ -3410,8 +3396,8 @@ class basic_string { } /** - * @brief Replaces @b (in-place) all occurrences of a given character set with the ::replacement string. - * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * @brief Replaces @b (in-place) all occurrences of a given character set with the ::replacement string. + * @see Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. @@ -3423,8 +3409,8 @@ class basic_string { } /** - * @brief Replaces @b (in-place) all occurrences of a given string with the ::replacement string. - * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * @brief Replaces @b (in-place) all occurrences of a given string with the ::replacement string. + * @see Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. @@ -3435,8 +3421,8 @@ class basic_string { } /** - * @brief Replaces @b (in-place) all occurrences of a given character set with the ::replacement string. - * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * @brief Replaces @b (in-place) all occurrences of a given character set with the ::replacement string. + * @see Similar to `boost::algorithm::replace_all` and Python's `str.replace`. * * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. @@ -3447,7 +3433,8 @@ class basic_string { } /** - * @brief Replaces @b (in-place) all characters in the string using the provided lookup table. + * @brief Replaces @b (in-place) all characters in the string using the provided lookup @p table. + * @sa sz_lookup */ basic_string &transform(look_up_table const &table) noexcept { transform(table, data()); @@ -3455,8 +3442,9 @@ class basic_string { } /** - * @brief Maps all characters in the current string into another buffer using the provided lookup table. - * @param output The buffer to write the transformed string into. + * @brief Maps all characters in the current string into the @p output buffer using the provided lookup @p table. + * @param[in] output The buffer to write the transformed string into. + * @sa sz_lookup */ void transform(look_up_table const &table, pointer output) const noexcept { sz_ptr_t start; @@ -3470,8 +3458,8 @@ class basic_string { bool try_replace_all_(pattern_type pattern, string_view replacement) noexcept; /** - * @brief Tries to prepare the string for a replacement of a given range with a new string. - * The allocation may occur, if the replacement is longer than the replaced range. + * @brief Tries to prepare the string for a replacement of a given range with a new string. + * @warning A memory allocation may occur, if the replacement is longer than the replaced range. */ bool try_preparing_replacement(size_type offset, size_type length, size_type new_length) noexcept; }; @@ -3738,8 +3726,8 @@ struct string_view_less { }; /** - * @brief Helper function-like object to check equality between string-view convertible objects with StringZilla. - * @see Similar to `std::equal_to`: https://en.cppreference.com/w/cpp/utility/functional/equal_to + * @brief Helper function-like object to check equality between string-view convertible objects with StringZilla. + * @see Similar to `std::equal_to`: https://en.cppreference.com/w/cpp/utility/functional/equal_to * * Unlike the STL analog, doesn't require C++14 or including the heavy `` header. * Can be used to combine STL classes with StringZilla logic, like: @@ -3750,8 +3738,8 @@ struct string_view_equal_to { }; /** - * @brief Helper function-like object to hash string-view convertible objects with StringZilla. - * @see Similar to `std::hash`: https://en.cppreference.com/w/cpp/utility/functional/hash + * @brief Helper function-like object to hash string-view convertible objects with StringZilla. + * @see Similar to `std::hash`: https://en.cppreference.com/w/cpp/utility/functional/hash * * Unlike the STL analog, doesn't require C++14 or including the heavy `` header. * Can be used to combine STL classes with StringZilla logic, like: @@ -3761,7 +3749,7 @@ struct string_view_hash { std::size_t operator()(string_view str) const noexcept { return str.hash(); } }; -/** @brief SFINAE-type used to infer the resulting type of concatenating multiple string together. */ +/** @brief SFINAE-type used to infer the resulting type of concatenating multiple string together. */ template struct concatenation_result {}; @@ -3776,8 +3764,8 @@ struct concatenation_result { }; /** - * @brief Concatenates two strings into a template expression. - * @see `concatenation` class for more details. + * @brief Concatenates two strings into a template expression. + * @sa `concatenation` class for more details. */ template concatenation concatenate(first_type &&first, second_type &&second) noexcept(false) { @@ -3785,8 +3773,8 @@ concatenation concatenate(first_type &&first, second_ty } /** - * @brief Concatenates two or more strings into a template expression. - * @see `concatenation` class for more details. + * @brief Concatenates two or more strings into a template expression. + * @sa `concatenation` class for more details. */ template typename concatenation_result::type concatenate( @@ -3806,8 +3794,8 @@ typename concatenation_result::type } /** - * @brief Calculates the Hamming edit distance in @b bytes between two strings. - * @see sz_levenshtein_distance + * @brief Calculates the Hamming edit distance in @b bytes between two strings. + * @sa sz_levenshtein_distance */ template std::size_t hamming_distance( // @@ -3819,8 +3807,8 @@ std::size_t hamming_distance( } /** - * @brief Calculates the Hamming edit distance in @b bytes between two strings. - * @see sz_levenshtein_distance + * @brief Calculates the Hamming edit distance in @b bytes between two strings. + * @sa sz_levenshtein_distance */ template ::type>> std::size_t hamming_distance( // @@ -3830,8 +3818,8 @@ std::size_t hamming_distance( } /** - * @brief Calculates the Hamming edit distance in @b unicode codepoints between two strings. - * @see sz_hamming_distance_utf8 + * @brief Calculates the Hamming edit distance in @b unicode codepoints between two strings. + * @sa sz_hamming_distance_utf8 */ template std::size_t hamming_distance_utf8( // @@ -3842,8 +3830,8 @@ std::size_t hamming_distance_utf8( // } /** - * @brief Calculates the Hamming edit distance in @b unicode codepoints between two strings. - * @see sz_levenshtein_distance + * @brief Calculates the Hamming edit distance in @b unicode codepoints between two strings. + * @sa sz_levenshtein_distance */ template ::type>> std::size_t hamming_distance_utf8( // @@ -3853,8 +3841,8 @@ std::size_t hamming_distance_utf8( // } /** - * @brief Calculates the Levenshtein edit distance in @b bytes between two strings. - * @see sz_levenshtein_distance + * @brief Calculates the Levenshtein edit distance in @b bytes between two strings. + * @sa sz_levenshtein_distance */ template ::type>> std::size_t edit_distance( // @@ -3870,8 +3858,8 @@ std::size_t edit_distance( // } /** - * @brief Calculates the Levenshtein edit distance in @b bytes between two strings. - * @see sz_levenshtein_distance + * @brief Calculates the Levenshtein edit distance in @b bytes between two strings. + * @sa sz_levenshtein_distance */ template > std::size_t edit_distance( // @@ -3881,8 +3869,8 @@ std::size_t edit_distance( } /** - * @brief Calculates the Levenshtein edit distance in @b unicode codepoints between two strings. - * @see sz_levenshtein_distance_utf8 + * @brief Calculates the Levenshtein edit distance in @b unicode codepoints between two strings. + * @sa sz_levenshtein_distance_utf8 */ template ::type>> std::size_t edit_distance_utf8( // @@ -3898,8 +3886,8 @@ std::size_t edit_distance_utf8( } /** - * @brief Calculates the Levenshtein edit distance in @b unicode codepoints between two strings. - * @see sz_levenshtein_distance_utf8 + * @brief Calculates the Levenshtein edit distance in @b unicode codepoints between two strings. + * @sa sz_levenshtein_distance_utf8 */ template > std::size_t edit_distance_utf8( // @@ -3909,8 +3897,8 @@ std::size_t edit_distance_utf8( } /** - * @brief Calculates the Needleman-Wunsch alignment score between two strings. - * @see sz_needleman_wunsch_score + * @brief Calculates the Needleman-Wunsch alignment score between two strings. + * @sa sz_needleman_wunsch_score */ template ::type>> std::ptrdiff_t alignment_score( // @@ -3932,8 +3920,8 @@ std::ptrdiff_t alignment_score( } /** - * @brief Calculates the Needleman-Wunsch alignment score between two strings. - * @see sz_needleman_wunsch_score + * @brief Calculates the Needleman-Wunsch alignment score between two strings. + * @sa sz_needleman_wunsch_score */ template > std::ptrdiff_t alignment_score( // @@ -3943,10 +3931,10 @@ std::ptrdiff_t alignment_score( } /** - * @brief Overwrites the string slice with random characters from the given alphabet using the random generator. - * - * @param string The string to overwrite. - * @param nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! + * @brief Overwrites the @p string slice with random bytes. + * @param[in] string The string to overwrite. + * @param[in] nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! + * @sa sz_fill_random */ template void randomize(basic_string_slice string, sz_u64_t nonce) noexcept { @@ -3955,7 +3943,9 @@ void randomize(basic_string_slice string, sz_u64_t nonce) noexcept { } /** - * @brief Replaces @b (in-place) all characters in the string using the provided lookup table. + * @brief Overwrites the @p string slice with random bytes using `std::rand` for the nonce. + * @param[in] string The string to overwrite. + * @sa sz_fill_random */ template void lookup(basic_string_slice string, basic_look_up_table const &table) noexcept { @@ -3964,7 +3954,8 @@ void lookup(basic_string_slice string, basic_look_up_table void lookup( // @@ -3975,11 +3966,8 @@ void lookup( // } /** - * @brief Overwrites the string slice with random characters from the given alphabet - * using `std::rand` as the random generator. - * - * @param string The string to overwrite. - * @param alphabet A string of characters to choose from. + * @brief Replaces @b (in-place) all characters in the string using the provided lookup table. + * @sa sz_lookup */ template void randomize(basic_string_slice string, string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { @@ -3989,8 +3977,8 @@ void randomize(basic_string_slice string, string_view alphabet = "ab using sorted_idx_t = sz_sorted_idx_t; /** - * @brief Internal data-structure used to forward the arguments to the `sz_sequence_argsort` function. - * @see argsort + * @brief Internal data-structure used to wrap arbitrary sequential containers with a random-order lookup. + * @sa try_argsort, argsort, try_join, join */ template struct _sequence_args { @@ -4064,8 +4052,8 @@ void hashes_fingerprint( // } /** - * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. - * @see sz_hashes + * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. + * @sa sz_hashes */ template std::bitset hashes_fingerprint( // @@ -4076,8 +4064,8 @@ std::bitset hashes_fingerprint( // } /** - * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. - * @see sz_hashes + * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. + * @sa sz_hashes */ template std::bitset hashes_fingerprint(basic_string const &str, std::size_t window_length) noexcept { From 407dd2de067bea93c1144e35bc8e3a8ab670ef64 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 7 Mar 2025 12:54:55 +0000 Subject: [PATCH 151/751] Docs: Ignore C++ docstring updates blame --- .git-blame-ignore-revs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 3d26edb4..c583f5fb 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -38,3 +38,5 @@ b007ba571860e1d3737d1478c7f8d66ae1839e36 bd547453122e9f8565e5be15f137e7b0de37caca 22e3d1e34d62d68c1e89df7c8bdc201faa18a9de ecb377541d0c706cf8997faff4f026b07e3f76f3 +0d982a45f842287d7e344f0d8b360f52482017f5 + From b6e4406101cda970659c64a9215b7e81072b2168 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:02:30 +0000 Subject: [PATCH 152/751] Docs: Details on the Unicode range --- include/stringzilla/types.h | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index a15cf116..f5658035 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -368,7 +368,8 @@ typedef enum { } sz_status_t; /** - * @brief Describes the length of a UTF8 @b rune / character / codepoint in bytes. + * @brief Describes the length of a UTF-8 @b rune / character / codepoint in bytes, which can be 1 to 4. + * @see https://en.wikipedia.org/wiki/UTF-8 */ typedef enum { sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. @@ -378,6 +379,16 @@ typedef enum { sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. } sz_rune_length_t; +/** + * @brief Stores a single UTF-8 @b rune / character / codepoint unpacked into @b UTF-32. + * @see https://en.wikipedia.org/wiki/UTF-32 + * + * The theoretical capacity of the underlying numeric type is 4 bytes, with over 4 billion possible states, but: + * - UTF-8, however, in its' largest 4-byte form has only 3+6+6+6 = 21 bits of usable space for 2 million states. + * - Unicode, in turn, has only @b 1'114'112 possible code points from U+0000 to U+10FFFF. + * - Of those, in Unicode 16, only @b 155'063 are assigned characters ~ a little over 17 bits of content. + * That's @b 0.004% of the 32-bit space, so sparse data-structures are encouraged for UTF-8 oriented algorithms. + */ typedef sz_u32_t sz_rune_t; /** From 5c02c4edb7b559ed129047fdb6739f2721fc481b Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:07:11 +0000 Subject: [PATCH 153/751] Docs: Formatting --- include/stringzilla/types.h | 148 +++++++++++++++++------------------- 1 file changed, 70 insertions(+), 78 deletions(-) diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index f5658035..6d693347 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -391,8 +391,15 @@ typedef enum { */ typedef sz_u32_t sz_rune_t; +SZ_PUBLIC sz_rune_t sz_rune_perfect_hash(sz_rune_t rune) { + // TODO: A perfect hashing scheme can be constructed to map a 32-bit rune into an 18-bit representation, + // TODO: that can fit all of the unique values in the Unicode 16 standard. + return rune; +} + /** - * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. + * @brief Tiny string-view structure. It's Plain-Old Datatype @b (POD) type, unlike the `std::string_view`. + * @see https://en.cppreference.com/w/cpp/named_req/PODType */ typedef struct sz_string_view_t { sz_cptr_t start; @@ -402,8 +409,8 @@ typedef struct sz_string_view_t { #pragma region Character Sets /** - * @brief Bit-set semi-opaque structure for 256 possible byte values. Useful for filtering and search. - * @sa sz_byteset_init, sz_byteset_add, sz_byteset_contains, sz_byteset_invert + * @brief Bit-set semi-opaque structure for 256 possible byte values. Useful for filtering and search. + * @sa sz_byteset_init, sz_byteset_add, sz_byteset_contains, sz_byteset_invert * * Example usage: * @@ -426,22 +433,22 @@ typedef union sz_byteset_t { sz_u8_t _u8s[32]; } sz_byteset_t; -/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ +/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ SZ_PUBLIC void sz_byteset_init(sz_byteset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } -/** @brief Initializes a bit-set to all ASCII character. */ +/** @brief Initializes a bit-set to all ASCII character. */ SZ_PUBLIC void sz_byteset_init_ascii(sz_byteset_t *s) { s->_u64s[0] = s->_u64s[1] = 0xFFFFFFFFFFFFFFFFull; s->_u64s[2] = s->_u64s[3] = 0; } -/** @brief Adds a character to the set and accepts @b unsigned integers. */ +/** @brief Adds a character to the set and accepts @b unsigned integers. */ SZ_PUBLIC void sz_byteset_add_u8(sz_byteset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } -/** @brief Adds a character to the set. Consider @b sz_byteset_add_u8. */ +/** @brief Adds a character to the set. Consider @b sz_byteset_add_u8. */ SZ_PUBLIC void sz_byteset_add(sz_byteset_t *s, char c) { sz_byteset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast -/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ +/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ SZ_PUBLIC sz_bool_t sz_byteset_contains_u8(sz_byteset_t const *s, sz_u8_t c) { // Checking the bit can be done in different ways: // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 @@ -451,12 +458,12 @@ SZ_PUBLIC sz_bool_t sz_byteset_contains_u8(sz_byteset_t const *s, sz_u8_t c) { return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); } -/** @brief Checks if the set contains a given character. Consider @b sz_byteset_contains_u8. */ +/** @brief Checks if the set contains a given character. Consider @b sz_byteset_contains_u8. */ SZ_PUBLIC sz_bool_t sz_byteset_contains(sz_byteset_t const *s, char c) { return sz_byteset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast } -/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ +/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ SZ_PUBLIC void sz_byteset_invert(sz_byteset_t *s) { s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; @@ -472,7 +479,7 @@ typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); /** * @brief Some complex pattern matching algorithms may require memory allocations. * This structure is used to pass the memory allocator to those functions. - * @see sz_memory_allocator_init_fixed + * @sa sz_memory_allocator_init_fixed */ typedef struct sz_memory_allocator_t { sz_memory_allocate_t allocate; @@ -481,21 +488,17 @@ typedef struct sz_memory_allocator_t { } sz_memory_allocator_t; /** - * @brief Initializes a memory allocator to use the system default `malloc` and `free`. - * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. - * - * @param alloc Memory allocator to initialize. + * @brief Initializes a memory allocator to use the system default `malloc` and `free`. + * @warning The function is not available if the library was compiled with `SZ_AVOID_LIBC`. + * @param[in] alloc Memory allocator to initialize. */ SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); /** - * @brief Initializes a memory allocator to use a static-capacity buffer. - * No dynamic allocations will be performed. - * - * @param alloc Memory allocator to initialize. - * @param buffer Buffer to use for allocations. - * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for - * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. + * @brief Initializes a memory allocator to use only a static-capacity buffer @b w/out any dynamic allocations. + * @param[in] alloc Memory allocator to initialize. + * @param[in] buffer Buffer to use for allocations. + * @param[in] length Length of the buffer. @b Must be greater than 8, at least 4KB (one RAM page) is recommended. */ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); @@ -503,66 +506,66 @@ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void #pragma region API Signature Types -/** @brief Signature of `sz_hash`. */ +/** @brief Signature of `sz_hash`. */ typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t, sz_u64_t); -/** @brief Signature of `sz_hash_state_init`. */ +/** @brief Signature of `sz_hash_state_init`. */ typedef void (*sz_hash_state_init_t)(struct sz_hash_state_t *, sz_u64_t); -/** @brief Signature of `sz_hash_state_stream`. */ +/** @brief Signature of `sz_hash_state_stream`. */ typedef void (*sz_hash_state_stream_t)(struct sz_hash_state_t *, sz_cptr_t, sz_size_t); -/** @brief Signature of `sz_hash_state_fold`. */ +/** @brief Signature of `sz_hash_state_fold`. */ typedef sz_u64_t (*sz_hash_state_fold_t)(struct sz_hash_state_t const *); -/** @brief Signature of `sz_bytesum`. */ +/** @brief Signature of `sz_bytesum`. */ typedef sz_u64_t (*sz_bytesum_t)(sz_cptr_t, sz_size_t); -/** @brief Signature of `sz_fill_random`. */ +/** @brief Signature of `sz_fill_random`. */ typedef void (*sz_fill_random_t)(sz_ptr_t, sz_size_t, sz_u64_t); -/** @brief Signature of `sz_equal`. */ +/** @brief Signature of `sz_equal`. */ typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); -/** @brief Signature of `sz_order`. */ +/** @brief Signature of `sz_order`. */ typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -/** @brief Signature of `sz_lookup`. */ +/** @brief Signature of `sz_lookup`. */ typedef void (*sz_lookup_t)(sz_ptr_t, sz_size_t, sz_cptr_t, sz_cptr_t); -/** @brief Signature of `sz_move`. */ +/** @brief Signature of `sz_move`. */ typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); -/** @brief Signature of `sz_fill`. */ +/** @brief Signature of `sz_fill`. */ typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); -/** @brief Signature of `sz_find_byte`. */ +/** @brief Signature of `sz_find_byte`. */ typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); -/** @brief Signature of `sz_find`. */ +/** @brief Signature of `sz_find`. */ typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); -/** @brief Signature of `sz_find_set`. */ +/** @brief Signature of `sz_find_set`. */ typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_byteset_t const *); -/** @brief Signature of `sz_hamming_distance`. */ +/** @brief Signature of `sz_hamming_distance`. */ typedef sz_status_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_size_t *); -/** @brief Signature of `sz_levenshtein_distance`. */ +/** @brief Signature of `sz_levenshtein_distance`. */ typedef sz_status_t (*sz_levenshtein_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *, sz_size_t *); -/** @brief Signature of `sz_needleman_wunsch_score`. */ +/** @brief Signature of `sz_needleman_wunsch_score`. */ typedef sz_status_t (*sz_needleman_wunsch_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, sz_error_cost_t, sz_memory_allocator_t *, sz_ssize_t *); -/** @brief Signature of `sz_sequence_argsort`. */ +/** @brief Signature of `sz_sequence_argsort`. */ typedef sz_status_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_memory_allocator_t *, sz_sorted_idx_t *); -/** @brief Signature of `sz_pgrams_sort`. */ +/** @brief Signature of `sz_pgrams_sort`. */ typedef sz_status_t (*sz_pgrams_sort_t)(sz_pgram_t *, sz_size_t, sz_memory_allocator_t *, sz_sorted_idx_t *); -/** @brief Signature of `sz_sequence_join`. */ +/** @brief Signature of `sz_sequence_join`. */ typedef sz_status_t (*sz_sequence_join_t)(struct sz_sequence_t const *, struct sz_sequence_t const *, sz_memory_allocator_t *, sz_size_t *, sz_sorted_idx_t *, sz_sorted_idx_t *); @@ -571,8 +574,8 @@ typedef sz_status_t (*sz_sequence_join_t)(struct sz_sequence_t const *, struct s #pragma region Helper Structures /** - * @brief Helper structure to simplify work with 16-bit words. - * @see sz_u16_load + * @brief Helper structure to simplify work with 16-bit words. + * @sa sz_u16_load */ typedef union sz_u16_vec_t { sz_u16_t u16; @@ -580,8 +583,8 @@ typedef union sz_u16_vec_t { } sz_u16_vec_t; /** - * @brief Helper structure to simplify work with 32-bit words. - * @see sz_u32_load + * @brief Helper structure to simplify work with 32-bit words. + * @sa sz_u32_load */ typedef union sz_u32_vec_t { sz_u32_t u32; @@ -590,8 +593,8 @@ typedef union sz_u32_vec_t { } sz_u32_vec_t; /** - * @brief Helper structure to simplify work with 64-bit words. - * @see sz_u64_load + * @brief Helper structure to simplify work with 64-bit words. + * @sa sz_u64_load */ typedef union sz_u64_vec_t { sz_u64_t u64; @@ -662,9 +665,7 @@ typedef union sz_u512_vec_t { #pragma region UTF8 -/** - * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. - */ +/** @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. */ SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { sz_u8_t const *current = (sz_u8_t const *)utf8; sz_u8_t leading_byte = *current++; @@ -708,8 +709,8 @@ SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_ } /** - * @brief Exports a UTF8 string into a UTF32 buffer. - * ! The result is undefined id the UTF8 string is corrupted. + * @brief Exports a UTF8 string into a UTF32 buffer. + * @warning The result is undefined id the UTF8 string is corrupted. * @return The length in the number of codepoints. */ SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { @@ -771,14 +772,10 @@ SZ_PUBLIC void sz_sequence_from_null_terminated_strings(sz_cptr_t *start, sz_siz ********************************************************************************************************************** */ -/** - * @brief Helper-macro to mark potentially unused variables. - */ +/** @brief Helper-macro to mark potentially unused variables. */ #define sz_unused(x) ((void)(x)) -/** - * @brief Helper-macro casting a variable to another type of the same size. - */ +/** @brief Helper-macro casting a variable to another type of the same size. */ #define sz_bitcast(type, value) (*((type *)&(value))) /** @@ -1024,7 +1021,8 @@ SZ_INTERNAL void sz_ssize_clamp_interval( // } /** - * @brief Compute the logarithm base 2 of a positive integer, rounding down. + * @brief Compute the logarithm base 2 of a positive integer, rounding down. + * @pre Input must be a positive number, as the logarithm of zero is undefined. */ SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { _sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); @@ -1033,11 +1031,11 @@ SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { } /** - * @brief Compute the smallest power of two greater than or equal to @p x. + * @brief Compute the smallest power of two greater than or equal to @p x. + * @pre Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`, @b including 0. + * @see https://stackoverflow.com/a/10143264 */ SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { - // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. - // https://stackoverflow.com/a/10143264 x--; x |= x >> 1; x |= x >> 2; @@ -1052,7 +1050,7 @@ SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { } /** - * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. + * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. * * There is a well known SWAR sequence for that known to chess programmers, * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. @@ -1070,9 +1068,7 @@ SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { return x; } -/** - * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ +/** @brief Load a 16-bit unsigned integer from a potentially unaligned pointer. Can be expensive on some platforms. */ SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { #if !SZ_USE_MISALIGNED_LOADS sz_u16_vec_t result; @@ -1080,7 +1076,7 @@ SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { result.u8s[1] = ptr[1]; return result; #elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. +#if defined(_M_IX86) //< The `__unaligned` modifier isn't valid for the x86 platform. return *((sz_u16_vec_t *)ptr); #else return *((__unaligned sz_u16_vec_t *)ptr); @@ -1091,9 +1087,7 @@ SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { #endif } -/** - * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ +/** @brief Load a 32-bit unsigned integer from a potentially unaligned pointer. Can be expensive on some platforms. */ SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { #if !SZ_USE_MISALIGNED_LOADS sz_u32_vec_t result; @@ -1103,7 +1097,7 @@ SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { result.u8s[3] = ptr[3]; return result; #elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. +#if defined(_M_IX86) //< The `__unaligned` modifier isn't valid for the x86 platform. return *((sz_u32_vec_t *)ptr); #else return *((__unaligned sz_u32_vec_t *)ptr); @@ -1114,9 +1108,7 @@ SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { #endif } -/** - * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. - */ +/** @brief Load a 64-bit unsigned integer from a potentially unaligned pointer. Can be expensive on some platforms. */ SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { #if !SZ_USE_MISALIGNED_LOADS sz_u64_vec_t result; @@ -1130,7 +1122,7 @@ SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { result.u8s[7] = ptr[7]; return result; #elif defined(_MSC_VER) && !defined(__clang__) -#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. +#if defined(_M_IX86) //< The `__unaligned` modifier isn't valid for the x86 platform. return *((sz_u64_vec_t *)ptr); #else return *((__unaligned sz_u64_vec_t *)ptr); @@ -1141,7 +1133,7 @@ SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { #endif } -/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ +/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { sz_size_t capacity; *(sz_ptr_t)&capacity = *(sz_cptr_t)handle; @@ -1150,7 +1142,7 @@ SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { return (sz_ptr_t)handle + consumed_capacity; } -/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ +/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { sz_unused(start && length && handle); } From 8dc4a2c70fb84f87c9e573579dc496a91b5f22e4 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:07:51 +0000 Subject: [PATCH 154/751] Fix: Randomization benchmarks --- scripts/bench_token.cpp | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 0d83604b..112fbc98 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -95,27 +95,25 @@ tracked_unary_functions_t hash_stream_functions() { return result; } -tracked_unary_functions_t random_generation_functions(std::size_t token_length) { +tracked_unary_functions_t random_generation_functions() { static std::vector buffer; - if (buffer.size() < token_length) buffer.resize(token_length); - - auto suffix = ", " + std::to_string(token_length) + " chars"; tracked_unary_functions_t result = { - {"std::rand % uint8" + suffix, unary_function_t([token_length](std::string_view alphabet) -> std::size_t { - using max_alphabet_size_t = std::uint8_t; - auto max_alphabet_size = static_cast(alphabet.size()); - for (std::size_t i = 0; i < token_length; ++i) { buffer[i] = alphabet[std::rand() % max_alphabet_size]; } - return token_length; + {"std::rand() & 0xFF", unary_function_t([](std::string_view token) -> std::size_t { + if (buffer.size() < token.size()) buffer.resize(token.size()); + for (std::size_t i = 0; i < token.size(); ++i) buffer[i] = static_cast(std::rand() & 0xFF); + return token.size(); + })}, + {"std::uniform_int", unary_function_t([](std::string_view token) -> std::size_t { + if (buffer.size() < token.size()) buffer.resize(token.size()); + randomize_string(buffer.data(), token.size()); + return token.size(); })}, - {"std::uniform_int" + suffix, unary_function_t([token_length](std::string_view alphabet) -> std::size_t { - randomize_string(buffer.data(), token_length, alphabet.data(), alphabet.size()); - return token_length; + {"sz::randomize", unary_function_t([](std::string_view token) -> std::size_t { + if (buffer.size() < token.size()) buffer.resize(token.size()); + sz::string_span span(buffer.data(), token.size()); + sz::fill_random(span); + return token.size(); })}, - // {"sz::randomize" + suffix, unary_function_t([token_length](std::string_view alphabet) -> std::size_t { - // sz::string_span span(buffer.data(), token_length); - // sz::randomize(span, global_random_generator(), alphabet); - // return token_length; - // })}, }; return result; } @@ -123,11 +121,11 @@ tracked_unary_functions_t random_generation_functions(std::size_t token_length) tracked_binary_functions_t equality_functions() { auto wrap_sz = [](auto function) -> binary_function_t { return binary_function_t([function](std::string_view a, std::string_view b) { - return (a.size() == b.size() && function(a.data(), b.data(), a.size())); + return a.size() == b.size() && function(a.data(), b.data(), a.size()); }); }; tracked_binary_functions_t result = { - {"std::string_view.==", [](std::string_view a, std::string_view b) { return (a == b); }}, + {"std::string_view.==", [](std::string_view a, std::string_view b) { return a == b; }}, {"sz_equal_serial", wrap_sz(sz_equal_serial), true}, #if SZ_USE_HASWELL {"sz_equal_haswell", wrap_sz(sz_equal_haswell), true}, @@ -190,6 +188,7 @@ void bench(strings_type &&strings) { bench_unary_functions(strings, hash_stream_functions()); bench_binary_functions(strings, equality_functions()); bench_binary_functions(strings, ordering_functions()); + bench_unary_functions(strings, random_generation_functions()); // Benchmark the cost of converting `std::string` and `sz::string` to `std::string_view`. // ! The results on a mixture of short and long strings should be similar. @@ -208,7 +207,9 @@ void bench_on_input_data(int argc, char const **argv) { std::printf("Benchmarking on real lines:\n"); bench(dataset.lines); std::printf("Benchmarking on entire dataset:\n"); - bench>({dataset.text}); + bench_unary_functions>({dataset.text}, bytesum_functions()); + bench_unary_functions>({dataset.text}, hash_functions()); + bench_unary_functions>({dataset.text}, hash_stream_functions()); // Run benchmarks on tokens of different length for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { From 1d956019b42c0df9ef32cd546911c931f37febad Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:08:24 +0000 Subject: [PATCH 155/751] Improve: Test set intersections --- scripts/test.cpp | 150 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 127 insertions(+), 23 deletions(-) diff --git a/scripts/test.cpp b/scripts/test.cpp index a0eac08e..63452df4 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -49,6 +49,8 @@ #include // `std::random_device` #include // `std::ostringstream` #include // `std::unordered_map` +#include // `std::unordered_set` +#include // `std::set` #include // `std::vector` #include // Baseline @@ -149,6 +151,31 @@ static void test_arithmetical_utilities() { #endif } +static void test_structural_utilities() { + // Make sure the sequence helper functions work as expected + // for both trivial c-style arrays and + { + sz_sequence_t sequence; + sz_cptr_t strings[] = {"banana", "apple", "cherry"}; + sz_sequence_from_null_terminated_strings(strings, 3, &sequence); + assert(sequence.count == 3); + assert("banana"_sv == sequence.get_start(sequence.handle, 0)); + assert("apple"_sv == sequence.get_start(sequence.handle, 1)); + assert("cherry"_sv == sequence.get_start(sequence.handle, 2)); + } + + // sz_memory_allocator_init_default; + // sz_memory_allocator_init_fixed; + // _sz_extract_utf8_rune; + // sz_byteset_init; + // sz_byteset_init_ascii; + // sz_byteset_add_u8; + // sz_byteset_add; + // sz_byteset_contains_u8; + // sz_byteset_contains; + // sz_byteset_invert; +} + /** * @brief Hashes a string and compares the output between a serial and hardware-specific SIMD backend. * @@ -437,14 +464,14 @@ static void test_memory_utilities( // } #define assert_scoped(init, operation, condition) \ - { \ + do { \ init; \ operation; \ assert(condition); \ - } + } while (0) #define assert_throws(expression, exception_type) \ - { \ + do { \ bool threw = false; \ try { \ sz_unused(expression); \ @@ -453,7 +480,7 @@ static void test_memory_utilities( // threw = true; \ } \ assert(threw); \ - } + } while (0) /** * @brief Invokes different C++ member methods of immutable strings to cover all STL APIs. @@ -1684,9 +1711,10 @@ static void test_levenshtein_distances() { /** * Evaluates the correctness of look-up table transforms using random lookup tables. * - * @param misalignment The number of bytes to misalign the haystack within the cacheline. + * @param lookup_tables_to_try The number of random lookup tables to try. + * @param slices_per_table The number of random inputs to test per lookup table. */ -void test_replacements(std::size_t lookup_tables_to_try = 128, std::size_t slices_per_table = 256) { +void test_replacements(std::size_t lookup_tables_to_try = 32, std::size_t slices_per_table = 16) { std::string body, transformed; body.resize(1024 * 1024); // 1MB @@ -1712,23 +1740,19 @@ void test_replacements(std::size_t lookup_tables_to_try = 128, std::size_t slice } /** - * @brief Tests sorting functionality. + * @brief Tests array sorting functionality, such as `argsort`, `sort`, and `sorted`. + * + * Tries to sort incrementally complex inputs, such as strings of varying lengths, with many equal inputs. + * 1. Basic tests with predetermined orders. + * 2. Test on long strings of identical length. + * 3. Test on random very small strings of varying lengths, likely with many equal inputs. + * 4. Test on random strings of varying lengths. + * 5. Test on random strings of varying lengths with zero characters. */ -static void test_sequence_algorithms() { +static void test_sorting_algorithms() { using strs_t = std::vector; using order_t = std::vector; - // Make sure teh helper functions work as expected. - { - sz_sequence_t sequence; - sz_cptr_t strings[] = {"banana", "apple", "cherry"}; - sz_sequence_from_null_terminated_strings(strings, 3, &sequence); - assert(sequence.count == 3); - assert("banana"_sv == sequence.get_start(sequence.handle, 0)); - assert("apple"_sv == sequence.get_start(sequence.handle, 1)); - assert("cherry"_sv == sequence.get_start(sequence.handle, 2)); - } - // Basic tests with predetermined orders. assert_scoped(strs_t x({"a", "b", "c", "d"}), (void)0, sz::argsort(x) == order_t({0u, 1u, 2u, 3u})); assert_scoped(strs_t x({"b", "c", "d", "a"}), (void)0, sz::argsort(x) == order_t({3u, 0u, 1u, 2u})); @@ -1796,6 +1820,84 @@ static void test_sequence_algorithms() { } } +/** + * @brief Tests array intersection functionality. + */ +static void test_intersecting_algorithms() { + using strs_t = std::vector; + using result_t = sz::intersect_result_t; + + // The mapping aren't guaranteed to be in any specific order, so we will sort them for comparisons. + using idx_pair_t = std::pair; + using idx_pairs_t = std::set; + auto to_pairs = [](result_t const &result) -> idx_pairs_t { + idx_pairs_t pairs; + for (std::size_t i = 0; i < result.first_offsets.size(); ++i) + pairs.insert({result.first_offsets[i], result.second_offsets[i]}); + return pairs; + }; + + // Predetermined simple cases + { + strs_t abcd({"a", "b", "c", "d"}); + strs_t dcba({"d", "c", "b", "a"}); + strs_t abs({"a", "b", "s"}); + strs_t empty; + result_t result; + // Empty sets + { + result = sz::intersect(empty, empty); + assert(result.first_offsets.size() == 0 && result.second_offsets.size() == 0); + result = sz::intersect(abcd, empty); + assert(result.first_offsets.size() == 0 && result.second_offsets.size() == 0); + } + // Identity check + { + result = sz::intersect(abcd, abcd); + assert(result.first_offsets.size() == 4 && result.second_offsets.size() == 4); + assert(to_pairs(result) == idx_pairs_t({{0u, 0u}, {1u, 1u}, {2u, 2u}, {3u, 3u}})); + } + // Identical size, different order + { + result = sz::intersect(abcd, dcba); + assert(result.first_offsets.size() == 4 && result.second_offsets.size() == 4); + assert(to_pairs(result) == idx_pairs_t({{0u, 3u}, {1u, 2u}, {2u, 1u}, {3u, 0u}})); + } + // Different sets + { + result = sz::intersect(abcd, abs); + assert(result.first_offsets.size() == 2 && result.second_offsets.size() == 2); + assert(to_pairs(result) == idx_pairs_t({{0u, 0u}, {1u, 1u}})); + } + } + + // Generate random strings + struct { + std::size_t min_length; + std::size_t max_length; + std::size_t count_strings; + } experiments[] = { + {10, 10, 100}, + {15, 15, 1000}, + {5, 30, 2000}, + }; + for (auto experiment : experiments) { + std::unordered_set random_strings; + while (random_strings.size() < experiment.count_strings) + random_strings.insert(sz::scripts::random_string( + experiment.min_length + std::rand() % (experiment.max_length - experiment.min_length + 1), // + "ab", 2)); + + strs_t all_strings(random_strings.begin(), random_strings.end()); + strs_t first_half(all_strings.begin(), all_strings.begin() + all_strings.size() / 2); + + // Try different joins + result_t result; + result = sz::intersect(all_strings, first_half); + assert(result.first_offsets.size() == first_half.size() && result.second_offsets.size() == first_half.size()); + } +} + /** * @brief Tests constructing STL containers with StringZilla strings. */ @@ -1824,8 +1926,14 @@ int main(int argc, char const **argv) { // Basic utilities test_arithmetical_utilities(); + test_structural_utilities(); test_simd_against_serial(); + // Sequences of strings + test_sorting_algorithms(); + test_intersecting_algorithms(); + test_stl_containers(); + // Core APIs test_ascii_utilities(); test_ascii_utilities(); @@ -1862,10 +1970,6 @@ int main(int argc, char const **argv) { test_search_with_misaligned_repetitions(); #endif - // Sequences of strings - test_sequence_algorithms(); - test_stl_containers(); - std::printf("All tests passed... Unbelievable!\n"); return 0; } From de62723158719f184d9d81db3ac786e47883f391 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:08:44 +0000 Subject: [PATCH 156/751] Add: Feature-extraction placeholder --- include/stringzilla/features.h | 134 +++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 include/stringzilla/features.h diff --git a/include/stringzilla/features.h b/include/stringzilla/features.h new file mode 100644 index 00000000..1389a4cf --- /dev/null +++ b/include/stringzilla/features.h @@ -0,0 +1,134 @@ +/** + * @brief Hardware-accelerated feature extractions for string collections. + * @file features.h + * @author Ash Vardanian + * + * The `sklearn.feature_extraction` module for @b TF-IDF, "CountVectorizer", and "HashingVectorizer" + * is one of the most commonly used in the industry due to its extreme flexibility. It can: + * + * - Tokenize by words, N-grams, or in-word N-grams. + * - Use arbitrary Regular Expressions as word separators. + * - Return matrices of different types, normalized or not. + * - Exclude "stop words" and remove ASCII and Unicode accents. + * - Dynamically build a vocabulary or use a fixed list/dictionary. + * + * That level of flexibility is not feasible for a hardware-accelerated SIMD library, but we + * can provide a set of APIs that can be used to build such a library on top of StringZilla. + * That functionality will reuse our @b Trie data-structure for vocabulary building histograms. + * + */ +#ifndef STRINGZILLA_FEATURES_H_ +#define STRINGZILLA_FEATURES_H_ + +#include "types.h" + +#include "compare.h" // `sz_compare` +#include "memory.h" // `sz_copy` + +#ifdef __cplusplus +extern "C" { +#endif + +#pragma region Core API + +/** + * @brief Faster @b arg-sort for an arbitrary @b string sequence, using QuickSort. + * Outputs the @p order of elements in the immutable @p sequence, that would sort it. + * + * @param[in] sequence Immutable sequence of strings to sort. + * @param[in] alloc Optional memory allocator for temporary storage. + * @param[out] order Output permutation that sorts the elements. + * + * @retval `sz_success_k` if the operation was successful. + * @retval `sz_bad_alloc_k` if the operation failed due to memory allocation failure. + * @pre The @p order array must fit at least `sequence->count` integers. + * @post The @p order array will contain a valid permutation of `[0, sequence->count - 1]`. + * + * Example usage: + * + * @code{.c} + * #include + * int main() { + * char const *strings[] = {"banana", "apple", "cherry"}; + * sz_sequence_t sequence; + * sz_sequence_from_null_terminated_strings(strings, 3, &sequence); + * sz_sorted_idx_t order[3]; + * sz_status_t status = sz_sequence_argsort(&sequence, NULL, order); + * return status == sz_success_k && order[0] == 1 && order[1] == 0 && order[2] == 2 ? 0 : 1; + * } + * @endcode + * + * @note The algorithm has linear memory complexity, quadratic worst-case and log-linear average time complexity. + * @see https://en.wikipedia.org/wiki/Quicksort + * + * @note This algorithm is @b unstable: equal elements may change relative order. + * @sa sz_sequence_argsort_stabilize + * + * @note Selects the fastest implementation at compile- or run-time based on `SZ_DYNAMIC_DISPATCH`. + * @sa sz_sequence_argsort_serial, sz_sequence_argsort_skylake, sz_sequence_argsort_sve + */ +SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order); + +enum sz_encoding_t { + sz_encoding_unknown_k = 0, + sz_encoding_ascii_k = 1, + sz_encoding_utf8_k = 2, + sz_encoding_utf16_k = 3, + sz_encoding_utf32_k = 4, + sz_encoding_jwt_k = 5, + sz_encoding_base64_k = 6, + // Low priority encodings: + sz_encoding_utf8bom_k = 7, + sz_encoding_utf16le_k = 8, + sz_encoding_utf16be_k = 9, + sz_encoding_utf32le_k = 10, + sz_encoding_utf32be_k = 11, +}; + +// Character Set Detection is one of the most commonly performed operations in data processing with +// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), +// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. +// All of them are notoriously slow. +// +// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. +// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): +// - ISO-8859-1: 1.2% +// - Windows-1252: 0.3% +// - Windows-1251: 0.2% +// - EUC-JP: 0.1% +// - Shift JIS: 0.1% +// - EUC-KR: 0.1% +// - GB2312: 0.1% +// - Windows-1250: 0.1% +// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings +// are also very popular and we need a way to efficiently differentiate between the most common UTF flavors, ASCII, and +// the rest. +// +// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime +// and focuses more on incremental validation & transcoding, rather than detection. +// +// So we need a very fast and efficient way of determining +SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { + // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp + // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 + // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 + // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 + + // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory + // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints + // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte + // codepoints. In the case of emojis, we deal with 4-byte codepoints. + // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. + int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; + sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); + sz_unused(text && length); + return sz_false_k; +} + +#pragma endregion // Core API + +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // STRINGZILLA_FEATURES_H_ From 197cd8719017f9fd43d1f4327e0d3f5fed477340 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:09:21 +0000 Subject: [PATCH 157/751] Docs: Exploring perfect Unicode hashing --- ...shtein.ipynb => explore_levenshtein.ipynb} | 0 scripts/explore_unicode.ipynb | 646 ++++++++++++++++++ 2 files changed, 646 insertions(+) rename scripts/{test_levenshtein.ipynb => explore_levenshtein.ipynb} (100%) create mode 100644 scripts/explore_unicode.ipynb diff --git a/scripts/test_levenshtein.ipynb b/scripts/explore_levenshtein.ipynb similarity index 100% rename from scripts/test_levenshtein.ipynb rename to scripts/explore_levenshtein.ipynb diff --git a/scripts/explore_unicode.ipynb b/scripts/explore_unicode.ipynb new file mode 100644 index 00000000..86af3fd6 --- /dev/null +++ b/scripts/explore_unicode.ipynb @@ -0,0 +1,646 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Unicode and Perfect Hashing\n", + "\n", + "Generalizing StringZilla from byte-strings to UTF-8 strings requires a deep understanding of Unicode.\n", + "This notebook is a playground to explore Unicode and UTF-8 encoding.\n", + "Most importantly it provides a snippet for finding the perfect-hash for unicode, which allows us to produce more efficient histograms and lookup tables for unicode characters.\n", + "That cab be a constituent part of any UTF-8-aware text-processing algorithm, be it Levenshtein automata or distance calculation, Aho-Corasick automata, or high-level NLP tasks, like feature extraction or text classification." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -q numba numpy tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from numba import jit as njit\n", + "from tqdm import tqdm\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloaded 10,047,830 bytes of UCD XML data\n", + "Files in zip: ['ucd.all.flat.xml']\n" + ] + } + ], + "source": [ + "import urllib.request\n", + "import io\n", + "import zipfile\n", + "import xml.etree.ElementTree as ET\n", + "\n", + "# URL for the latest UCD XML archive (flattened)\n", + "ucd_zip_url = \"https://www.unicode.org/Public/UCD/latest/ucdxml/ucd.all.flat.zip\"\n", + "\n", + "# Download the ZIP file\n", + "with urllib.request.urlopen(ucd_zip_url) as response:\n", + " zip_data = response.read()\n", + "print(f\"Downloaded {len(zip_data):,} bytes of UCD XML data\")\n", + "\n", + "# Read the ZIP file from memory\n", + "zip_bytes = io.BytesIO(zip_data)\n", + "with zipfile.ZipFile(zip_bytes) as zf:\n", + " # List files in the zip archive (typically one XML file)\n", + " file_list = zf.namelist()\n", + " print(\"Files in zip:\", file_list)\n", + " # Assuming the first file is the desired XML file\n", + " xml_filename = file_list[0]\n", + " with zf.open(xml_filename) as xml_file:\n", + " # Parse the XML file\n", + " tree = ET.parse(xml_file)\n", + " root = tree.getroot()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The XML structure contains a `` element with many `` elements.\n", + "Each `` element has attributes, including:\n", + "\n", + "- `'cp'`: the code point (as hexadecimal)\n", + "- `'na'`: the character name\n", + "- `'gc'`: the general category" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total code points processed (after expanding ranges): 155,063\n" + ] + } + ], + "source": [ + "# Use a namespace-agnostic search for all elements ending with 'char'\n", + "chars = [elem for elem in root.iter() if elem.tag.endswith('char')]\n", + "\n", + "# List to hold all characters (expanded ranges)\n", + "all_chars = []\n", + "\n", + "def process_char(elem):\n", + " \"\"\"\n", + " Process a element, handling all individual code points (cp)\n", + " and ignoring ranges (first-cp and last-cp). Appends each code point to all_chars.\n", + " \"\"\"\n", + " if 'cp' in elem.attrib:\n", + " cp = int(elem.attrib['cp'], 16)\n", + " entry = {\n", + " 'cp': cp,\n", + " 'name': elem.attrib.get('na', '').strip(),\n", + " 'gc': elem.attrib.get('gc', '').strip(),\n", + " 'age': elem.attrib.get('age', '').strip()\n", + " # You can add pull attributes here if needed.\n", + " }\n", + " all_chars.append(entry)\n", + "\n", + "# Process every 'char' element found\n", + "for elem in chars:\n", + " process_char(elem)\n", + "\n", + "print(f\"Total code points processed (after expanding ranges): {len(all_chars):,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Unicode standard defines a range of 1,114,112 possible code points (from U+0000 to U+10FFFF), but only a subset of these are actually assigned characters or have specific property data.\n", + "As of Unicode version 16.0, there are 155,063 characters with code points, covering 168 modern and historical scripts, as well as multiple symbol sets, split into [338 blocks](https://en.wikipedia.org/wiki/Unicode_block).\n", + "Let's random sample and print a few:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Example symbols:\n", + "U+0C15: TELUGU LETTER KA (Lo)\n", + "U+1016: MYANMAR LETTER PHA (Lo)\n", + "U+98B4: CJK UNIFIED IDEOGRAPH-# (Lo)\n", + "U+2BE19: CJK UNIFIED IDEOGRAPH-# (Lo)\n", + "U+297D: RIGHT FISH TAIL (Sm)\n", + "U+3D9A: CJK UNIFIED IDEOGRAPH-# (Lo)\n", + "U+2F95: KANGXI RADICAL VALLEY (So)\n", + "U+9527: CJK UNIFIED IDEOGRAPH-# (Lo)\n", + "U+16F6E: MIAO VOWEL SIGN UU (Mc)\n", + "U+28B5F: CJK UNIFIED IDEOGRAPH-# (Lo)\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "random_chars = random.sample(all_chars, 10)\n", + "\n", + "print(\"Example symbols:\")\n", + "for char in random_chars:\n", + " cp = char['cp']\n", + " na = char['name']\n", + " gc = char['gc']\n", + " print(f\"U+{cp:04X}: {na} ({gc})\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A natural question can be asked, is that set of codepoints dense or does it contain holes?" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Highest code point: U+E01EF (917,999)\n", + "Number of holes: 762,936\n" + ] + } + ], + "source": [ + "highest_code_point = max(char['cp'] for char in all_chars)\n", + "print(f\"Highest code point: U+{highest_code_point:04X} ({highest_code_point:,})\")\n", + "count_holes = highest_code_point - len(all_chars)\n", + "print(f\"Number of holes: {count_holes:,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The presence of holes means, that simply using the code-point itself as a lookup index would result in a significant \"memory amplification\" factor, lower data locality, and very uneven distribution of data." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory amplification: 5.9\n" + ] + } + ], + "source": [ + "memory_amplification = 1.0 * highest_code_point / len(all_chars)\n", + "print(f\"Memory amplification: {memory_amplification:.1f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For various hash-functions, we may want to find the smallest buffer size that results in no collisions.\n", + "Moreover, assuming how small code-points can be, we would prefer hash-functions that only rely on 32-bit arithmetic and avoid expensive operations.\n", + "We may want to start by using a power-of-two hash-table size, as the final stage of the hash-function can be a simple bitwise-and operation.\n", + "\n", + "- $2^{17} = 131072$ is the closes power of two to the number of code-points.\n", + "- $2^{18} = 262144$ is the next power of two - the first one that fits all code-points.\n", + "\n", + "The latter would still have a 69% memory amplifications factor with only 59% of the slots filled." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's export all code-points to a flat NumPy array and for efficiency, calculate all hashes at once." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage for code points: 620,252 bytes\n" + ] + } + ], + "source": [ + "code_points = np.array([char['cp'] for char in all_chars], dtype=np.uint32)\n", + "print(f\"Memory usage for code points: {code_points.nbytes:,} bytes\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# ---------------------------------------------------------------------\n", + "# 1. Jenkins One-at-a-Time Hash\n", + "# ---------------------------------------------------------------------\n", + "def hash_all_jenkins(code_points: np.ndarray) -> np.ndarray:\n", + " # Ensure input is np.uint32.\n", + " code_points = code_points.astype(np.uint32)\n", + " h = np.zeros_like(code_points, dtype=np.uint32)\n", + " # Process each of the 4 bytes of the 32-bit integer.\n", + " for shift in (0, 8, 16, 24):\n", + " # Extract one byte at a time.\n", + " b = (code_points >> shift) & np.uint32(0xFF)\n", + " h = (h + b) & np.uint32(0xFFFFFFFF)\n", + " h = (h + (h << np.uint32(10))) & np.uint32(0xFFFFFFFF)\n", + " h = (h ^ (h >> np.uint32(6))) & np.uint32(0xFFFFFFFF)\n", + " h = (h + (h << np.uint32(3))) & np.uint32(0xFFFFFFFF)\n", + " h = (h ^ (h >> np.uint32(11))) & np.uint32(0xFFFFFFFF)\n", + " h = (h + (h << np.uint32(15))) & np.uint32(0xFFFFFFFF)\n", + " return h\n", + "\n", + "# ---------------------------------------------------------------------\n", + "# 2. FNV-1a Hash (32-bit)\n", + "# ---------------------------------------------------------------------\n", + "def hash_all_fnv1a(code_points: np.ndarray) -> np.ndarray:\n", + " # FNV-1a 32-bit parameters\n", + " FNV_offset = np.uint32(0x811C9DC5)\n", + " FNV_prime = np.uint32(16777619)\n", + " code_points = code_points.astype(np.uint32)\n", + " h = np.full_like(code_points, FNV_offset, dtype=np.uint32)\n", + " # Process each of the 4 bytes\n", + " for shift in (0, 8, 16, 24):\n", + " byte = (code_points >> shift) & np.uint32(0xFF)\n", + " h = h ^ byte\n", + " h = (h * FNV_prime) & np.uint32(0xFFFFFFFF)\n", + " return h\n", + "\n", + "# ---------------------------------------------------------------------\n", + "# 3. Thomas Wang's 32-bit Integer Hash\n", + "# ---------------------------------------------------------------------\n", + "def hash_all_thomas_wang(code_points: np.ndarray) -> np.ndarray:\n", + " code_points = code_points.astype(np.uint32)\n", + " x = code_points.copy()\n", + " x = (x ^ np.uint32(61)) ^ (x >> np.uint32(16))\n", + " x = (x + (x << np.uint32(3))) & np.uint32(0xFFFFFFFF)\n", + " x = x ^ (x >> np.uint32(4))\n", + " x = (x * np.uint32(0x27d4eb2d)) & np.uint32(0xFFFFFFFF)\n", + " x = x ^ (x >> np.uint32(15))\n", + " return x\n", + "\n", + "# ---------------------------------------------------------------------\n", + "# 4. MurmurHash3 (x86 32-bit variant for 4-byte input)\n", + "# ---------------------------------------------------------------------\n", + "def hash_all_murmur3(code_points: np.ndarray, seed: np.uint32 = np.uint32(0)) -> np.ndarray:\n", + " code_points = code_points.astype(np.uint32)\n", + " c1 = np.uint32(0xcc9e2d51)\n", + " c2 = np.uint32(0x1b873593)\n", + " r1 = np.uint32(15)\n", + " r2 = np.uint32(13)\n", + " m = np.uint32(5)\n", + " n = np.uint32(0xe6546b64)\n", + " \n", + " # Treat each 32-bit integer as 4 bytes of data.\n", + " k = (code_points * c1) & np.uint32(0xFFFFFFFF)\n", + " k = ((k << r1) | (k >> (32 - r1))) & np.uint32(0xFFFFFFFF)\n", + " k = (k * c2) & np.uint32(0xFFFFFFFF)\n", + " \n", + " h = seed ^ k\n", + " h = ((h << r2) | (h >> (32 - r2))) & np.uint32(0xFFFFFFFF)\n", + " h = (h * m + n) & np.uint32(0xFFFFFFFF)\n", + " \n", + " # Since input length is always 4 bytes for a 32-bit integer:\n", + " h ^= np.uint32(4)\n", + " # Finalization mix\n", + " h ^= (h >> np.uint32(16))\n", + " h = (h * np.uint32(0x85ebca6b)) & np.uint32(0xFFFFFFFF)\n", + " h ^= (h >> np.uint32(13))\n", + " h = (h * np.uint32(0xc2b2ae35)) & np.uint32(0xFFFFFFFF)\n", + " h ^= (h >> np.uint32(16))\n", + " return h" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "def count_unique(x: np.ndarray) -> int:\n", + " # This approach is about 50% faster than `len(np.unique(x))`.\n", + " if x.size == 0:\n", + " return 0\n", + " xs = np.sort(x)\n", + " return int(np.count_nonzero(np.diff(xs)) + 1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def rotate_left(x: np.ndarray, r: int) -> np.ndarray:\n", + " \"\"\"Rotate left the 32-bit integers in x by r bits.\"\"\"\n", + " return ((x << np.uint32(r)) | (x >> np.uint32(32 - r))) & np.uint32(0xFFFFFFFF)\n", + "\n", + "def hash_custom(code_points: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Compute a composite hash on an array of 32-bit integers.\n", + " The hash is a combination of multiplications, rotations, and XOR mixing.\n", + " \"\"\"\n", + " # Ensure code_points are treated as 32-bit unsigned integers.\n", + " x = code_points.astype(np.uint32)\n", + " \n", + " # First mixing stage:\n", + " # Multiply by a constant and then rotate left.\n", + " x = (x * np.uint32(0xcc9e2d51)) & np.uint32(0xFFFFFFFF)\n", + " x = rotate_left(x, 15)\n", + " \n", + " # Second stage: XOR with a constant.\n", + " x ^= np.uint32(0x1b873593)\n", + " \n", + " # Third stage: Multiply and then rotate.\n", + " x = (x * np.uint32(0x85ebca6b)) & np.uint32(0xFFFFFFFF)\n", + " x = rotate_left(x, 13)\n", + " \n", + " # Fourth stage: Final XOR mix.\n", + " x ^= np.uint32(0xc2b2ae35)\n", + " \n", + " # Optionally, perform one more multiplication to scramble bits further.\n", + " x = (x * np.uint32(0x27d4eb2d)) & np.uint32(0xFFFFFFFF)\n", + " \n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Jenkins One-at-a-Time:\n", + "Unique hashes: 154,839 = 99.8555%\n", + "Unique hashes (modulo size): 97,909 = 63.1414%\n", + "Unique hashes (modulo 262144): 117,278 = 75.6325%\n", + "\n", + "FNV-1a:\n", + "Unique hashes: 155,063 = 100.0000%\n", + "Unique hashes (modulo size): 101,078 = 65.1851%\n", + "Unique hashes (modulo 262144): 104,858 = 67.6228%\n", + "\n", + "Thomas Wang's Hash:\n", + "Unique hashes: 155,063 = 100.0000%\n", + "Unique hashes (modulo size): 98,080 = 63.2517%\n", + "Unique hashes (modulo 262144): 117,089 = 75.5106%\n", + "\n", + "MurmurHash3:\n", + "Unique hashes: 155,063 = 100.0000%\n", + "Unique hashes (modulo size): 98,034 = 63.2220%\n", + "Unique hashes (modulo 262144): 116,970 = 75.4339%\n", + "\n", + "Custom:\n", + "Unique hashes: 155,063 = 100.0000%\n", + "Unique hashes (modulo size): 98,139 = 63.2898%\n", + "Unique hashes (modulo 262144): 116,442 = 75.0933%\n" + ] + } + ], + "source": [ + "for name, func in [\n", + " ('Jenkins One-at-a-Time', hash_all_jenkins),\n", + " ('FNV-1a', hash_all_fnv1a),\n", + " (\"Thomas Wang's Hash\", hash_all_thomas_wang),\n", + " ('MurmurHash3', hash_all_murmur3),\n", + " ('Custom', hash_custom),\n", + "]:\n", + " print(f\"\\n{name}:\")\n", + " hashes = func(code_points)\n", + " \n", + " unique_hashes = count_unique(hashes)\n", + " print(f\"Unique hashes: {unique_hashes:,} = {unique_hashes / len(code_points):.4%}\")\n", + " \n", + " # Lets estimate the number of collisions for different modulo values\n", + " hashes_modulo_valid = hashes % len(code_points)\n", + " unique_hashes_modulo_valid = count_unique(hashes_modulo_valid)\n", + " print(f\"Unique hashes (modulo size): {unique_hashes_modulo_valid:,} = {unique_hashes_modulo_valid / len(code_points):.4%}\")\n", + " \n", + " # Try the next power of 2 for modulo size\n", + " bitceil = 2 ** 18\n", + " hashes_modulo_bitceil = hashes % bitceil\n", + " unique_hashes_modulo_bitceil = count_unique(hashes_modulo_bitceil)\n", + " print(f\"Unique hashes (modulo {bitceil}): {unique_hashes_modulo_bitceil:,} = {unique_hashes_modulo_bitceil / len(code_points):.4%}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We end up with a fairly high collision rate of around 37% with vocabulary-size modulo and slightly more tolerable 25% with the next power of two.\n", + "Still, that's far from perfect-hashing.\n", + "Let's try different multiplicative hash-functions and see if we can find a better one." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# Take the range of all 32-bit unsigned integers\n", + "# ? Random shuffling to simplify the search for the first multiplier is a good idea,\n", + "# ? but it would take forever to run on 4 billion elements in Python.\n", + "# ! all_integers = np.arange(1, 2**32, dtype=np.uint32)\n", + "# ! np.random.shuffle(all_integers)\n", + "all_integers = np.random.permutation(2**32)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory usage for all_integers: 17,179,869,184 bytes\n" + ] + } + ], + "source": [ + "all_integers = all_integers.astype(np.uint32)\n", + "print(f\"Memory usage for all_integers: {all_integers.nbytes:,} bytes\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Take the range of all 32-bit unsigned integers\n", + "bitceil = 2 ** 18\n", + "for multiplier in tqdm(all_integers):\n", + " hashes = code_points * multiplier\n", + " \n", + " # Lets estimate the number of collisions for different modulo values\n", + " hashes_modulo_valid = hashes % len(code_points)\n", + " unique_hashes_modulo_valid = count_unique(hashes_modulo_valid)\n", + " if unique_hashes_modulo_valid == len(code_points):\n", + " print(f\"Multiplier (modulo size): {multiplier}\")\n", + " break\n", + " \n", + " # Try the next power of 2 for modulo size\n", + " hashes_modulo_bitceil = hashes % bitceil\n", + " unique_hashes_modulo_bitceil = count_unique(hashes_modulo_bitceil)\n", + " if unique_hashes_modulo_bitceil == len(code_points):\n", + " print(f\"Multiplier (modulo {bitceil}): {multiplier}\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from numba import uint32\n", + "\n", + "@njit(nopython=True)\n", + "def check_multiplier(code_points: np.ndarray, multiplier: uint32, seen_flags: np.ndarray) -> bool:\n", + " \"\"\"\n", + " Check if the multiplier produces a perfect hash mapping\n", + " for the given code_points with modulus `len(seen_flags)`.\n", + " Returns True if no collisions are found, False otherwise.\n", + " \"\"\"\n", + " # Create an array of flags for each hash value.\n", + " n = code_points.shape[0]\n", + " modulo = uint32(len(seen_flags))\n", + " for i in range(n):\n", + " # Compute hash value (simulate 32-bit wrap-around implicitly via modulo arithmetic)\n", + " h = uint32(code_points[i] * multiplier) % modulo\n", + " if seen_flags[h] == 1:\n", + " # Collision found.\n", + " return False\n", + " seen_flags[h] = 1\n", + " return True\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 13887159/4294967296 [2:11:09<673:54:25, 1764.62it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[18]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 11\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mMultiplier (modulo size): \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmultiplier\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mcheck_multiplier\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcode_points\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmultiplier\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseen_modulo_bitceil\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 15\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mMultiplier (modulo \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbitceil\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m): \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmultiplier\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "# Take the range of all 32-bit unsigned integers\n", + "seen_modulo_vocabulary = np.zeros(len(code_points), dtype=np.uint8)\n", + "seen_modulo_bitceil = np.zeros(2 ** 18, dtype=np.uint8)\n", + "\n", + "for multiplier in tqdm(all_integers):\n", + " seen_modulo_vocabulary.fill(0)\n", + " seen_modulo_bitceil.fill(0)\n", + "\n", + " # Lets estimate the number of collisions for different modulo values\n", + " if check_multiplier(code_points, multiplier, seen_modulo_vocabulary):\n", + " print(f\"Multiplier (modulo size): {multiplier}\")\n", + " break\n", + "\n", + " if check_multiplier(code_points, multiplier, seen_modulo_bitceil):\n", + " print(f\"Multiplier (modulo {bitceil}): {multiplier}\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 63daa5f56e3ab92a528936d97d55eab379e29af8 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:33:52 +0000 Subject: [PATCH 158/751] Add: `status_t` for errors in C++ --- include/stringzilla/stringzilla.hpp | 152 +++++++++++++++++----------- 1 file changed, 93 insertions(+), 59 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 98817dd7..87e14831 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -1062,6 +1062,31 @@ std::size_t range_length(iterator_type first, iterator_type last) { #pragma endregion +#pragma region Helper Types + +enum class status_t { + success_k = sz_success_k, + bad_alloc_k = sz_bad_alloc_k, + invalid_utf8_k = sz_invalid_utf8_k, + contains_duplicates_k = sz_contains_duplicates_k, +}; + +#if !SZ_AVOID_STL +void raise(status_t status) noexcept(false) { + switch (status) { + case status_t::bad_alloc_k: throw std::bad_alloc(); + case status_t::invalid_utf8_k: throw std::invalid_argument("Invalid UTF-8 string"); + case status_t::contains_duplicates_k: throw std::invalid_argument("Array contains identical strings"); + default: break; + } +} + +using sorted_idx_t = sz_sorted_idx_t; + +#endif + +#pragma endregion + #pragma region Global Operations with Dynamic Memory template @@ -1082,16 +1107,20 @@ static sz_u64_t _call_random_generator(void *state) noexcept { } template -static bool _with_alloc(allocator_type_ &allocator, allocator_callback_ &&callback) noexcept { +static status_t _with_alloc(allocator_type_ &allocator, allocator_callback_ &&callback) noexcept { sz_memory_allocator_t alloc; alloc.allocate = &_call_allocate; alloc.free = &_call_free; alloc.handle = &allocator; - return callback(alloc); + return static_cast(callback(alloc)); } +/** + * @brief Helper function, wrapping a C++ allocator into a C-style allocator. + * @return Error code or success. All allocating functions may fail. + */ template -static bool _with_alloc(allocator_callback_ &&callback) noexcept { +static status_t _with_alloc(allocator_callback_ &&callback) noexcept { allocator_type_ allocator; return _with_alloc(allocator, std::forward(callback)); } @@ -2038,23 +2067,23 @@ class basic_string { static_assert(std::is_empty::value, "We currently only support stateless allocators"); template - static bool _with_alloc(allocator_callback &&callback) noexcept { + static status_t _with_alloc(allocator_callback &&callback) noexcept { return ashvardanian::stringzilla::_with_alloc(callback); } void init(std::size_t length, char_type value) noexcept(false) { sz_ptr_t start; - if (!_with_alloc( - [&](sz_alloc_type &alloc) { return (start = sz_string_init_length(&string_, length, &alloc)); })) - throw std::bad_alloc(); + raise(_with_alloc([&](sz_alloc_type &alloc) { + return (start = sz_string_init_length(&string_, length, &alloc)) ? sz_success_k : sz_bad_alloc_k; + })); sz_fill(start, length, *(sz_u8_t *)&value); } void init(string_view other) noexcept(false) { sz_ptr_t start; - if (!_with_alloc( - [&](sz_alloc_type &alloc) { return (start = sz_string_init_length(&string_, other.size(), &alloc)); })) - throw std::bad_alloc(); + raise(_with_alloc([&](sz_alloc_type &alloc) { + return (start = sz_string_init_length(&string_, other.size(), &alloc)) ? sz_success_k : sz_bad_alloc_k; + })); sz_copy(start, (sz_cptr_t)other.data(), other.size()); } @@ -2121,7 +2150,7 @@ class basic_string { ~basic_string() noexcept { _with_alloc([&](sz_alloc_type &alloc) { sz_string_free(&string_, &alloc); - return true; + return sz_success_k; }); } @@ -2130,7 +2159,7 @@ class basic_string { if (!is_internal()) { _with_alloc([&](sz_alloc_type &alloc) { sz_string_free(&string_, &alloc); - return true; + return sz_success_k; }); } move(other); @@ -2750,7 +2779,10 @@ class basic_string { * @return `true` if the operation was successful and potentially reduced the memory footprint, `false` otherwise. */ bool try_shrink_to_fit() noexcept { - return _with_alloc([&](sz_alloc_type &alloc) { return sz_string_shrink_to_fit(&string_, &alloc); }); + auto status = _with_alloc([&](sz_alloc_type &alloc) { + return sz_string_shrink_to_fit(&string_, &alloc) ? sz_success_k : sz_bad_alloc_k; + }); + return status == status_t::success_k; } /** @@ -2759,7 +2791,10 @@ class basic_string { * @return `true` if the reservation was successful, `false` otherwise. */ bool try_reserve(size_type capacity) noexcept { - return _with_alloc([&](sz_alloc_type &alloc) { return sz_string_reserve(&string_, capacity, &alloc); }); + auto status = _with_alloc([&](sz_alloc_type &alloc) { + return sz_string_reserve(&string_, capacity, &alloc) ? sz_success_k : sz_bad_alloc_k; + }); + return status == status_t::success_k; } /** @@ -2827,9 +2862,10 @@ class basic_string { bool try_insert(difference_type signed_offset, string_view string) noexcept { sz_size_t normalized_offset, normalized_length; sz_ssize_clamp_interval(size(), signed_offset, 0, &normalized_offset, &normalized_length); - if (!_with_alloc([&](sz_alloc_type &alloc) { - return sz_string_expand(&string_, normalized_offset, string.size(), &alloc); - })) + if (_with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, normalized_offset, string.size(), &alloc) ? sz_success_k + : sz_bad_alloc_k; + }) != status_t::success_k) return false; sz_copy(data() + normalized_offset, string.data(), string.size()); @@ -2909,10 +2945,9 @@ class basic_string { basic_string &insert(size_type offset, string_view other) noexcept(false) { if (offset > size()) throw std::out_of_range("sz::basic_string::insert"); if (size() + other.size() > max_size()) throw std::length_error("sz::basic_string::insert"); - if (!_with_alloc( - [&](sz_alloc_type &alloc) { return sz_string_expand(&string_, offset, other.size(), &alloc); })) - throw std::bad_alloc(); - + raise(_with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, offset, other.size(), &alloc) ? sz_success_k : sz_bad_alloc_k; + })); sz_copy(data() + offset, other.data(), other.size()); return *this; } @@ -2977,8 +3012,9 @@ class basic_string { auto added_length = range_length(first, last); if (size() + added_length > max_size()) throw std::length_error("sz::basic_string::insert"); - if (!_with_alloc([&](sz_alloc_type &alloc) { return sz_string_expand(&string_, pos, added_length, &alloc); })) - throw std::bad_alloc(); + raise(_with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, pos, added_length, &alloc) ? sz_success_k : sz_bad_alloc_k; + })); iterator result = begin() + pos; for (iterator output = result; first != last; ++first, ++output) *output = *first; @@ -3327,8 +3363,7 @@ class basic_string { size_type edit_distance(string_view other, size_type bound = 0) const noexcept { size_type result; _with_alloc([&](sz_alloc_type &alloc) { - return sz_levenshtein_distance(data(), size(), other.data(), other.size(), bound, &alloc, &result) != - sz_bad_alloc_k; + return sz_levenshtein_distance(data(), size(), other.data(), other.size(), bound, &alloc, &result); }); return result; } @@ -3485,9 +3520,10 @@ bool basic_string::try_resize(size_type count, value_typ // Allocate more space if needed. if (count >= string_space) { - if (!_with_alloc([&](sz_alloc_type &alloc) { - return sz_string_expand(&string_, SZ_SIZE_MAX, count - string_length, &alloc) != NULL; - })) + if (_with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, SZ_SIZE_MAX, count - string_length, &alloc) ? sz_success_k + : sz_bad_alloc_k; + }) != status_t::success_k) return false; sz_string_unpack(&string_, &string_start, &string_length, &string_space, &string_is_external); } @@ -3525,12 +3561,12 @@ bool basic_string::try_assign(string_view other) noexcep } // In the common case, however, we need to allocate. else { - if (!_with_alloc([&](sz_alloc_type &alloc) { + if (_with_alloc([&](sz_alloc_type &alloc) { string_start = sz_string_expand(&string_, SZ_SIZE_MAX, other.length() - string_length, &alloc); - if (!string_start) return false; + if (!string_start) return sz_bad_alloc_k; other.copy(string_start, other.length()); - return true; - })) + return sz_success_k; + }) != status_t::success_k) return false; } return true; @@ -3538,18 +3574,19 @@ bool basic_string::try_assign(string_view other) noexcep template bool basic_string::try_push_back(char_type c) noexcept { - return _with_alloc([&](sz_alloc_type &alloc) { + auto result = _with_alloc([&](sz_alloc_type &alloc) { auto old_size = size(); sz_ptr_t start = sz_string_expand(&string_, SZ_SIZE_MAX, 1, &alloc); - if (!start) return false; + if (!start) return sz_bad_alloc_k; start[old_size] = c; - return true; + return sz_success_k; }); + return result == status_t::success_k; } template bool basic_string::try_append(const_pointer str, size_type length) noexcept { - return _with_alloc([&](sz_alloc_type &alloc) { + auto result = _with_alloc([&](sz_alloc_type &alloc) { // Sometimes we are inserting part of this string into itself. // By the time `sz_string_expand` finished, the old `str` pointer may be invalidated, // so we need to handle that special case separately. @@ -3557,16 +3594,17 @@ bool basic_string::try_append(const_pointer str, size_ty if (str >= this_span.begin() && str < this_span.end()) { auto str_offset_in_this = str - data(); sz_ptr_t start = sz_string_expand(&string_, SZ_SIZE_MAX, length, &alloc); - if (!start) return false; + if (!start) return sz_bad_alloc_k; sz_copy(start + this_span.size(), start + str_offset_in_this, length); } else { sz_ptr_t start = sz_string_expand(&string_, SZ_SIZE_MAX, length, &alloc); - if (!start) return false; + if (!start) return sz_bad_alloc_k; sz_copy(start + this_span.size(), str, length); } - return true; + return sz_success_k; }); + return result == status_t::success_k; } template @@ -3677,12 +3715,12 @@ bool basic_string::try_assign(concatenation::try_preparing_replacement( // assert(offset + length <= size()); // 1. The replacement is the same length as the replaced range. - if (replacement_length == length) { return true; } + if (replacement_length == length) return true; // 2. The replacement is shorter than the replaced range. else if (replacement_length < length) { @@ -3708,9 +3746,11 @@ bool basic_string::try_preparing_replacement( // } // 3. The replacement is longer than the replaced range. An allocation may occur. else { - return _with_alloc([&](sz_alloc_type &alloc) { - return sz_string_expand(&string_, offset + length, replacement_length - length, &alloc); + auto result = _with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, offset + length, replacement_length - length, &alloc) ? sz_success_k + : sz_bad_alloc_k; }); + return result == status_t::success_k; } } @@ -3849,11 +3889,9 @@ std::size_t edit_distance( // basic_string_slice const &a, basic_string_slice const &b, std::size_t bound = SZ_SIZE_MAX, allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { std::size_t result; - if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { - return sz_levenshtein_distance(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result) != - sz_bad_alloc_k; - })) - throw std::bad_alloc(); + raise(_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { + return sz_levenshtein_distance(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result); + })); return result; } @@ -3877,11 +3915,9 @@ std::size_t edit_distance_utf8( basic_string_slice const &a, basic_string_slice const &b, // std::size_t bound = SZ_SIZE_MAX, allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { std::size_t result; - if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { - return sz_levenshtein_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result) != - sz_bad_alloc_k; - })) - throw std::bad_alloc(); + raise(_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { + return sz_levenshtein_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result); + })); return result; } @@ -3911,11 +3947,9 @@ std::ptrdiff_t alignment_score( "sz_error_cost_t must be signed."); std::ptrdiff_t result; - if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { - return sz_needleman_wunsch_score(a.data(), a.size(), b.data(), b.size(), &subs[0][0], gap, &alloc, - &result) != sz_bad_alloc_k; - })) - throw std::bad_alloc(); + raise(_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { + return sz_needleman_wunsch_score(a.data(), a.size(), b.data(), b.size(), &subs[0][0], gap, &alloc, &result); + })); return result; } From 1ce830bfda97715e126d5cefade0e9aa4c1c1ca2 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:37:31 +0000 Subject: [PATCH 159/751] Break: C++ `lookup` and `fill_random` --- include/stringzilla/stringzilla.hpp | 27 ++++++++++++--------------- scripts/test.cpp | 6 +++--- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 87e14831..91b2480b 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -3395,9 +3395,9 @@ class basic_string { * In this case the undefined behaviour in concurrent environments may play in our favor, * but it's recommended to use the other overload in such cases. */ - basic_string &randomize() noexcept { + basic_string &fill_random() noexcept { static sz_u64_t nonce = 42; - return randomize(nonce++); + return fill_random(nonce++); } /** @@ -3407,7 +3407,7 @@ class basic_string { * @throw `std::bad_alloc` if the allocation fails. */ static basic_string random(size_type length, sz_u64_t nonce) noexcept(false) { - return basic_string(length, '\0').randomize(nonce); + return basic_string(length, '\0').fill_random(nonce); } /** @@ -3415,7 +3415,7 @@ class basic_string { * @param[in] length The length of the generated string. * @throw `std::bad_alloc` if the allocation fails. */ - static basic_string random(size_type length) noexcept(false) { return basic_string(length, '\0').randomize(); } + static basic_string random(size_type length) noexcept(false) { return basic_string(length, '\0').fill_random(); } /** * @brief Replaces @b (in-place) all occurrences of a given string with the ::replacement string. @@ -3471,8 +3471,8 @@ class basic_string { * @brief Replaces @b (in-place) all characters in the string using the provided lookup @p table. * @sa sz_lookup */ - basic_string &transform(look_up_table const &table) noexcept { - transform(table, data()); + basic_string &lookup(look_up_table const &table) noexcept { + lookup(table, data()); return *this; } @@ -3481,7 +3481,7 @@ class basic_string { * @param[in] output The buffer to write the transformed string into. * @sa sz_lookup */ - void transform(look_up_table const &table, pointer output) const noexcept { + void lookup(look_up_table const &table, pointer output) const noexcept { sz_ptr_t start; sz_size_t length; sz_string_range(&string_, &start, &length); @@ -3971,7 +3971,7 @@ std::ptrdiff_t alignment_score( * @sa sz_fill_random */ template -void randomize(basic_string_slice string, sz_u64_t nonce) noexcept { +void fill_random(basic_string_slice string, sz_u64_t nonce) noexcept { static_assert(!std::is_const::value, "The string must be mutable."); sz_fill_random(string.data(), string.size(), nonce); } @@ -3982,9 +3982,8 @@ void randomize(basic_string_slice string, sz_u64_t nonce) noexcept { * @sa sz_fill_random */ template -void lookup(basic_string_slice string, basic_look_up_table const &table) noexcept { - static_assert(sizeof(char_type_) == 1, "The character type must be 1 byte long."); - sz_lookup((sz_ptr_t)string.data(), (sz_size_t)string.size(), (sz_cptr_t)string.data(), (sz_cptr_t)table.raw()); +void fill_random(basic_string_slice string) noexcept { + fill_random(string, std::rand()); } /** @@ -4004,12 +4003,10 @@ void lookup( // * @sa sz_lookup */ template -void randomize(basic_string_slice string, string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { - randomize(string, std::rand, alphabet); +void lookup(basic_string_slice string, basic_look_up_table const &table) noexcept { + lookup(string, table, string.data()); } -using sorted_idx_t = sz_sorted_idx_t; - /** * @brief Internal data-structure used to wrap arbitrary sequential containers with a random-order lookup. * @sa try_argsort, argsort, try_join, join diff --git a/scripts/test.cpp b/scripts/test.cpp index 63452df4..aefcb638 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1079,9 +1079,9 @@ void test_non_stl_extensions_for_updates() { sz::look_up_table invert_case = sz::look_up_table::identity(); for (char c = 'a'; c <= 'z'; c++) invert_case[c] = c - 'a' + 'A'; for (char c = 'A'; c <= 'Z'; c++) invert_case[c] = c - 'A' + 'a'; - assert_scoped(str s = "hello", s.transform(invert_case), s == "HELLO"); - assert_scoped(str s = "HeLLo", s.transform(invert_case), s == "hEllO"); - assert_scoped(str s = "H-lL0", s.transform(invert_case), s == "h-Ll0"); + assert_scoped(str s = "hello", s.lookup(invert_case), s == "HELLO"); + assert_scoped(str s = "HeLLo", s.lookup(invert_case), s == "hEllO"); + assert_scoped(str s = "H-lL0", s.lookup(invert_case), s == "h-Ll0"); // Concatenation. assert(str(str("a") | str("b")) == "ab"); From 5ea06981d27510b86d70cbdca37378fdd0b10c6c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:38:39 +0000 Subject: [PATCH 160/751] Add: C++ `argsort`, `intersect` --- include/stringzilla/stringzilla.hpp | 190 +++++++++++++++++++--------- 1 file changed, 133 insertions(+), 57 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 91b2480b..a8c077d6 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -4011,27 +4011,25 @@ void lookup(basic_string_slice string, basic_look_up_table +template struct _sequence_args { - objects_type_ const *begin; - std::size_t count; - sorted_idx_t *order; - string_extractor_ extractor; + container_type_ const &container; + string_extractor_ const &extractor; }; -template -sz_cptr_t _call_sequence_member_start(void const *sequence, sz_size_t i) { - using handle_type = _sequence_args; - handle_type const *args = reinterpret_cast(sequence); - string_view member = args->extractor(args->begin[i]); +template +sz_cptr_t _call_sequence_member_start(void const *sequence_args_ptr, sz_size_t i) { + using sequence_args_t = _sequence_args; + sequence_args_t const *args = reinterpret_cast(sequence_args_ptr); + string_view member = args->extractor(args->container[i]); return member.data(); } -template -sz_size_t _call_sequence_member_length(void const *sequence, sz_size_t i) { - using handle_type = _sequence_args; - handle_type const *args = reinterpret_cast(sequence); - string_view member = args->extractor(args->begin[i]); +template +sz_size_t _call_sequence_member_length(void const *sequence_args_ptr, sz_size_t i) { + using sequence_args_t = _sequence_args; + sequence_args_t const *args = reinterpret_cast(sequence_args_ptr); + string_view member = args->extractor(args->container[i]); return static_cast(member.size()); } @@ -4039,40 +4037,79 @@ sz_size_t _call_sequence_member_length(void const *sequence, sz_size_t i) { * @brief Computes the permutation of an array, that would lead to sorted order. * The elements of the array must be convertible to a `string_view` with the given extractor. * Unlike the `sz_sequence_argsort` C interface, overwrites the output array. + * @sa sz_sequence_argsort * - * @param[in] begin The pointer to the first element of the array. - * @param[in] end The pointer to the element after the last element of the array. - * @param[out] order The pointer to the output array of indices, that will be populated with the permutation. - * @param[in] extractor The function object that extracts the string from the object. - * - * @see sz_sequence_argsort + * @param[in] begin The pointer to the first element of the array. + * @param[in] end The pointer to the element after the last element of the array. + * @param[in] extractor The function object that extracts the string from the object. + * @param[out] order The pointer to the output array of indices, that will be populated with the permutation. */ -template -void argsort(objects_type_ const *begin, objects_type_ const *end, sorted_idx_t *order, - string_extractor_ &&extractor) noexcept { +template +status_t try_argsort(container_type_ const &container, string_extractor_ const &extractor, + sorted_idx_t *order) noexcept { // Pack the arguments into a single structure to reference it from the callback. - _sequence_args args = {begin, static_cast(end - begin), order, - std::forward(extractor)}; - // Populate the array with `iota`-style order. - for (std::size_t i = 0; i != args.count; ++i) order[i] = static_cast(i); + using args_t = _sequence_args; + args_t args {container, extractor}; + sz_sequence_t sequence; + sequence.handle = &args; + sequence.count = container.size(); + sequence.get_start = _call_sequence_member_start; + sequence.get_length = _call_sequence_member_length; - sz_sequence_t array; - array.count = args.count; - array.handle = &args; - array.get_start = _call_sequence_member_start; - array.get_length = _call_sequence_member_length; + using sz_alloc_type = sz_memory_allocator_t; + return _with_alloc>( + [&](sz_alloc_type &alloc) { return sz_sequence_argsort(&sequence, &alloc, order); }); +} + +/** + * @brief Locates the positions of the elements in 2 deduplicated string arrays that have identical values. + * @sa sz_sequence_join + * + * @param[in] first_begin The pointer to the first element of the first array. + * @param[in] first_end The pointer to the element after the last element of the first array. + * @param[in] second_begin The pointer to the first element of the second array. + * @param[in] second_end The pointer to the element after the last element of the second array. + * @param[out] first_positions The pointer to the output array of indices from the first array. + * @param[out] second_positions The pointer to the output array of indices from the second array. + * @param[in] first_extractor The function object that extracts the string from the object in the first array. + * @param[in] second_extractor The function object that extracts the string from the object in the second array. + */ +template +status_t try_intersect( // + first_container_ const &first_container, first_extractor_ const &first_extractor, // + second_container_ const &second_container, second_extractor_ const &second_extractor, // + std::uint64_t seed, std::size_t *intersection_size_ptr, // + sorted_idx_t *first_positions, sorted_idx_t *second_positions) noexcept { + + // Pack the arguments into a single structure to reference it from the callback. + using first_t = _sequence_args; + using second_t = _sequence_args; + first_t first_args {first_container, first_extractor}; + second_t second_args {second_container, second_extractor}; + + sz_sequence_t first_sequence, second_sequence; + first_sequence.count = first_container.size(), second_sequence.count = second_container.size(); + first_sequence.handle = &first_args, second_sequence.handle = &second_args; + first_sequence.get_start = _call_sequence_member_start; + first_sequence.get_length = _call_sequence_member_length; + second_sequence.get_start = _call_sequence_member_start; + second_sequence.get_length = _call_sequence_member_length; using sz_alloc_type = sz_memory_allocator_t; - _with_alloc>( - [&](sz_alloc_type &alloc) { return sz_sequence_argsort(&array, &alloc, order); }); + return _with_alloc>([&](sz_alloc_type &alloc) { + static_assert(sizeof(sz_size_t) == sizeof(std::size_t), "sz_size_t must be the same size as std::size_t."); + return sz_sequence_intersect(&first_sequence, &second_sequence, &alloc, static_cast(seed), + reinterpret_cast(intersection_size_ptr), first_positions, + second_positions); + }); } #if !SZ_AVOID_STL #if _SZ_DEPRECATED_FINGERPRINTS /** - * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. - * @see sz_hashes + * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. + * @sa sz_hashes */ template void hashes_fingerprint( // @@ -4105,41 +4142,80 @@ std::bitset hashes_fingerprint(basic_string const &str #endif /** - * @brief Computes the permutation of an array, that would lead to sorted order. + * @brief Computes the permutation of an array, that would lead to sorted order. * @return The array of indices, that will be populated with the permutation. - * @throw `std::bad_alloc` if the allocation fails. + * @throw `std::bad_alloc` if the allocation fails. */ -template +template std::vector argsort( // - objects_type_ const *begin, objects_type_ const *end, string_extractor_ &&extractor) noexcept(false) { - std::vector order(end - begin); - argsort(begin, end, order.data(), std::forward(extractor)); + container_type_ const &container, string_extractor_ const &extractor) noexcept(false) { + std::vector order(container.size()); + status_t status = try_argsort(container, extractor, order.data()); + raise(status); return order; } /** - * @brief Computes the permutation of an array, that would lead to sorted order. + * @brief Computes the permutation of an array, that would lead to sorted order. * @return The array of indices, that will be populated with the permutation. - * @throw `std::bad_alloc` if the allocation fails. + * @throw `std::bad_alloc` if the allocation fails. */ -template -std::vector argsort(string_like_type_ const *begin, string_like_type_ const *end) noexcept(false) { +template +std::vector argsort(container_type_ const &container) noexcept(false) { + using string_like_type = typename container_type_::value_type; static_assert( // - std::is_convertible::value, "The type must be convertible to string_view."); - return argsort(begin, end, [](string_like_type_ const &s) -> string_view { return s; }); + std::is_convertible::value, "The type must be convertible to string_view."); + return argsort(container, [](string_like_type const &s) -> string_view { return s; }); } +struct intersect_result_t { + std::vector first_offsets; + std::vector second_offsets; +}; + /** - * @brief Computes the permutation of an array, that would lead to sorted order. - * @return The array of indices, that will be populated with the permutation. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Locates identical elements in two arrays. + * @return Two arrays of indicies, mapping the elements of the first and the second array that have identical values. + * @throw `std::bad_alloc` if the allocation fails. */ -template -std::vector argsort(std::vector const &array) noexcept(false) { +template +intersect_result_t intersect(first_type_ const &first, second_type_ const &second, + first_extractor_ const &first_extractor, second_extractor_ const &second_extractor, + std::uint64_t seed = 0) noexcept(false) { + + std::size_t const max_count = (std::min)(first.size(), second.size()); + std::vector first_positions(max_count); + std::vector second_positions(max_count); + std::size_t count; + status_t status = try_intersect( // + first, first_extractor, // + second, second_extractor, // + seed, &count, first_positions.data(), second_positions.data()); + raise(status); + first_positions.resize(count); + second_positions.resize(count); + return {std::move(first_positions), std::move(second_positions)}; +} + +/** + * @brief Locates identical elements in two arrays. + * @return Two arrays of indicies, mapping the elements of the first and the second array that have identical values. + * @throw `std::bad_alloc` if the allocation fails. + */ +template +intersect_result_t intersect(first_type_ const &first, second_type_ const &second, + std::uint64_t seed = 0) noexcept(false) { + using first_string_type = typename first_type_::value_type; + using second_string_type = typename second_type_::value_type; + static_assert( // + std::is_convertible::value, "The type must be convertible to string_view."); static_assert( // - std::is_convertible::value, "The type must be convertible to string_view."); - return argsort(array.data(), array.data() + array.size(), - [](string_like_type_ const &s) -> string_view { return s; }); + std::is_convertible::value, "The type must be convertible to string_view."); + return intersect( + first, second, // + [](first_string_type const &s) -> string_view { return s; }, // + [](second_string_type const &s) -> string_view { return s; }, // + seed); } #endif From f656577f60ec12ee4eb1b17888b2fc22b71d6e88 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 05:39:45 +0000 Subject: [PATCH 161/751] Improve: Fix minor inconsistencies --- .clang-format | 1 + .vscode/settings.json | 34 +++++++++++++++- README.md | 6 +-- include/stringzilla/hash.h | 63 ++++++++++++++++++++++++++++- include/stringzilla/memory.h | 63 ++--------------------------- include/stringzilla/similarity.h | 8 ++-- include/stringzilla/stringzilla.hpp | 51 ++++++++++++----------- rust/lib.rs | 38 ++++++++++++----- 8 files changed, 159 insertions(+), 105 deletions(-) diff --git a/.clang-format b/.clang-format index c97feb6f..1ce7d064 100644 --- a/.clang-format +++ b/.clang-format @@ -6,6 +6,7 @@ NamespaceIndentation: None ColumnLimit: 120 ReflowComments: true UseTab: Never +IndentPPDirectives: None AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false diff --git a/.vscode/settings.json b/.vscode/settings.json index a4925981..678b1305 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools", + "C_Cpp.dimInactiveRegions": false, // This may cause overheating. // https://github.com/microsoft/vscode-cpptools/issues/1816 "C_Cpp.workspaceParsingPriority": "low", @@ -17,6 +18,7 @@ }, "cmake.sourceDirectory": "${workspaceRoot}", "cSpell.words": [ + "aesdec", "allowoverlap", "aminoacid", "aminoacids", @@ -24,6 +26,7 @@ "Appleby", "ASAN", "ashvardanian", + "Aumasson", "Baeza", "basicsize", "bigram", @@ -32,10 +35,16 @@ "bioinformatics", "Bitap", "bitcast", + "bitceil", "BLOSUM", + "Borwein", + "Brase", "Brumme", + "Byteset", + "bytesum", "carray", "Cawley", + "chardet", "cheminformatics", "cibuildwheel", "CONCAT", @@ -43,6 +52,7 @@ "copydoc", "Corasick", "cptr", + "DRBG", "endregion", "endswith", "Eron", @@ -51,7 +61,9 @@ "getitem", "getslice", "Giancarlo", + "Giordano", "Gonnet", + "Gotoh", "Haswell", "Heikki", "hexdigits", @@ -65,6 +77,7 @@ "isprintable", "itemsize", "Jaccard", + "Kaitchuck", "Karp", "keeplinebreaks", "keepseparator", @@ -82,13 +95,18 @@ "memcpy", "Merkle-Damgård", "Mersenne", + "misalign", "MODINIT", + "Morten", + "Mosè", "MSVC", "napi", "nargsf", "ndim", "Needleman", "newfunc", + "ngram", + "ngrams", "NOARGS", "noexcept", "NOMINMAX", @@ -97,6 +115,9 @@ "numpy", "octdigits", "octogram", + "pgram", + "pgrams", + "Plouffe", "printables", "pytest", "Pythonic", @@ -104,7 +125,9 @@ "quadgram", "Raita", "readlines", + "Reini", "releasebuffer", + "repr", "rfind", "rfinds", "richcompare", @@ -116,6 +139,7 @@ "rsplits", "rstrip", "SIMD", + "sklearn", "Skylake", "splitlines", "ssize", @@ -138,10 +162,14 @@ "Vardanian", "VBMI", "vectorcallfunc", + "Vectorizer", "Wagner", "whitespaces", "Wunsch", "XDECREF", + "xmms", + "Yann", + "Yaroshevskiy", "Zilla" ], "editor.formatOnSave": true, @@ -149,7 +177,6 @@ 120 ], "files.associations": { - "*.tcc": "cpp", "__bit_reference": "cpp", "__bits": "cpp", "__config": "cpp", @@ -168,12 +195,14 @@ "__tree": "cpp", "__tuple": "cpp", "__verbose_abort": "cpp", + "*.tcc": "cpp", "algorithm": "cpp", "any": "cpp", "array": "cpp", "atomic": "cpp", "bit": "cpp", "bitset": "cpp", + "cassert": "cpp", "cctype": "cpp", "charconv": "c", "chrono": "cpp", @@ -231,6 +260,7 @@ "semaphore": "cpp", "set": "cpp", "shared_mutex": "cpp", + "sort.h": "c", "source_location": "cpp", "span": "cpp", "sstream": "cpp", @@ -269,6 +299,6 @@ "xstring": "cpp", "xtr1common": "cpp", "xtree": "cpp", - "xutility": "cpp", + "xutility": "cpp" } } \ No newline at end of file diff --git a/README.md b/README.md index 18aea8e2..a3121cb4 100644 --- a/README.md +++ b/README.md @@ -1072,11 +1072,11 @@ Similar to Python it also defines the commonly used character sets. auto protein = sz::string::random(300, "ARNDCQEGHILKMFPSTWYV"); // static method auto dna = sz::basic_string::random(3_000_000_000, "ACGT"); -dna.randomize("ACGT"); // `noexcept` pre-allocated version -dna.randomize(&std::rand, "ACGT"); // pass any generator, like `std::mt19937` +dna.fill_random("ACGT"); // `noexcept` pre-allocated version +dna.fill_random(&std::rand, "ACGT"); // pass any generator, like `std::mt19937` char uuid[36]; -sz::randomize(sz::string_span(uuid, 36), "0123456789abcdef-"); // Overwrite any buffer +sz::fill_random(sz::string_span(uuid, 36), "0123456789abcdef-"); // Overwrite any buffer ``` ### Bulk Replacements diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index e23b700a..1db8a4b3 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -843,7 +843,7 @@ SZ_INTERNAL void _sz_hash_minimal_init_haswell(_sz_hash_minimal_t *state, sz_u64 __m128i k1 = _mm_xor_si128(seed_vec, pi0); __m128i k2 = _mm_xor_si128(seed_vec, pi1); - // The first 128 bits of the "sum" and "AES" blocks are the same + // The first 128 bits of the "sum" and "AES" blocks are the same for the "minimal" and full state state->aes.xmm = k1; state->sum.xmm = k2; } @@ -1559,6 +1559,8 @@ SZ_INTERNAL void _sz_hash_state_update_ice(sz_hash_state_t *state) { SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { + // For short strings the "masked loads" are identical to Skylake-X and + // the "logic" is identical to Haswell. if (length <= 16) { // Initialize the AES block with a given seed _sz_hash_minimal_t state; @@ -1611,6 +1613,7 @@ SZ_PUBLIC sz_u64_t sz_hash_ice(sz_cptr_t start, sz_size_t length, sz_u64_t seed) _sz_hash_minimal_update_haswell(&state, data3_vec.xmm); return _sz_hash_minimal_finalize_haswell(&state, length); } + // This is where the logic differs from Skylake-X and other pre-Ice Lake CPUs: else { // Use a larger state to handle the main loop and add different offsets // to different lanes of the register @@ -1716,6 +1719,64 @@ SZ_PUBLIC void sz_fill_random_ice(sz_ptr_t output, sz_size_t length, sz_u64_t no } } +/** + * @brief A wider parallel analog of `_sz_hash_minimal_t`, which is not used for computing individual hashes, + * but for parallel hashing of @b short 4x separate strings under 16 bytes long. + * Useful for higher-level Database and Machine Learning operations. + */ +typedef struct _sz_hash_minimal_x4_t { + sz_u512_vec_t aes; + sz_u512_vec_t sum; + sz_u512_vec_t key; +} _sz_hash_minimal_x4_t; + +SZ_INTERNAL void _sz_hash_minimal_x4_init_ice(_sz_hash_minimal_x4_t *state, sz_u64_t seed) { + + // The key is made from the seed and half of it will be mixed with the length in the end + __m512i seed_vec = _mm512_set1_epi64(seed); + state->key.zmm = seed_vec; + + // XOR the user-supplied keys with the two "pi" constants + sz_u64_t const *pi = _sz_hash_pi_constants(); + __m512i pi0 = _mm512_load_si512((__m512i const *)(pi)); + __m512i pi1 = _mm512_load_si512((__m512i const *)(pi + 8)); + // We will load the entire 512-bit values, but will only use the first 128 bits, + // replicating it 4x times across the register. The `_mm512_shuffle_i64x2` is supposed to + // be faster than `_mm512_broadcast_i64x2` on Ice Lake. + pi0 = _mm512_shuffle_i64x2(pi0, pi0, 0); + pi1 = _mm512_shuffle_i64x2(pi1, pi1, 0); + __m512i k1 = _mm512_xor_si512(seed_vec, pi0); + __m512i k2 = _mm512_xor_si512(seed_vec, pi1); + + // The first 128 bits of the "sum" and "AES" blocks are the same for the "minimal" and full state + state->aes.zmm = k1; + state->sum.zmm = k2; +} + +SZ_INTERNAL __m256i _sz_hash_minimal_x4_finalize_ice(_sz_hash_minimal_x4_t const *state, // + sz_size_t length0, sz_size_t length1, sz_size_t length2, + sz_size_t length3) { + __m512i const padded_lengths = _mm512_set_epi64(0, length3, 0, length2, 0, length1, 0, length0); + // Mix the length into the key + __m512i key_with_length = _mm512_add_epi64(state->key.zmm, padded_lengths); + // Combine the "sum" and the "AES" blocks + __m512i mixed_registers = _mm512_aesenc_epi128(state->sum.zmm, state->aes.zmm); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + __m512i mixed_within_register = + _mm512_aesenc_epi128(_mm512_aesenc_epi128(mixed_registers, key_with_length), mixed_registers); + // Extract the low 64 bits from each 128-bit lane - weirdly using the `permutexvar` instruction + // is cheaper than compressing instructions like `_mm512_maskz_compress_epi64`. + return _mm512_castsi512_si256( + _mm512_permutexvar_epi64(_mm512_set_epi64(0, 0, 0, 0, 6, 4, 2, 0), mixed_within_register)); +} + +SZ_INTERNAL void _sz_hash_minimal_x4_update_ice(_sz_hash_minimal_x4_t *state, __m512i blocks) { + __m512i const shuffle_mask = _mm512_load_si512((__m512i const *)_sz_hash_u8x16x4_shuffle()); + state->aes.zmm = _mm512_aesenc_epi128(state->aes.zmm, blocks); + state->sum.zmm = _mm512_add_epi64(_mm512_shuffle_epi8(state->sum.zmm, shuffle_mask), blocks); +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_ICE diff --git a/include/stringzilla/memory.h b/include/stringzilla/memory.h index 79cd840c..5b14108c 100644 --- a/include/stringzilla/memory.h +++ b/include/stringzilla/memory.h @@ -5,12 +5,11 @@ * * Includes core APIs for contiguous memory operations: * - * - @b `sz_copy` - analog to `memcpy`, probably the most common operation in a computer - * - @b `sz_move` - analog to `memmove`, allowing overlapping memory regions, often used in string manipulation - * - @b `sz_fill` - analog to `memset`, often used to initialize memory with a constant value, like zero + * - @b `sz_copy` - analog to @b `memcpy`, probably the most common operation in a computer + * - @b `sz_move` - analog to @b `memmove`, allowing overlapping memory regions, often used in string manipulation + * - @b `sz_fill` - analog to @b `memset`, often used to initialize memory with a constant value, like zero * - @b `sz_lookup` - Look-Up Table @b (LUT) transformation of a string, mapping each byte to a new value * - TODO: @b `sz_lookup_utf8` - LUT transformation of a UTF8 string, which can be used for normalization - * - TODO: @b `sz_detect_encoding` - detects the character encoding similar to "iconv" or "chardet" tools * * All of the core APIs receive the target output buffer as the first argument, * and aim to minimize the number of "store" instructions, especially unaligned ones, @@ -1084,62 +1083,6 @@ SZ_PUBLIC void sz_lookup_ice(sz_ptr_t target, sz_size_t length, sz_cptr_t source } } -enum sz_encoding_t { - sz_encoding_unknown_k = 0, - sz_encoding_ascii_k = 1, - sz_encoding_utf8_k = 2, - sz_encoding_utf16_k = 3, - sz_encoding_utf32_k = 4, - sz_encoding_jwt_k = 5, - sz_encoding_base64_k = 6, - // Low priority encodings: - sz_encoding_utf8bom_k = 7, - sz_encoding_utf16le_k = 8, - sz_encoding_utf16be_k = 9, - sz_encoding_utf32le_k = 10, - sz_encoding_utf32be_k = 11, -}; - -// Character Set Detection is one of the most commonly performed operations in data processing with -// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), -// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. -// All of them are notoriously slow. -// -// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. -// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): -// - ISO-8859-1: 1.2% -// - Windows-1252: 0.3% -// - Windows-1251: 0.2% -// - EUC-JP: 0.1% -// - Shift JIS: 0.1% -// - EUC-KR: 0.1% -// - GB2312: 0.1% -// - Windows-1250: 0.1% -// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings -// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and -// the rest. -// -// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime -// and focuses more on incremental validation & transcoding, rather than detection. -// -// So we need a very fast and efficient way of determining -SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { - // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 - // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 - - // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory - // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints - // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte - // codepoints. In the case of emojis, we deal with 4-byte codepoints. - // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. - int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; - sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); - sz_unused(text && length); - return sz_false_k; -} - #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_ICE diff --git a/include/stringzilla/similarity.h b/include/stringzilla/similarity.h index 058b1313..6d65fcbe 100644 --- a/include/stringzilla/similarity.h +++ b/include/stringzilla/similarity.h @@ -309,7 +309,7 @@ SZ_INTERNAL sz_status_t _sz_levenshtein_distance_skewed_diagonals_serial( // } // TODO: Generalize to remove the following asserts! - _sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); + _sz_assert(bound >= longer_length && "For bounded search the method should only evaluate one band of the matrix."); _sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); sz_unused(longer_length && bound); @@ -860,7 +860,7 @@ SZ_INTERNAL sz_size_t _sz_levenshtein_distance_skewed_diagonals_upto63_ice( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return longer_length + 1; + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return bound; } // Now let's handle the anti-diagonal band of the matrix, between the top and bottom triangles. @@ -891,7 +891,7 @@ SZ_INTERNAL sz_size_t _sz_levenshtein_distance_skewed_diagonals_upto63_ice( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return longer_length + 1; + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return bound; } // Now let's handle the bottom right triangle. @@ -915,7 +915,7 @@ SZ_INTERNAL sz_size_t _sz_levenshtein_distance_skewed_diagonals_upto63_ice( // // Check if we can exit early - if none of the diagonals values are smaller than the upper distance bound. __mmask64 within_bound_mask = _mm512_cmple_epu8_mask(next_vec.zmm, bound_vec.zmm); - if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return longer_length + 1; + if (_ktestz_mask64_u8(within_bound_mask, next_diagonal_mask) == 1) return bound; // In every following iterations we take use a shorter prefix of each register, // but we don't need to update the `next_diagonal_mask` anymore... except for the early exit. diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index a8c077d6..6b0d598a 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -652,9 +652,10 @@ class range_rmatches { iterator(string_type haystack, matcher_type matcher) noexcept : matcher_(matcher), remaining_(haystack) { auto position = matcher_(remaining_); - remaining_.remove_suffix(position != string_type::npos - ? remaining_.size() - position - matcher_.needle_length() - : remaining_.size()); + remaining_.remove_suffix( // + position != string_type::npos // + ? remaining_.size() - position - matcher_.needle_length() + : remaining_.size()); } pointer operator->() const noexcept = delete; @@ -665,9 +666,10 @@ class range_rmatches { iterator &operator++() noexcept { remaining_.remove_suffix(matcher_.skip_length()); auto position = matcher_(remaining_); - remaining_.remove_suffix(position != string_type::npos - ? remaining_.size() - position - matcher_.needle_length() - : remaining_.size()); + remaining_.remove_suffix( // + position != string_type::npos // + ? remaining_.size() - position - matcher_.needle_length() + : remaining_.size()); return *this; } @@ -1100,12 +1102,10 @@ static void _call_free(void *ptr, sz_size_t n, void *allocator_state) noexcept { return reinterpret_cast(allocator_state)->deallocate(reinterpret_cast(ptr), n); } -template -static sz_u64_t _call_random_generator(void *state) noexcept { - generator_type_ &generator = *reinterpret_cast(state); - return generator(); -} - +/** + * @brief Helper function, wrapping a C++ allocator into a C-style allocator. + * @return Error code or success. All allocating functions may fail. + */ template static status_t _with_alloc(allocator_type_ &allocator, allocator_callback_ &&callback) noexcept { sz_memory_allocator_t alloc; @@ -2018,7 +2018,7 @@ class basic_string_slice { #pragma endregion /** - * @brief Memory-owning string class with a Small String Optimization. + * @brief Memory-owning string class with a Small String Optimization. * * @section API * @@ -2873,7 +2873,7 @@ class basic_string { } /** - * @brief Replaces @b (in-place) a range of characters with a given string. + * @brief Replaces @b (in-place) a range of characters with a given string. * @return `true` if the replacement was successful, `false` otherwise. */ bool try_replace(difference_type signed_start_offset, difference_type signed_end_offset, @@ -2929,9 +2929,9 @@ class basic_string { basic_string &insert(size_type offset, size_type repeats, char_type character) noexcept(false) { if (offset > size()) throw std::out_of_range("sz::basic_string::insert"); if (size() + repeats > max_size()) throw std::length_error("sz::basic_string::insert"); - if (!_with_alloc([&](sz_alloc_type &alloc) { return sz_string_expand(&string_, offset, repeats, &alloc); })) - throw std::bad_alloc(); - + raise(_with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, offset, repeats, &alloc) ? sz_success_k : sz_bad_alloc_k; + })); sz_fill(data() + offset, repeats, character); return *this; } @@ -2974,10 +2974,10 @@ class basic_string { } /** - * @brief Inserts @b (in-place) one ::character at the given iterator position. - * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. - * @throw `std::length_error` if the string is too long. - * @throw `std::bad_alloc` if the allocation fails. + * @brief Inserts @b (in-place) one ::character at the given iterator position. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. */ iterator insert(const_iterator it, char_type character) noexcept(false) { auto pos = range_length(cbegin(), it); @@ -3375,11 +3375,10 @@ class basic_string { size_type bytesum() const noexcept { return view().bytesum(); } /** - * @brief Overwrites the string with random binary data. - * + * @brief Overwrites the string with random binary data. * @param[in] nonce "Number used ONCE" to initialize the random number generator, @b don't repeat it! */ - basic_string &randomize(sz_u64_t nonce) noexcept { + basic_string &fill_random(sz_u64_t nonce) noexcept { sz_ptr_t start; sz_size_t length; sz_string_range(&string_, &start, &length); @@ -3755,8 +3754,8 @@ bool basic_string::try_preparing_replacement( // } /** - * @brief Helper function-like object to order string-view convertible objects with StringZilla. - * @see Similar to `std::less`: https://en.cppreference.com/w/cpp/utility/functional/less + * @brief Helper function-like object to order string-view convertible objects with StringZilla. + * @see Similar to `std::less`: https://en.cppreference.com/w/cpp/utility/functional/less * * Unlike the STL analog, doesn't require C++14 or including the heavy `` header. * Can be used to combine STL classes with StringZilla logic, like: `std::map`. diff --git a/rust/lib.rs b/rust/lib.rs index b3cbca4e..3d32227a 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -23,6 +23,24 @@ pub mod sz { bits: [u64; 4], } + pub type SortedIdx = usize; + + #[repr(C)] + pub struct Sequence { + pub handle: *const c_void, + pub count: usize, + pub get_start: Option *const c_void>, + pub get_length: Option usize>, + } + + /// A simple semantic version structure. + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct SemVer { + pub major: i32, + pub minor: i32, + pub patch: i32, + } + impl Byteset { /// Initializes a bit‑set to an empty collection (all characters banned). #[inline] @@ -102,7 +120,17 @@ pub mod sz { fn sz_fill_random(text: *mut c_void, length: usize, seed: u64); - // fn sz_sort() -> Status; + pub fn sz_sequence_argsort(sequence: *const Sequence, alloc: *const c_void, order: *mut SortedIdx) -> Status; + + pub fn sz_sequence_intersect( + first_sequence: *const Sequence, + second_sequence: *const Sequence, + alloc: *const c_void, + seed: u64, + intersection_size: *mut usize, + first_positions: *mut SortedIdx, + second_positions: *mut SortedIdx, + ) -> Status; pub fn sz_levenshtein_distance( a: *const c_void, @@ -162,14 +190,6 @@ pub mod sz { fn sz_lookup(target: *const c_void, length: usize, source: *const c_void, lut: *const u8) -> *const c_void; } - /// A simple semantic version structure. - #[derive(Debug, Copy, Clone, PartialEq, Eq)] - pub struct SemVer { - pub major: i32, - pub minor: i32, - pub patch: i32, - } - impl SemVer { pub const fn new(major: i32, minor: i32, patch: i32) -> Self { Self { major, minor, patch } From d19e8b83754b172dc6d57b639b0d8b5feb5b230f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 09:56:05 +0400 Subject: [PATCH 162/751] Fix: `find_1byte` signature compatibility Without this patch Clang raises "converts to incompatible function type" due to the following flag: `-Wcast-function-type-mismatch`. --- include/stringzilla/find.h | 53 ++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/include/stringzilla/find.h b/include/stringzilla/find.h index d3db653e..1cf99e3b 100644 --- a/include/stringzilla/find.h +++ b/include/stringzilla/find.h @@ -401,14 +401,25 @@ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b return vec; } +SZ_INTERNAL sz_cptr_t _sz_find_1byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_unused(n_length); //? We keep this argument only for `sz_find_t` signature compatibility. + return sz_find_byte_serial(h, h_length, n); +} + +SZ_INTERNAL sz_cptr_t _sz_rfind_1byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_unused(n_length); //? We keep this argument only for `sz_rfind_t` signature compatibility. + return sz_rfind_byte_serial(h, h_length, n); +} + /** * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ -SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { +SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. _sz_assert(h_length >= 2 && "The haystack is too short."); + sz_unused(n_length); //? We keep this argument only for `sz_find_t` signature compatibility. sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS @@ -459,10 +470,11 @@ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ -SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { +SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. _sz_assert(h_length >= 4 && "The haystack is too short."); + sz_unused(n_length); //? We keep this argument only for `sz_find_t` signature compatibility. sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS @@ -523,10 +535,11 @@ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ -SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { +SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. _sz_assert(h_length >= 3 && "The haystack is too short."); + sz_unused(n_length); //? We keep this argument only for `sz_find_t` signature compatibility. sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS @@ -753,24 +766,24 @@ SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, #if _SZ_IS_BIG_ENDIAN sz_find_t backends[] = { - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, + _sz_find_1byte_serial, + _sz_find_horspool_upto_256bytes_serial, + _sz_find_horspool_over_256bytes_serial, }; return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); #else sz_find_t backends[] = { // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_find_byte_serial, - (sz_find_t)_sz_find_2byte_serial, - (sz_find_t)_sz_find_3byte_serial, - (sz_find_t)_sz_find_4byte_serial, + _sz_find_1byte_serial, + _sz_find_2byte_serial, + _sz_find_3byte_serial, + _sz_find_4byte_serial, // To avoid constructing the skip-table, let's use the prefixed approach. - (sz_find_t)_sz_find_over_4bytes_serial, + _sz_find_over_4bytes_serial, // For longer needles - use skip tables. - (sz_find_t)_sz_find_horspool_upto_256bytes_serial, - (sz_find_t)_sz_find_horspool_over_256bytes_serial, + _sz_find_horspool_upto_256bytes_serial, + _sz_find_horspool_over_256bytes_serial, }; return backends[ @@ -790,16 +803,16 @@ SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n sz_find_t backends[] = { // For very short strings brute-force SWAR makes sense. - (sz_find_t)sz_rfind_byte_serial, + _sz_rfind_1byte_serial, // TODO: implement reverse-order SWAR for 2/3/4 byte variants. - // TODO: (sz_find_t)_sz_rfind_2byte_serial, - // TODO: (sz_find_t)_sz_rfind_3byte_serial, - // TODO: (sz_find_t)_sz_rfind_4byte_serial, + // TODO: _sz_rfind_2byte_serial, + // TODO: _sz_rfind_3byte_serial, + // TODO: _sz_rfind_4byte_serial, // To avoid constructing the skip-table, let's use the prefixed approach. - // (sz_find_t)_sz_rfind_over_4bytes_serial, + // _sz_rfind_over_4bytes_serial, // For longer needles - use skip tables. - (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, - (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, + _sz_rfind_horspool_upto_256bytes_serial, + _sz_rfind_horspool_over_256bytes_serial, }; return backends[ From 90540d3a29406874fa049438dc1ed8559bd80eb9 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 10:02:47 +0400 Subject: [PATCH 163/751] Fix: Unused Levenshtein tests --- scripts/test.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/test.cpp b/scripts/test.cpp index aefcb638..ebcc01fe 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1970,6 +1970,8 @@ int main(int argc, char const **argv) { test_search_with_misaligned_repetitions(); #endif + test_levenshtein_distances(); + std::printf("All tests passed... Unbelievable!\n"); return 0; } From feb415f74bf86f75f1841f45b88d3c681ae569c1 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 10:03:19 +0400 Subject: [PATCH 164/751] Fix: Variable in C++14 `constexpr` --- include/stringzilla/stringzilla.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index 6b0d598a..ab99bc5f 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -341,8 +341,8 @@ class basic_byteset { return *this; } - constexpr basic_byteset operator|(basic_byteset other) const noexcept { - basic_byteset result = *this; + sz_constexpr_if_cpp14 basic_byteset operator|(basic_byteset other) const noexcept { + basic_byteset result = *this; //? Variable declaration in a `constexpr` function is a C++14 extension result.bitset_._u64s[0] |= other.bitset_._u64s[0], result.bitset_._u64s[1] |= other.bitset_._u64s[1], result.bitset_._u64s[2] |= other.bitset_._u64s[2], result.bitset_._u64s[3] |= other.bitset_._u64s[3]; return result; From a7b35bad578482b405dc0307ecd62350717c1c9e Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 10:45:22 +0400 Subject: [PATCH 165/751] Make: Don't build `stringzilla_bare` on MacOS --- CMakeLists.txt | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15145843..1da1e36f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -369,21 +369,24 @@ if(${STRINGZILLA_BUILD_SHARED}) target_compile_definitions(stringzilla_shared PRIVATE "SZ_OVERRIDE_LIBC=1") target_include_directories(stringzilla_shared PUBLIC include) - - # Try compiling a version without linking the LibC - define_shared(stringzilla_bare) - target_compile_definitions(stringzilla_bare PRIVATE "SZ_AVOID_LIBC=1") - target_compile_definitions(stringzilla_bare PRIVATE "SZ_OVERRIDE_LIBC=1") - target_include_directories(stringzilla_bare PUBLIC include) - - # Avoid built-ins on MSVC and other compilers, as that will cause compilation errors - target_compile_options(stringzilla_bare PRIVATE - "$<$:-fno-builtin;-nostdlib>" - "$<$:/Oi-;/GS->") - target_link_options(stringzilla_bare PRIVATE "$<$:-nostdlib>") - target_link_options(stringzilla_bare PRIVATE "$<$:/NODEFAULTLIB>") - + # Try compiling a version without linking the LibC + # ! This is only for Linux and Windows, as on modern Arm-based MacOS machines + # ! we can't legally access Arm's "feature registers" without `sysctl` or `sysctlbyname`. + # So let's check if we are compiling for a Darwin-based OS. + if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + define_shared(stringzilla_bare) + target_compile_definitions(stringzilla_bare PRIVATE "SZ_AVOID_LIBC=1") + target_compile_definitions(stringzilla_bare PRIVATE "SZ_OVERRIDE_LIBC=1") + target_include_directories(stringzilla_bare PUBLIC include) + + # Avoid built-ins on MSVC and other compilers, as that will cause compilation errors + target_compile_options(stringzilla_bare PRIVATE + "$<$:-fno-builtin;-nostdlib>" + "$<$:/Oi-;/GS->") + target_link_options(stringzilla_bare PRIVATE "$<$:-nostdlib>") + target_link_options(stringzilla_bare PRIVATE "$<$:/NODEFAULTLIB>") + endif() endif() if(STRINGZILLA_INSTALL) From f712de33b4c6fcadac27bf4b59feb7b01234f234 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 18:38:15 +0000 Subject: [PATCH 166/751] Fix: `sz_intersect` signature --- c/lib.c | 18 +++++++++--------- include/stringzilla/types.h | 11 ++++++----- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/c/lib.c b/c/lib.c index 9c6324dd..7f1219ca 100644 --- a/c/lib.c +++ b/c/lib.c @@ -194,7 +194,7 @@ typedef struct sz_implementations_t { sz_needleman_wunsch_score_t alignment_score; sz_sequence_argsort_t sequence_argsort; - sz_sequence_join_t sequence_join; + sz_sequence_intersect_t sequence_intersect; sz_pgrams_sort_t pgrams_sort; } sz_implementations_t; @@ -239,7 +239,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->alignment_score = sz_needleman_wunsch_score_serial; impl->sequence_argsort = sz_sequence_argsort_serial; - impl->sequence_join = sz_sequence_join_serial; + impl->sequence_intersect = sz_sequence_intersect_serial; impl->pgrams_sort = sz_pgrams_sort_serial; #if SZ_USE_HASWELL @@ -291,7 +291,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->bytesum = sz_bytesum_skylake; impl->sequence_argsort = sz_sequence_argsort_skylake; - impl->sequence_join = sz_sequence_join_skylake; + impl->sequence_intersect = sz_sequence_intersect_skylake; impl->pgrams_sort = sz_pgrams_sort_skylake; } #endif @@ -343,7 +343,7 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { #if SZ_USE_SVE if (caps & sz_cap_sve_k) { impl->sequence_argsort = sz_sequence_argsort_sve; - impl->sequence_join = sz_sequence_join_sve; + impl->sequence_intersect = sz_sequence_intersect_sve; impl->pgrams_sort = sz_pgrams_sort_sve; } #endif @@ -507,11 +507,11 @@ SZ_DYNAMIC sz_status_t sz_sequence_argsort(sz_sequence_t const *array, sz_memory return sz_dispatch_table.sequence_argsort(array, alloc, order); } -SZ_DYNAMIC sz_status_t sz_sequence_join(sz_sequence_t const *first_array, sz_sequence_t const *second_array, - sz_memory_allocator_t *alloc, sz_size_t *intersection_size, - sz_size_t *first_positions, sz_size_t *second_positions) { - return sz_dispatch_table.sequence_join(first_array, second_array, alloc, intersection_size, first_positions, - second_positions); +SZ_DYNAMIC sz_status_t sz_sequence_intersect(sz_sequence_t const *first_array, sz_sequence_t const *second_array, + sz_memory_allocator_t *alloc, sz_u64_t seed, sz_size_t *intersection_size, + sz_size_t *first_positions, sz_size_t *second_positions) { + return sz_dispatch_table.sequence_intersect(first_array, second_array, alloc, seed, intersection_size, + first_positions, second_positions); } // Provide overrides for the libc mem* functions diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 6d693347..11366304 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -565,9 +565,10 @@ typedef sz_status_t (*sz_sequence_argsort_t)(struct sz_sequence_t const *, sz_me /** @brief Signature of `sz_pgrams_sort`. */ typedef sz_status_t (*sz_pgrams_sort_t)(sz_pgram_t *, sz_size_t, sz_memory_allocator_t *, sz_sorted_idx_t *); -/** @brief Signature of `sz_sequence_join`. */ -typedef sz_status_t (*sz_sequence_join_t)(struct sz_sequence_t const *, struct sz_sequence_t const *, - sz_memory_allocator_t *, sz_size_t *, sz_sorted_idx_t *, sz_sorted_idx_t *); +/** @brief Signature of `sz_sequence_intersect`. */ +typedef sz_status_t (*sz_sequence_intersect_t)(struct sz_sequence_t const *, struct sz_sequence_t const *, + sz_memory_allocator_t *, sz_u64_t, sz_size_t *, sz_sorted_idx_t *, + sz_sorted_idx_t *); #pragma endregion @@ -726,9 +727,9 @@ SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_le #pragma region String Sequences API /** @brief Signature of `sz_sequence_t::get_start` used to get the start of a member string at a given index. */ -typedef sz_cptr_t (*sz_sequence_member_start_t)(void const *, sz_size_t); +typedef sz_cptr_t (*sz_sequence_member_start_t)(void const *, sz_sorted_idx_t); /** @brief Signature of `sz_sequence_t::get_length` used to get the length of a member string at a given index. */ -typedef sz_size_t (*sz_sequence_member_length_t)(void const *, sz_size_t); +typedef sz_size_t (*sz_sequence_member_length_t)(void const *, sz_sorted_idx_t); /** * @brief Structure to represent an ordered collection of strings. From 8bb90e54210bf5d1e3e626de12ff6086c3ad358e Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 18:42:04 +0000 Subject: [PATCH 167/751] Fix: Unused `_sz_capabilities` symbols --- c/lib.c | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/c/lib.c b/c/lib.c index 7f1219ca..7aea9455 100644 --- a/c/lib.c +++ b/c/lib.c @@ -51,6 +51,8 @@ extern void *malloc(size_t length); #include // `DllMain` #endif +#if _SZ_IS_ARM64 + /** * @brief Function to determine the SIMD capabilities of the current 64-bit Arm machine at @b runtime. * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. @@ -65,8 +67,8 @@ SZ_INTERNAL sz_capability_t _sz_capabilities_arm(void) { size_t size = sizeof(supports_neon); if (sysctlbyname("hw.optional.neon", &supports_neon, &size, NULL, 0) != 0) supports_neon = 0; - return (sz_capability_t)( // - (sz_cap_arm_neon_k * (supports_neon)) | // + return (sz_capability_t)( // + (sz_cap_neon_k * (supports_neon)) | // (sz_cap_serial_k)); #elif defined(_SZ_IS_LINUX) @@ -107,6 +109,10 @@ SZ_INTERNAL sz_capability_t _sz_capabilities_arm(void) { #endif } +#endif // _SZ_IS_ARM64 + +#if _SZ_IS_X86_64 + SZ_INTERNAL sz_capability_t _sz_capabilities_x86(void) { #if SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE @@ -152,6 +158,7 @@ SZ_INTERNAL sz_capability_t _sz_capabilities_x86(void) { return sz_cap_serial_k; #endif } +#endif // _SZ_IS_X86_64 /** * @brief Function to determine the SIMD capabilities of the current 64-bit x86 machine at @b runtime. From 63f0368dc869c0d977c308a9015e03211f5a4866 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 18:44:14 +0000 Subject: [PATCH 168/751] Add: Missing SVE placeholder definition Only declarations were present for SVE --- include/stringzilla/hash.h | 18 ++++++++++++++++++ include/stringzilla/intersect.h | 22 ++++++++++++++++++++++ include/stringzilla/sort.h | 23 +++++++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 1db8a4b3..aedbff89 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -1833,6 +1833,24 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_neon(sz_hash_state_t const *state) { retur #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) +SZ_PUBLIC sz_u64_t sz_bytesum_sve(sz_cptr_t text, sz_size_t length) { return sz_bytesum_serial(text, length); } + +SZ_PUBLIC void sz_hash_state_init_sve(sz_hash_state_t *state, sz_u64_t seed) { sz_hash_state_init_serial(state, seed); } + +SZ_PUBLIC void sz_hash_state_stream_sve(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { + sz_hash_state_stream_serial(state, text, length); +} + +SZ_PUBLIC sz_u64_t sz_hash_state_fold_sve(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } + +SZ_PUBLIC sz_u64_t sz_hash_sve(sz_cptr_t text, sz_size_t length, sz_u64_t seed) { + return sz_hash_serial(text, length, seed); +} + +SZ_PUBLIC void sz_fill_random_sve(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_fill_random_serial(text, length, nonce); +} + #pragma clang attribute pop #pragma GCC pop_options #endif // SZ_USE_SVE diff --git a/include/stringzilla/intersect.h b/include/stringzilla/intersect.h index 77033148..b3610969 100644 --- a/include/stringzilla/intersect.h +++ b/include/stringzilla/intersect.h @@ -713,6 +713,28 @@ SZ_PUBLIC sz_status_t sz_sequence_intersect_ice( #endif // SZ_USE_ICE #pragma endregion // Ice Lake Implementation +#pragma region SVE Implementation +#if SZ_USE_SVE +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+sve") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) + +SZ_PUBLIC sz_status_t sz_sequence_intersect_sve(sz_sequence_t const *first_sequence, + sz_sequence_t const *second_sequence, // + sz_memory_allocator_t *alloc, sz_u64_t seed, + sz_size_t *intersection_size, sz_sorted_idx_t *first_positions, + sz_sorted_idx_t *second_positions) { + return sz_sequence_intersect_serial( // + first_sequence, second_sequence, // + alloc, seed, intersection_size, // + first_positions, second_positions); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SZ_USE_SVE +#pragma endregion // SVE Implementation + /* Pick the right implementation for the string search algorithms. * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. */ diff --git a/include/stringzilla/sort.h b/include/stringzilla/sort.h index af808cb5..b4d487bd 100644 --- a/include/stringzilla/sort.h +++ b/include/stringzilla/sort.h @@ -925,6 +925,29 @@ SZ_PUBLIC sz_status_t sz_sequence_argsort_skylake(sz_sequence_t const *sequence, #endif // SZ_USE_SKYLAKE #pragma endregion // Ice Lake Implementation +#pragma region SVE Implementation +#if SZ_USE_SVE +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+sve") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) + +/** @copydoc sz_sequence_argsort */ +SZ_PUBLIC sz_status_t sz_sequence_argsort_sve(sz_sequence_t const *sequence, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + return sz_sequence_argsort_serial(sequence, alloc, order); +} + +/** @copydoc sz_pgrams_sort */ +SZ_PUBLIC sz_status_t sz_pgrams_sort_sve(sz_pgram_t *pgrams, sz_size_t count, sz_memory_allocator_t *alloc, + sz_sorted_idx_t *order) { + return sz_pgrams_sort_serial(pgrams, count, alloc, order); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SZ_USE_SVE +#pragma endregion // SVE Implementation + /* Pick the right implementation for the string search algorithms. * To override this behavior and precompile all backends - set `SZ_DYNAMIC_DISPATCH` to 1. */ From 2965502c82a7e96c3252b745e029f713f8c2b1f9 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 18:44:58 +0000 Subject: [PATCH 169/751] Fix: Guard Skylake benchmarks --- scripts/bench_sort.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/bench_sort.cpp b/scripts/bench_sort.cpp index f32d9909..8cdf97fd 100644 --- a/scripts/bench_sort.cpp +++ b/scripts/bench_sort.cpp @@ -112,6 +112,7 @@ int main(int argc, char const **argv) { }); expect_sorted(pgrams, permute); +#if SZ_USE_SKYLAKE bench_permute("sz_pgrams_sort_skylake", [&]() { std::copy(pgrams.begin(), pgrams.end(), pgrams_sorted.begin()); std::iota(permute.begin(), permute.end(), 0); @@ -120,6 +121,7 @@ int main(int argc, char const **argv) { }); }); expect_sorted(pgrams, permute); +#endif // Sorting strings bench_permute("std::sort(positions)", [&]() { @@ -140,7 +142,7 @@ int main(int argc, char const **argv) { [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort_serial(&array, &alloc, permute.data()); }); }); expect_sorted(strings, permute); - +#if SZ_USE_SKYLAKE bench_permute("sz_sequence_argsort_skylake", [&]() { std::iota(permute.begin(), permute.end(), 0); sz_sequence_t array; @@ -152,6 +154,7 @@ int main(int argc, char const **argv) { [&](sz_memory_allocator_t &alloc) { return sz_sequence_argsort_skylake(&array, &alloc, permute.data()); }); }); expect_sorted(strings, permute); +#endif #if __linux__ && defined(_GNU_SOURCE) && !defined(__BIONIC__) bench_permute("qsort_r", [&]() { From 4b3847df38feef40e8ad4bf53eab3dc5c6f1c83d Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 19:05:26 +0000 Subject: [PATCH 170/751] Add: Arm NEON hashing --- include/stringzilla/hash.h | 379 ++++++++++++++++++++++++++++++++++-- include/stringzilla/types.h | 12 ++ 2 files changed, 374 insertions(+), 17 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index aedbff89..eb748ef4 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -61,8 +61,6 @@ * @see The serial AES routines are based on Morten Jensen's "tiny-AES-c": https://github.com/kokke/tiny-AES-c * @see The "xxHash" C implementation by Yann Collet: https://github.com/Cyan4973/xxHash * @see The "aHash" Rust implementation by Tom Kaitchuck: https://github.com/tkaitchuck/aHash - * @see "Emulating x86 AES Intrinsics on ARMv8-A" by Michael Brase: - * https://blog.michaelbrase.com/2018/05/08/emulating-x86-aes-intrinsics-on-armv8-a/ * * Moreover, the same AES primitives are reused to implement a fast Pseudo-Random Number Generator @b (PRNG) that * is consistent between different implementation backends and has reproducible output with the same "nonce". @@ -877,7 +875,7 @@ SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) for (int i = 0; i < 4; ++i) state->aes.xmms[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2))); for (int i = 0; i < 4; ++i) - state->sum.xmms[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2 + 8))); + state->sum.u64x2s[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2 + 8))); // The inputs are zeroed out at the beginning state->ins.xmms[0] = state->ins.xmms[1] = state->ins.xmms[2] = state->ins.xmms[3] = _mm_setzero_si128(); @@ -887,23 +885,23 @@ SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) SZ_INTERNAL void _sz_hash_state_update_haswell(sz_hash_state_t *state) { __m128i const shuffle_mask = _mm_load_si128((__m128i const *)_sz_hash_u8x16x4_shuffle()); state->aes.xmms[0] = _mm_aesenc_si128(state->aes.xmms[0], state->ins.xmms[0]); - state->sum.xmms[0] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[0], shuffle_mask), state->ins.xmms[0]); + state->sum.u64x2s[0] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[0], shuffle_mask), state->ins.xmms[0]); state->aes.xmms[1] = _mm_aesenc_si128(state->aes.xmms[1], state->ins.xmms[1]); - state->sum.xmms[1] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[1], shuffle_mask), state->ins.xmms[1]); + state->sum.u64x2s[1] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[1], shuffle_mask), state->ins.xmms[1]); state->aes.xmms[2] = _mm_aesenc_si128(state->aes.xmms[2], state->ins.xmms[2]); - state->sum.xmms[2] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[2], shuffle_mask), state->ins.xmms[2]); + state->sum.u64x2s[2] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[2], shuffle_mask), state->ins.xmms[2]); state->aes.xmms[3] = _mm_aesenc_si128(state->aes.xmms[3], state->ins.xmms[3]); - state->sum.xmms[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[3], shuffle_mask), state->ins.xmms[3]); + state->sum.u64x2s[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[3], shuffle_mask), state->ins.xmms[3]); } SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *state) { // Mix the length into the key __m128i key_with_length = _mm_add_epi64(state->key.xmm, _mm_set_epi64x(0, state->ins_length)); // Combine the "sum" and the "AES" blocks - __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.xmms[0], state->aes.xmms[0]); - __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.xmms[1], state->aes.xmms[1]); - __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.xmms[2], state->aes.xmms[2]); - __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.xmms[3], state->aes.xmms[3]); + __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.u64x2s[0], state->aes.xmms[0]); + __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.u64x2s[1], state->aes.xmms[1]); + __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.u64x2s[2], state->aes.xmms[2]); + __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.u64x2s[3], state->aes.xmms[3]); // Combine the mixed registers __m128i mixed_registers01 = _mm_aesenc_si128(mixed_registers0, mixed_registers1); __m128i mixed_registers23 = _mm_aesenc_si128(mixed_registers2, mixed_registers3); @@ -1045,7 +1043,7 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state) { _sz_hash_minimal_t minimal_state; minimal_state.key.xmm = state->key.xmm; minimal_state.aes.xmm = state->aes.xmms[0]; - minimal_state.sum.xmm = state->sum.xmms[0]; + minimal_state.sum.xmm = state->sum.u64x2s[0]; // The logic is different depending on the length of the input __m128i const *ins_vecs = (__m128i const *)&state->ins.xmms[0]; @@ -1788,8 +1786,8 @@ SZ_INTERNAL void _sz_hash_minimal_x4_update_ice(_sz_hash_minimal_x4_t *state, __ #pragma region NEON Implementation #if SZ_USE_NEON #pragma GCC push_options -#pragma GCC target("arch=armv8.2-a+simd") -#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) +#pragma GCC target("arch=armv8.2-a+simd+crypto") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+crypto"))), apply_to = function) SZ_PUBLIC sz_u64_t sz_bytesum_neon(sz_cptr_t text, sz_size_t length) { uint64x2_t sum_vec = vdupq_n_u64(0); @@ -1809,15 +1807,362 @@ SZ_PUBLIC sz_u64_t sz_bytesum_neon(sz_cptr_t text, sz_size_t length) { return sum; } +/** + * @brief Emulates the Intel's AES-NI `AESENC` instruction on Arm NEON. + * @see "Emulating x86 AES Intrinsics on ARMv8-A" by Michael Brase: + * https://blog.michaelbrase.com/2018/05/08/emulating-x86-aes-intrinsics-on-armv8-a/ + */ +SZ_INTERNAL uint8x16_t _sz_emulate_aesenc_u8x16_neon(uint8x16_t state_vec, uint8x16_t round_key_vec) { + return veorq_u8(vaesmcq_u8(vaeseq_u8(state_vec, vdupq_n_u8(0))), round_key_vec); +} + +SZ_INTERNAL uint64x2_t _sz_emulate_aesenc_u64x2_neon(uint64x2_t state_vec, uint64x2_t round_key_vec) { + return vreinterpretq_u64_u8( // + _sz_emulate_aesenc_u8x16_neon( // + vreinterpretq_u8_u64(state_vec), // + vreinterpretq_u8_u64(round_key_vec))); +} + +SZ_INTERNAL void _sz_hash_minimal_init_neon(_sz_hash_minimal_t *state, sz_u64_t seed) { + + // The key is made from the seed and half of it will be mixed with the length in the end + uint64x2_t seed_vec = vdupq_n_u64(seed); + state->key.u64x2 = seed_vec; + + // XOR the user-supplied keys with the two "pi" constants + sz_u64_t const *pi = _sz_hash_pi_constants(); + uint64x2_t const pi0 = vld1q_u64(pi); + uint64x2_t const pi1 = vld1q_u64(pi + 8); + uint64x2_t k1 = veorq_u64(seed_vec, pi0); + uint64x2_t k2 = veorq_u64(seed_vec, pi1); + + // The first 128 bits of the "sum" and "AES" blocks are the same for the "minimal" and full state + state->aes.u64x2 = k1; + state->sum.u64x2 = k2; +} + +SZ_INTERNAL sz_u64_t _sz_hash_minimal_finalize_neon(_sz_hash_minimal_t const *state, sz_size_t length) { + // Mix the length into the key + uint64x2_t key_with_length = vaddq_u64(state->key.u64x2, vsetq_lane_u64(length, vdupq_n_u64(0), 0)); + // Combine the "sum" and the "AES" blocks + uint8x16_t mixed_registers = _sz_emulate_aesenc_u8x16_neon(state->sum.u8x16, state->aes.u8x16); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + uint8x16_t mixed_within_register = _sz_emulate_aesenc_u8x16_neon( + _sz_emulate_aesenc_u8x16_neon(mixed_registers, vreinterpretq_u8_u64(key_with_length)), mixed_registers); + // Extract the low 64 bits + return vgetq_lane_u64(vreinterpretq_u64_u8(mixed_within_register), 0); +} + +SZ_INTERNAL void _sz_hash_minimal_update_neon(_sz_hash_minimal_t *state, uint8x16_t block) { + uint8x16_t const shuffle_mask = vld1q_u8(_sz_hash_u8x16x4_shuffle()); + state->aes.u8x16 = _sz_emulate_aesenc_u8x16_neon(state->aes.u8x16, block); + uint8x16_t sum_shuffled = vqtbl1q_u8(vreinterpretq_u8_u64(state->sum.u64x2), shuffle_mask); + state->sum.u64x2 = vaddq_u64(vreinterpretq_u64_u8(sum_shuffled), vreinterpretq_u64_u8(block)); +} + SZ_PUBLIC void sz_hash_state_init_neon(sz_hash_state_t *state, sz_u64_t seed) { - sz_hash_state_init_serial(state, seed); + // The key is made from the seed and half of it will be mixed with the length in the end + uint64x2_t seed_vec = vdupq_n_u64(seed); + state->key.u64x2 = seed_vec; + + // XOR the user-supplied keys with the two "pi" constants + sz_u64_t const *pi = _sz_hash_pi_constants(); + for (int i = 0; i < 4; ++i) state->aes.u64x2s[i] = veorq_u64(seed_vec, vld1q_u64(pi + i * 2)); + for (int i = 0; i < 4; ++i) state->sum.u64x2s[i] = veorq_u64(seed_vec, vld1q_u64(pi + i * 2 + 8)); + + // The inputs are zeroed out at the beginning + state->ins.u8x16s[0] = state->ins.u8x16s[1] = state->ins.u8x16s[2] = state->ins.u8x16s[3] = vdupq_n_u8(0); + state->ins_length = 0; +} + +SZ_INTERNAL void _sz_hash_state_update_neon(sz_hash_state_t *state) { + uint8x16_t const shuffle_mask = vld1q_u8(_sz_hash_u8x16x4_shuffle()); + state->aes.u8x16s[0] = _sz_emulate_aesenc_u8x16_neon(state->aes.u8x16s[0], state->ins.u8x16s[0]); + uint8x16_t sum_shuffled0 = vqtbl1q_u8(vreinterpretq_u8_u64(state->sum.u64x2s[0]), shuffle_mask); + state->sum.u64x2s[0] = vaddq_u64(vreinterpretq_u64_u8(sum_shuffled0), state->ins.u64x2s[0]); + state->aes.u8x16s[1] = _sz_emulate_aesenc_u8x16_neon(state->aes.u8x16s[1], state->ins.u8x16s[1]); + uint8x16_t sum_shuffled1 = vqtbl1q_u8(vreinterpretq_u8_u64(state->sum.u64x2s[1]), shuffle_mask); + state->sum.u64x2s[1] = vaddq_u64(vreinterpretq_u64_u8(sum_shuffled1), state->ins.u64x2s[1]); + state->aes.u8x16s[2] = _sz_emulate_aesenc_u8x16_neon(state->aes.u8x16s[2], state->ins.u8x16s[2]); + uint8x16_t sum_shuffled2 = vqtbl1q_u8(vreinterpretq_u8_u64(state->sum.u64x2s[2]), shuffle_mask); + state->sum.u64x2s[2] = vaddq_u64(vreinterpretq_u64_u8(sum_shuffled2), state->ins.u64x2s[2]); + state->aes.u8x16s[3] = _sz_emulate_aesenc_u8x16_neon(state->aes.u8x16s[3], state->ins.u8x16s[3]); + uint8x16_t sum_shuffled3 = vqtbl1q_u8(vreinterpretq_u8_u64(state->sum.u64x2s[3]), shuffle_mask); + state->sum.u64x2s[3] = vaddq_u64(vreinterpretq_u64_u8(sum_shuffled3), state->ins.u64x2s[3]); +} + +SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_neon(sz_hash_state_t const *state) { + // Mix the length into the key + uint64x2_t key_with_length = vaddq_u64(state->key.u64x2, vsetq_lane_u64(state->ins_length, vdupq_n_u64(0), 0)); + // Combine the "sum" and the "AES" blocks + uint8x16_t mixed_registers0 = _sz_emulate_aesenc_u8x16_neon(state->sum.u8x16s[0], state->aes.u8x16s[0]); + uint8x16_t mixed_registers1 = _sz_emulate_aesenc_u8x16_neon(state->sum.u8x16s[1], state->aes.u8x16s[1]); + uint8x16_t mixed_registers2 = _sz_emulate_aesenc_u8x16_neon(state->sum.u8x16s[2], state->aes.u8x16s[2]); + uint8x16_t mixed_registers3 = _sz_emulate_aesenc_u8x16_neon(state->sum.u8x16s[3], state->aes.u8x16s[3]); + // Combine the mixed registers + uint8x16_t mixed_registers01 = _sz_emulate_aesenc_u8x16_neon(mixed_registers0, mixed_registers1); + uint8x16_t mixed_registers23 = _sz_emulate_aesenc_u8x16_neon(mixed_registers2, mixed_registers3); + uint8x16_t mixed_registers = _sz_emulate_aesenc_u8x16_neon(mixed_registers01, mixed_registers23); + // Make sure the "key" mixes enough with the state, + // as with less than 2 rounds - SMHasher fails + uint8x16_t mixed_within_register = _sz_emulate_aesenc_u8x16_neon( + _sz_emulate_aesenc_u8x16_neon(mixed_registers, vreinterpretq_u8_u64(key_with_length)), mixed_registers); + // Extract the low 64 bits + return vgetq_lane_u64(vreinterpretq_u64_u8(mixed_within_register), 0); } SZ_PUBLIC void sz_hash_state_stream_neon(sz_hash_state_t *state, sz_cptr_t text, sz_size_t length) { - sz_hash_state_stream_serial(state, text, length); + // This whole function is identical to Haswell. + while (length) { + // Append to the internal buffer until it's full + if (state->ins_length % 64 == 0 && length >= 64) { + state->ins.u8x16s[0] = vld1q_u8((sz_u8_t const *)text); + state->ins.u8x16s[1] = vld1q_u8((sz_u8_t const *)(text + 16)); + state->ins.u8x16s[2] = vld1q_u8((sz_u8_t const *)(text + 32)); + state->ins.u8x16s[3] = vld1q_u8((sz_u8_t const *)(text + 48)); + _sz_hash_state_update_neon(state); + state->ins_length += 64; + text += 64; + length -= 64; + } + // If vectorization isn't that trivial - fall back to the serial implementation + else { + sz_size_t progress_in_block = state->ins_length % 64; + sz_size_t to_copy = sz_min_of_two(length, 64 - progress_in_block); + int const will_fill_block = progress_in_block + to_copy == 64; + // Update the metadata before we modify the `to_copy` variable + state->ins_length += to_copy; + length -= to_copy; + // Append to the internal buffer until it's full + while (to_copy--) state->ins.u8s[progress_in_block++] = *text++; + // If we've reached the end of the buffer, update the state + if (will_fill_block) { + _sz_hash_state_update_neon(state); + // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state + for (int i = 0; i < 4; ++i) state->ins.u8x16s[i] = vdupq_n_u8(0); + } + } + } } -SZ_PUBLIC sz_u64_t sz_hash_state_fold_neon(sz_hash_state_t const *state) { return sz_hash_state_fold_serial(state); } +SZ_PUBLIC sz_u64_t sz_hash_state_fold_neon(sz_hash_state_t const *state) { + // This whole function is identical to Haswell. + sz_size_t length = state->ins_length; + if (length >= 64) return _sz_hash_state_finalize_neon(state); + + // Switch back to a smaller "minimal" state for small inputs + _sz_hash_minimal_t minimal_state; + minimal_state.key.u8x16 = state->key.u8x16; + minimal_state.aes.u8x16 = state->aes.u8x16s[0]; + minimal_state.sum.u8x16 = state->sum.u8x16s[0]; + + // The logic is different depending on the length of the input + uint8x16_t const *ins_vecs = (uint8x16_t const *)&state->ins.u8x16s[0]; + if (length <= 16) { + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[0]); + return _sz_hash_minimal_finalize_neon(&minimal_state, length); + } + else if (length <= 32) { + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[1]); + return _sz_hash_minimal_finalize_neon(&minimal_state, length); + } + else if (length <= 48) { + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[1]); + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[2]); + return _sz_hash_minimal_finalize_neon(&minimal_state, length); + } + else { + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[0]); + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[1]); + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[2]); + _sz_hash_minimal_update_neon(&minimal_state, ins_vecs[3]); + return _sz_hash_minimal_finalize_neon(&minimal_state, length); + } +} + +SZ_PUBLIC sz_u64_t sz_hash_neon(sz_cptr_t start, sz_size_t length, sz_u64_t seed) { + if (length <= 16) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_neon(&state, seed); + // Load the data and update the state + sz_u128_vec_t data_vec; + data_vec.u8x16 = vdupq_n_u8(0); + for (sz_size_t i = 0; i < length; ++i) data_vec.u8s[i] = start[i]; + _sz_hash_minimal_update_neon(&state, data_vec.u8x16); + return _sz_hash_minimal_finalize_neon(&state, length); + } + else if (length <= 32) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_neon(&state, seed); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec; + data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start)); + data1_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + length - 16)); + // Let's shift the data within the register to de-interleave the bytes. + _sz_hash_shift_in_register_serial(&data1_vec, 32 - length); + _sz_hash_minimal_update_neon(&state, data0_vec.u8x16); + _sz_hash_minimal_update_neon(&state, data1_vec.u8x16); + return _sz_hash_minimal_finalize_neon(&state, length); + } + else if (length <= 48) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_neon(&state, seed); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec; + data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start)); + data1_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 16)); + data2_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + length - 16)); + // Let's shift the data within the register to de-interleave the bytes. + _sz_hash_shift_in_register_serial(&data2_vec, 48 - length); + _sz_hash_minimal_update_neon(&state, data0_vec.u8x16); + _sz_hash_minimal_update_neon(&state, data1_vec.u8x16); + _sz_hash_minimal_update_neon(&state, data2_vec.u8x16); + return _sz_hash_minimal_finalize_neon(&state, length); + } + else if (length <= 64) { + // Initialize the AES block with a given seed + _sz_hash_minimal_t state; + _sz_hash_minimal_init_neon(&state, seed); + // Load the data and update the state + sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; + data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start)); + data1_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 16)); + data2_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 32)); + data3_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + length - 16)); + // Let's shift the data within the register to de-interleave the bytes. + _sz_hash_shift_in_register_serial(&data3_vec, 64 - length); + _sz_hash_minimal_update_neon(&state, data0_vec.u8x16); + _sz_hash_minimal_update_neon(&state, data1_vec.u8x16); + _sz_hash_minimal_update_neon(&state, data2_vec.u8x16); + _sz_hash_minimal_update_neon(&state, data3_vec.u8x16); + return _sz_hash_minimal_finalize_neon(&state, length); + } + else { + // Use a larger state to handle the main loop and add different offsets + // to different lanes of the register + sz_hash_state_t state; + sz_hash_state_init_neon(&state, seed); + for (; state.ins_length + 64 <= length; state.ins_length += 64) { + state.ins.u8x16s[0] = vld1q_u8((sz_u8_t const *)(start + state.ins_length)); + state.ins.u8x16s[1] = vld1q_u8((sz_u8_t const *)(start + state.ins_length + 16)); + state.ins.u8x16s[2] = vld1q_u8((sz_u8_t const *)(start + state.ins_length + 32)); + state.ins.u8x16s[3] = vld1q_u8((sz_u8_t const *)(start + state.ins_length + 48)); + _sz_hash_state_update_neon(&state); + } + // Handle the tail, resetting the registers to zero first + if (state.ins_length < length) { + state.ins.u8x16s[0] = vdupq_n_u8(0); + state.ins.u8x16s[1] = vdupq_n_u8(0); + state.ins.u8x16s[2] = vdupq_n_u8(0); + state.ins.u8x16s[3] = vdupq_n_u8(0); + for (sz_size_t i = 0; state.ins_length < length; ++i, ++state.ins_length) + state.ins.u8s[i] = start[state.ins_length]; + _sz_hash_state_update_neon(&state); + state.ins_length = length; + } + return _sz_hash_state_finalize_neon(&state); + } +} + +SZ_PUBLIC void sz_fill_random_neon(sz_ptr_t text, sz_size_t length, sz_u64_t nonce) { + sz_u64_t const *pi_ptr = _sz_hash_pi_constants(); + if (length <= 16) { + uint64x2_t input = vdupq_n_u64(nonce); + uint64x2_t pi = vld1q_u64(pi_ptr); + uint64x2_t key = veorq_u64(vdupq_n_u64(nonce), pi); + uint64x2_t generated = _sz_emulate_aesenc_u64x2_neon(input, key); + // Now the tricky part is outputting this data to the user-supplied buffer + // without masked writes, like in AVX-512. + for (sz_size_t i = 0; i < length; ++i) text[i] = ((sz_u8_t *)&generated)[i]; + } + // Assuming the YMM register contains two 128-bit blocks, the input to the generator + // will be more complex, containing the sum of the nonce and the block number. + else if (length <= 32) { + uint64x2_t inputs[2], pis[2], keys[2], generated[2]; + inputs[0] = vdupq_n_u64(nonce + 0); + inputs[1] = vdupq_n_u64(nonce + 1); + pis[0] = vld1q_u64(pi_ptr + 0); + pis[1] = vld1q_u64(pi_ptr + 2); + keys[0] = veorq_u64(vdupq_n_u64(nonce), pis[0]); + keys[1] = veorq_u64(vdupq_n_u64(nonce), pis[1]); + generated[0] = _sz_emulate_aesenc_u64x2_neon(inputs[0], keys[0]); + generated[1] = _sz_emulate_aesenc_u64x2_neon(inputs[1], keys[1]); + // The first store can easily be vectorized, but the second can be serial for now + vst1q_u64((sz_u64_t *)(text), generated[0]); + for (sz_size_t i = 16; i < length; ++i) text[i] = ((sz_u8_t *)&generated[1])[i - 16]; + } + // The last special case we handle outside of the primary loop is for buffers up to 64 bytes long. + else if (length <= 48) { + uint64x2_t inputs[3], pis[3], keys[3], generated[3]; + inputs[0] = vdupq_n_u64(nonce); + inputs[1] = vdupq_n_u64(nonce + 1); + inputs[2] = vdupq_n_u64(nonce + 2); + pis[0] = vld1q_u64(pi_ptr + 0); + pis[1] = vld1q_u64(pi_ptr + 2); + pis[2] = vld1q_u64(pi_ptr + 4); + keys[0] = veorq_u64(vdupq_n_u64(nonce), pis[0]); + keys[1] = veorq_u64(vdupq_n_u64(nonce), pis[1]); + keys[2] = veorq_u64(vdupq_n_u64(nonce), pis[2]); + generated[0] = _sz_emulate_aesenc_u64x2_neon(inputs[0], keys[0]); + generated[1] = _sz_emulate_aesenc_u64x2_neon(inputs[1], keys[1]); + generated[2] = _sz_emulate_aesenc_u64x2_neon(inputs[2], keys[2]); + // The first store can easily be vectorized, but the second can be serial for now + vst1q_u64((sz_u64_t *)(text + 0), generated[0]); + vst1q_u64((sz_u64_t *)(text + 16), generated[1]); + for (sz_size_t i = 32; i < length; ++i) text[i] = ((sz_u8_t *)generated)[i]; + } + // The final part of the function is the primary loop, which processes the buffer in 64-byte chunks. + else { + uint64x2_t inputs[4], pis[4], keys[4], generated[4]; + inputs[0] = vdupq_n_u64(nonce + 0); + inputs[1] = vdupq_n_u64(nonce + 1); + inputs[2] = vdupq_n_u64(nonce + 2); + inputs[3] = vdupq_n_u64(nonce + 3); + // Load parts of PI into the registers + pis[0] = vld1q_u64(pi_ptr + 0); + pis[1] = vld1q_u64(pi_ptr + 2); + pis[2] = vld1q_u64(pi_ptr + 4); + pis[3] = vld1q_u64(pi_ptr + 6); + // XOR the nonce with the PI constants + keys[0] = veorq_u64(vdupq_n_u64(nonce), pis[0]); + keys[1] = veorq_u64(vdupq_n_u64(nonce), pis[1]); + keys[2] = veorq_u64(vdupq_n_u64(nonce), pis[2]); + keys[3] = veorq_u64(vdupq_n_u64(nonce), pis[3]); + + // Produce the output, fixing the key and enumerating input chunks. + sz_size_t i = 0; + uint64x2_t const increment = vdupq_n_u64(4); + for (; i + 64 <= length; i += 64) { + generated[0] = _sz_emulate_aesenc_u64x2_neon(inputs[0], keys[0]); + generated[1] = _sz_emulate_aesenc_u64x2_neon(inputs[1], keys[1]); + generated[2] = _sz_emulate_aesenc_u64x2_neon(inputs[2], keys[2]); + generated[3] = _sz_emulate_aesenc_u64x2_neon(inputs[3], keys[3]); + vst1q_u64((sz_u64_t *)(text + i + 0), generated[0]); + vst1q_u64((sz_u64_t *)(text + i + 16), generated[1]); + vst1q_u64((sz_u64_t *)(text + i + 32), generated[2]); + vst1q_u64((sz_u64_t *)(text + i + 48), generated[3]); + inputs[0] = vaddq_u64(inputs[0], increment); + inputs[1] = vaddq_u64(inputs[1], increment); + inputs[2] = vaddq_u64(inputs[2], increment); + inputs[3] = vaddq_u64(inputs[3], increment); + } + + // Handle the tail of the buffer. + { + generated[0] = _sz_emulate_aesenc_u64x2_neon(inputs[0], keys[0]); + generated[1] = _sz_emulate_aesenc_u64x2_neon(inputs[1], keys[1]); + generated[2] = _sz_emulate_aesenc_u64x2_neon(inputs[2], keys[2]); + generated[3] = _sz_emulate_aesenc_u64x2_neon(inputs[3], keys[3]); + for (sz_size_t j = 0; i < length; ++i, ++j) text[i] = ((sz_u8_t *)generated)[j]; + } + } +} #pragma clang attribute pop #pragma GCC pop_options diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 11366304..164932df 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -634,6 +634,12 @@ typedef union sz_u256_vec_t { #if SZ_USE_HASWELL __m256i ymm; __m128i xmms[2]; +#endif +#if SZ_USE_NEON + uint8x16_t u8x16s[2]; + uint16x8_t u16x8s[2]; + uint32x4_t u32x4s[2]; + uint64x2_t u64x2s[2]; #endif sz_u64_t u64s[4]; sz_u32_t u32s[8]; @@ -653,6 +659,12 @@ typedef union sz_u512_vec_t { #if SZ_USE_HASWELL || SZ_USE_SKYLAKE || SZ_USE_ICE __m256i ymms[2]; __m128i xmms[4]; +#endif +#if SZ_USE_NEON + uint8x16_t u8x16s[4]; + uint16x8_t u16x8s[4]; + uint32x4_t u32x4s[4]; + uint64x2_t u64x2s[4]; #endif sz_u64_t u64s[8]; sz_i64_t i64s[8]; From d44beb4377fad8ccf8b8240245dc12bcb16eb346 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 19:27:59 +0000 Subject: [PATCH 171/751] Break: `sz::edit_distance` -> Levenshtein --- include/stringzilla/stringzilla.hpp | 38 +++++++++++---------- include/stringzilla/types.h | 1 + scripts/test.cpp | 51 +++++++++++++++-------------- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index ab99bc5f..b0146a82 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -2030,7 +2030,7 @@ class basic_string_slice { * * `replace`, `insert`, `erase`, `append`, `push_back`, `pop_back`, `resize`, `shrink_to_fit`... from STL, * * `try_` exception-free "try" operations that returning non-zero values on success, * * `replace_all` and `erase_all` similar to Boost, - * * `edit_distance` - Levenshtein distance computation reusing the allocator, + * * `levenshtein_distance` - Levenshtein distance computation reusing the allocator, * * `translate` - character mapping, * * `randomize`, `random` - for fast random string generation. * @@ -3360,11 +3360,12 @@ class basic_string { concatenation operator|(string_view other) const noexcept { return {view(), other}; } - size_type edit_distance(string_view other, size_type bound = 0) const noexcept { - size_type result; - _with_alloc([&](sz_alloc_type &alloc) { + size_type levenshtein_distance(string_view other, size_type bound = std::numeric_limits::max()) const + noexcept(false) { + size_type result = std::numeric_limits::max(); + raise(_with_alloc([&](sz_alloc_type &alloc) { return sz_levenshtein_distance(data(), size(), other.data(), other.size(), bound, &alloc, &result); - }); + })); return result; } @@ -3839,7 +3840,7 @@ typename concatenation_result::type template std::size_t hamming_distance( // basic_string_slice const &a, basic_string_slice const &b, // - std::size_t bound = 0) noexcept { + std::size_t bound = SZ_SIZE_MAX) noexcept { std::size_t result; sz_hamming_distance(a.data(), a.size(), b.data(), b.size(), bound, &result); return result; @@ -3852,7 +3853,7 @@ std::size_t hamming_distance( template ::type>> std::size_t hamming_distance( // basic_string const &a, basic_string const &b, // - std::size_t bound = 0) noexcept { + std::size_t bound = SZ_SIZE_MAX) noexcept { return ashvardanian::stringzilla::hamming_distance(a.view(), b.view(), bound); } @@ -3862,7 +3863,8 @@ std::size_t hamming_distance( */ template std::size_t hamming_distance_utf8( // - basic_string_slice const &a, basic_string_slice const &b, std::size_t bound = 0) noexcept { + basic_string_slice const &a, basic_string_slice const &b, + std::size_t bound = SZ_SIZE_MAX) noexcept { std::size_t result; sz_hamming_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &result); return result; @@ -3875,7 +3877,7 @@ std::size_t hamming_distance_utf8( // template ::type>> std::size_t hamming_distance_utf8( // basic_string const &a, basic_string const &b, - std::size_t bound = 0) noexcept { + std::size_t bound = SZ_SIZE_MAX) noexcept { return ashvardanian::stringzilla::hamming_distance_utf8(a.view(), b.view(), bound); } @@ -3884,10 +3886,10 @@ std::size_t hamming_distance_utf8( // * @sa sz_levenshtein_distance */ template ::type>> -std::size_t edit_distance( // +std::size_t levenshtein_distance( // basic_string_slice const &a, basic_string_slice const &b, std::size_t bound = SZ_SIZE_MAX, allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { - std::size_t result; + std::size_t result = SZ_SIZE_MAX; raise(_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { return sz_levenshtein_distance(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result); })); @@ -3899,10 +3901,10 @@ std::size_t edit_distance( // * @sa sz_levenshtein_distance */ template > -std::size_t edit_distance( // +std::size_t levenshtein_distance( // basic_string const &a, basic_string const &b, // std::size_t bound = SZ_SIZE_MAX) noexcept(false) { - return ashvardanian::stringzilla::edit_distance(a.view(), b.view(), bound, a.get_allocator()); + return ashvardanian::stringzilla::levenshtein_distance(a.view(), b.view(), bound, a.get_allocator()); } /** @@ -3910,10 +3912,10 @@ std::size_t edit_distance( * @sa sz_levenshtein_distance_utf8 */ template ::type>> -std::size_t edit_distance_utf8( // +std::size_t levenshtein_distance_utf8( // basic_string_slice const &a, basic_string_slice const &b, // std::size_t bound = SZ_SIZE_MAX, allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { - std::size_t result; + std::size_t result = SZ_SIZE_MAX; raise(_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { return sz_levenshtein_distance_utf8(a.data(), a.size(), b.data(), b.size(), bound, &alloc, &result); })); @@ -3925,10 +3927,10 @@ std::size_t edit_distance_utf8( * @sa sz_levenshtein_distance_utf8 */ template > -std::size_t edit_distance_utf8( // +std::size_t levenshtein_distance_utf8( // basic_string const &a, basic_string const &b, // std::size_t bound = SZ_SIZE_MAX) noexcept(false) { - return ashvardanian::stringzilla::edit_distance_utf8(a.view(), b.view(), bound, a.get_allocator()); + return ashvardanian::stringzilla::levenshtein_distance_utf8(a.view(), b.view(), bound, a.get_allocator()); } /** @@ -3945,7 +3947,7 @@ std::ptrdiff_t alignment_score( static_assert(std::is_signed() == std::is_signed(), "sz_error_cost_t must be signed."); - std::ptrdiff_t result; + std::ptrdiff_t result = SZ_SSIZE_MIN; raise(_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { return sz_needleman_wunsch_score(a.data(), a.size(), b.data(), b.size(), &subs[0][0], gap, &alloc, &result); })); diff --git a/include/stringzilla/types.h b/include/stringzilla/types.h index 164932df..3a117f61 100644 --- a/include/stringzilla/types.h +++ b/include/stringzilla/types.h @@ -811,6 +811,7 @@ SZ_PUBLIC void sz_sequence_from_null_terminated_strings(sz_cptr_t *start, sz_siz #define SZ_CACHE_LINE_WIDTH (64) // bytes #define SZ_SIZE_MAX ((sz_size_t)(-1)) #define SZ_SSIZE_MAX ((sz_ssize_t)(SZ_SIZE_MAX >> 1)) +#define SZ_SSIZE_MIN ((sz_ssize_t)(-SZ_SSIZE_MAX - 1)) SZ_INTERNAL sz_size_t _sz_size_max(void) { return SZ_SIZE_MAX; } SZ_INTERNAL sz_ssize_t _sz_ssize_max(void) { return SZ_SSIZE_MAX; } diff --git a/scripts/test.cpp b/scripts/test.cpp index ebcc01fe..8dd66dd9 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -975,27 +975,28 @@ static void test_non_stl_extensions_for_reads() { assert(sz::hamming_distance_utf8(str("abcdefgh"), str("_bcdefg_")) == 2); // replace ASCI prefix and suffix assert(sz::hamming_distance_utf8(str("αβγδ"), str("αγγδ")) == 1); // replace Beta UTF8 codepoint - assert(sz::edit_distance(str("hello"), str("hello")) == 0); - assert(sz::edit_distance(str("hello"), str("hell")) == 1); - assert(sz::edit_distance(str(""), str("")) == 0); - assert(sz::edit_distance(str(""), str("abc")) == 3); - assert(sz::edit_distance(str("abc"), str("")) == 3); - assert(sz::edit_distance(str("abc"), str("ac")) == 1); // one deletion - assert(sz::edit_distance(str("abc"), str("a_bc")) == 1); // one insertion - assert(sz::edit_distance(str("abc"), str("adc")) == 1); // one substitution - assert(sz::edit_distance(str("ggbuzgjux{}l"), str("gbuzgjux{}l")) == 1); // one insertion (prepended) - assert(sz::edit_distance(str("abcdefgABCDEFG"), str("ABCDEFGabcdefg")) == 14); - - assert(sz::edit_distance_utf8(str("hello"), str("hell")) == 1); // no unicode symbols, just ASCII - assert(sz::edit_distance_utf8(str("𠜎 𠜱 𠝹 𠱓"), str("𠜎𠜱𠝹𠱓")) == 3); // add 3 whitespaces in Chinese - assert(sz::edit_distance_utf8(str("💖"), str("💗")) == 1); - - assert(sz::edit_distance_utf8(str("αβγδ"), str("αγδ")) == 1); // insert Beta - assert(sz::edit_distance_utf8(str("école"), str("école")) == 2); // etter "é" as a single character vs "e" + "´" - assert(sz::edit_distance_utf8(str("façade"), str("facade")) == 1); // "ç" with cedilla vs. plain - assert(sz::edit_distance_utf8(str("Schön"), str("Scho\u0308n")) == 2); // "ö" represented as "o" + "¨" - assert(sz::edit_distance_utf8(str("München"), str("Muenchen")) == 2); // German with umlaut vs. transcription - assert(sz::edit_distance_utf8(str("こんにちは世界"), str("こんばんは世界")) == 2); + assert(sz::levenshtein_distance(str("hello"), str("hello")) == 0); + assert(sz::levenshtein_distance(str("hello"), str("hell")) == 1); + assert(sz::levenshtein_distance(str(""), str("")) == 0); + assert(sz::levenshtein_distance(str(""), str("abc")) == 3); + assert(sz::levenshtein_distance(str("abc"), str("")) == 3); + assert(sz::levenshtein_distance(str("abc"), str("ac")) == 1); // one deletion + assert(sz::levenshtein_distance(str("abc"), str("a_bc")) == 1); // one insertion + assert(sz::levenshtein_distance(str("abc"), str("adc")) == 1); // one substitution + assert(sz::levenshtein_distance(str("ggbuzgjux{}l"), str("gbuzgjux{}l")) == 1); // one insertion (prepended) + assert(sz::levenshtein_distance(str("abcdefgABCDEFG"), str("ABCDEFGabcdefg")) == 14); + + assert(sz::levenshtein_distance_utf8(str("hello"), str("hell")) == 1); // no unicode symbols, just ASCII + assert(sz::levenshtein_distance_utf8(str("𠜎 𠜱 𠝹 𠱓"), str("𠜎𠜱𠝹𠱓")) == 3); // add 3 whitespaces in Chinese + assert(sz::levenshtein_distance_utf8(str("💖"), str("💗")) == 1); + + assert(sz::levenshtein_distance_utf8(str("αβγδ"), str("αγδ")) == 1); // insert Beta + assert(sz::levenshtein_distance_utf8(str("école"), str("école")) == + 2); // etter "é" as a single character vs "e" + "´" + assert(sz::levenshtein_distance_utf8(str("façade"), str("facade")) == 1); // "ç" with cedilla vs. plain + assert(sz::levenshtein_distance_utf8(str("Schön"), str("Scho\u0308n")) == 2); // "ö" represented as "o" + "¨" + assert(sz::levenshtein_distance_utf8(str("München"), str("Muenchen")) == 2); // German with umlaut vs. transcription + assert(sz::levenshtein_distance_utf8(str("こんにちは世界"), str("こんばんは世界")) == 2); // Computing alignment scores. using matrix_t = std::int8_t[256][256]; @@ -1645,20 +1646,20 @@ static void test_levenshtein_distances() { }; auto test_distance = [&](sz::string const &l, sz::string const &r, std::size_t expected) { - auto received = sz::edit_distance(l, r); + auto received = sz::levenshtein_distance(l, r); auto received_score = sz::alignment_score(l, r, costs, -1); if (received != expected) print_failure("Levenshtein", l, r, expected, received); if ((std::size_t)(-received_score) != expected) print_failure("Scoring", l, r, expected, received_score); // The distance relation commutes - received = sz::edit_distance(r, l); + received = sz::levenshtein_distance(r, l); received_score = sz::alignment_score(r, l, costs, -1); if (received != expected) print_failure("Levenshtein", r, l, expected, received); if ((std::size_t)(-received_score) != expected) print_failure("Scoring", r, l, expected, received_score); // Validate the bounded variants: if (received > 1) { - assert(sz::edit_distance(l, r, received) == received); - assert(sz::edit_distance(r, l, received - 1) >= (std::max)(l.size(), r.size())); + assert(sz::levenshtein_distance(l, r, received) == received); + assert(sz::levenshtein_distance(r, l, received - 1) >= (std::max)(l.size(), r.size())); } }; From af54e933479b50bc3fef0c1cbe0511a384c2fb66 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 20:20:17 +0000 Subject: [PATCH 172/751] Improve: Separate PRNG backends in benchmarks --- scripts/bench_token.cpp | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 112fbc98..ac632767 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -97,7 +97,28 @@ tracked_unary_functions_t hash_stream_functions() { tracked_unary_functions_t random_generation_functions() { static std::vector buffer; + auto wrap_sz = [](auto function) -> unary_function_t { + return unary_function_t([function](std::string_view s) { + if (buffer.size() < s.size()) buffer.resize(s.size()); + function(buffer.data(), s.size(), 0); + return s.size(); + }); + }; + tracked_unary_functions_t result = { + {"sz_fill_random_serial", wrap_sz(sz_fill_random_serial)}, +#if SZ_USE_HASWELL + {"sz_fill_random_haswell", wrap_sz(sz_fill_random_haswell), true}, +#endif +#if SZ_USE_SKYLAKE + {"sz_fill_random_skylake", wrap_sz(sz_fill_random_skylake), true}, +#endif +#if SZ_USE_ICE + {"sz_fill_random_ice", wrap_sz(sz_fill_random_ice), true}, +#endif +#if SZ_USE_NEON + {"sz_fill_random_neon", wrap_sz(sz_fill_random_neon), true}, +#endif {"std::rand() & 0xFF", unary_function_t([](std::string_view token) -> std::size_t { if (buffer.size() < token.size()) buffer.resize(token.size()); for (std::size_t i = 0; i < token.size(); ++i) buffer[i] = static_cast(std::rand() & 0xFF); @@ -108,12 +129,6 @@ tracked_unary_functions_t random_generation_functions() { randomize_string(buffer.data(), token.size()); return token.size(); })}, - {"sz::randomize", unary_function_t([](std::string_view token) -> std::size_t { - if (buffer.size() < token.size()) buffer.resize(token.size()); - sz::string_span span(buffer.data(), token.size()); - sz::fill_random(span); - return token.size(); - })}, }; return result; } From c4f7a0e36b9afd67ce14763d9399535a93994f63 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 20:23:25 +0000 Subject: [PATCH 173/751] Improve: Discarding buffer in streaming hashes --- include/stringzilla/hash.h | 44 +++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index eb748ef4..3dab9488 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -947,7 +947,7 @@ SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t start, sz_size_t length, sz_u64_t s _sz_hash_minimal_init_haswell(&state, seed); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec; - data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 0)); data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); data2_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + length - 16)); // Let's shift the data within the register to de-interleave the bytes. @@ -963,7 +963,7 @@ SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t start, sz_size_t length, sz_u64_t s _sz_hash_minimal_init_haswell(&state, seed); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; - data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start)); + data0_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 0)); data1_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 16)); data2_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + 32)); data3_vec.xmm = _mm_lddqu_si128((__m128i const *)(start + length - 16)); @@ -981,7 +981,7 @@ SZ_PUBLIC sz_u64_t sz_hash_haswell(sz_cptr_t start, sz_size_t length, sz_u64_t s sz_hash_state_t state; sz_hash_state_init_haswell(&state, seed); for (; state.ins_length + 64 <= length; state.ins_length += 64) { - state.ins.xmms[0] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length)); + state.ins.xmms[0] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length + 0)); state.ins.xmms[1] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length + 16)); state.ins.xmms[2] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length + 32)); state.ins.xmms[3] = _mm_lddqu_si128((__m128i const *)(start + state.ins_length + 48)); @@ -1006,7 +1006,7 @@ SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state, sz_cptr_t te while (length) { // Append to the internal buffer until it's full if (state->ins_length % 64 == 0 && length >= 64) { - state->ins.xmms[0] = _mm_lddqu_si128((__m128i const *)text); + state->ins.xmms[0] = _mm_lddqu_si128((__m128i const *)(text + 0)); state->ins.xmms[1] = _mm_lddqu_si128((__m128i const *)(text + 16)); state->ins.xmms[2] = _mm_lddqu_si128((__m128i const *)(text + 32)); state->ins.xmms[3] = _mm_lddqu_si128((__m128i const *)(text + 48)); @@ -1029,7 +1029,7 @@ SZ_PUBLIC void sz_hash_state_stream_haswell(sz_hash_state_t *state, sz_cptr_t te if (will_fill_block) { _sz_hash_state_update_haswell(state); // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state - for (int i = 0; i < 4; ++i) state->ins.xmms[i] = _mm_setzero_si128(); + for (int i = 0; i < 4; ++i) _mm_storeu_si128(&state->ins.xmms[i], _mm_setzero_si128()); } } } @@ -1043,7 +1043,7 @@ SZ_PUBLIC sz_u64_t sz_hash_state_fold_haswell(sz_hash_state_t const *state) { _sz_hash_minimal_t minimal_state; minimal_state.key.xmm = state->key.xmm; minimal_state.aes.xmm = state->aes.xmms[0]; - minimal_state.sum.xmm = state->sum.u64x2s[0]; + minimal_state.sum.xmm = state->sum.xmms[0]; // The logic is different depending on the length of the input __m128i const *ins_vecs = (__m128i const *)&state->ins.xmms[0]; @@ -1088,7 +1088,7 @@ SZ_PUBLIC void sz_fill_random_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t __m128i inputs[2], pis[2], keys[2], generated[2]; inputs[0] = _mm_set1_epi64x(nonce); inputs[1] = _mm_set1_epi64x(nonce + 1); - pis[0] = _mm_load_si128((__m128i const *)(pi_ptr)); + pis[0] = _mm_load_si128((__m128i const *)(pi_ptr + 0)); pis[1] = _mm_load_si128((__m128i const *)(pi_ptr + 2)); keys[0] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[0]); keys[1] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[1]); @@ -1104,7 +1104,7 @@ SZ_PUBLIC void sz_fill_random_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t inputs[0] = _mm_set1_epi64x(nonce); inputs[1] = _mm_set1_epi64x(nonce + 1); inputs[2] = _mm_set1_epi64x(nonce + 2); - pis[0] = _mm_load_si128((__m128i const *)(pi_ptr)); + pis[0] = _mm_load_si128((__m128i const *)(pi_ptr + 0)); pis[1] = _mm_load_si128((__m128i const *)(pi_ptr + 2)); pis[2] = _mm_load_si128((__m128i const *)(pi_ptr + 4)); keys[0] = _mm_xor_si128(_mm_set1_epi64x(nonce), pis[0]); @@ -1114,7 +1114,7 @@ SZ_PUBLIC void sz_fill_random_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t generated[1] = _mm_aesenc_si128(inputs[1], keys[1]); generated[2] = _mm_aesenc_si128(inputs[2], keys[2]); // The first store can easily be vectorized, but the second can be serial for now - _mm_storeu_si128((__m128i *)text, generated[0]); + _mm_storeu_si128((__m128i *)(text + 0), generated[0]); _mm_storeu_si128((__m128i *)(text + 16), generated[1]); for (sz_size_t i = 32; i < length; ++i) text[i] = ((sz_u8_t *)generated)[i]; } @@ -1126,7 +1126,7 @@ SZ_PUBLIC void sz_fill_random_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t inputs[2] = _mm_set1_epi64x(nonce + 2); inputs[3] = _mm_set1_epi64x(nonce + 3); // Load parts of PI into the registers - pis[0] = _mm_load_si128((__m128i const *)(pi_ptr)); + pis[0] = _mm_load_si128((__m128i const *)(pi_ptr + 0)); pis[1] = _mm_load_si128((__m128i const *)(pi_ptr + 2)); pis[2] = _mm_load_si128((__m128i const *)(pi_ptr + 4)); pis[3] = _mm_load_si128((__m128i const *)(pi_ptr + 6)); @@ -1144,7 +1144,7 @@ SZ_PUBLIC void sz_fill_random_haswell(sz_ptr_t text, sz_size_t length, sz_u64_t generated[1] = _mm_aesenc_si128(inputs[1], keys[1]); generated[2] = _mm_aesenc_si128(inputs[2], keys[2]); generated[3] = _mm_aesenc_si128(inputs[3], keys[3]); - _mm_storeu_si128((__m128i *)(text + i), generated[0]); + _mm_storeu_si128((__m128i *)(text + i + 0), generated[0]); _mm_storeu_si128((__m128i *)(text + i + 16), generated[1]); _mm_storeu_si128((__m128i *)(text + i + 32), generated[2]); _mm_storeu_si128((__m128i *)(text + i + 48), generated[3]); @@ -1389,7 +1389,7 @@ SZ_PUBLIC void sz_hash_state_stream_skylake(sz_hash_state_t *state, sz_cptr_t te if (will_fill_block) { _sz_hash_state_update_haswell(state); // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state - state->ins.zmm = _mm512_setzero_si512(); + _mm512_storeu_si512(&state->ins.zmm, _mm512_setzero_si512()); } } } @@ -1803,7 +1803,7 @@ SZ_PUBLIC sz_u64_t sz_bytesum_neon(sz_cptr_t text, sz_size_t length) { // Final reduction of `sum_vec` to a single scalar sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); - if (length) sum += sz_bytesum_serial(text, length); + while (length--) sum += *(sz_u8_t const *)text++; // Same as the scalar version return sum; } @@ -1917,7 +1917,7 @@ SZ_PUBLIC void sz_hash_state_stream_neon(sz_hash_state_t *state, sz_cptr_t text, while (length) { // Append to the internal buffer until it's full if (state->ins_length % 64 == 0 && length >= 64) { - state->ins.u8x16s[0] = vld1q_u8((sz_u8_t const *)text); + state->ins.u8x16s[0] = vld1q_u8((sz_u8_t const *)(text + 0)); state->ins.u8x16s[1] = vld1q_u8((sz_u8_t const *)(text + 16)); state->ins.u8x16s[2] = vld1q_u8((sz_u8_t const *)(text + 32)); state->ins.u8x16s[3] = vld1q_u8((sz_u8_t const *)(text + 48)); @@ -1940,7 +1940,7 @@ SZ_PUBLIC void sz_hash_state_stream_neon(sz_hash_state_t *state, sz_cptr_t text, if (will_fill_block) { _sz_hash_state_update_neon(state); // Reset to zeros now, so we don't have to overwrite an immutable buffer in the folding state - for (int i = 0; i < 4; ++i) state->ins.u8x16s[i] = vdupq_n_u8(0); + for (int i = 0; i < 4; ++i) vst1q_u8(state->ins.u8s + i * 16, vdupq_n_u8(0)); } } } @@ -2001,10 +2001,10 @@ SZ_PUBLIC sz_u64_t sz_hash_neon(sz_cptr_t start, sz_size_t length, sz_u64_t seed _sz_hash_minimal_init_neon(&state, seed); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec; - data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start)); + data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 0)); data1_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + length - 16)); // Let's shift the data within the register to de-interleave the bytes. - _sz_hash_shift_in_register_serial(&data1_vec, 32 - length); + _sz_hash_shift_in_register_serial(&data1_vec, 32 - length); //! `vextq_u8` requires immediates _sz_hash_minimal_update_neon(&state, data0_vec.u8x16); _sz_hash_minimal_update_neon(&state, data1_vec.u8x16); return _sz_hash_minimal_finalize_neon(&state, length); @@ -2015,11 +2015,11 @@ SZ_PUBLIC sz_u64_t sz_hash_neon(sz_cptr_t start, sz_size_t length, sz_u64_t seed _sz_hash_minimal_init_neon(&state, seed); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec; - data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start)); + data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 0)); data1_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 16)); data2_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + length - 16)); // Let's shift the data within the register to de-interleave the bytes. - _sz_hash_shift_in_register_serial(&data2_vec, 48 - length); + _sz_hash_shift_in_register_serial(&data2_vec, 48 - length); //! `vextq_u8` requires immediates _sz_hash_minimal_update_neon(&state, data0_vec.u8x16); _sz_hash_minimal_update_neon(&state, data1_vec.u8x16); _sz_hash_minimal_update_neon(&state, data2_vec.u8x16); @@ -2031,12 +2031,12 @@ SZ_PUBLIC sz_u64_t sz_hash_neon(sz_cptr_t start, sz_size_t length, sz_u64_t seed _sz_hash_minimal_init_neon(&state, seed); // Load the data and update the state sz_u128_vec_t data0_vec, data1_vec, data2_vec, data3_vec; - data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start)); + data0_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 0)); data1_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 16)); data2_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + 32)); data3_vec.u8x16 = vld1q_u8((sz_u8_t const *)(start + length - 16)); // Let's shift the data within the register to de-interleave the bytes. - _sz_hash_shift_in_register_serial(&data3_vec, 64 - length); + _sz_hash_shift_in_register_serial(&data3_vec, 64 - length); //! `vextq_u8` requires immediates _sz_hash_minimal_update_neon(&state, data0_vec.u8x16); _sz_hash_minimal_update_neon(&state, data1_vec.u8x16); _sz_hash_minimal_update_neon(&state, data2_vec.u8x16); @@ -2049,7 +2049,7 @@ SZ_PUBLIC sz_u64_t sz_hash_neon(sz_cptr_t start, sz_size_t length, sz_u64_t seed sz_hash_state_t state; sz_hash_state_init_neon(&state, seed); for (; state.ins_length + 64 <= length; state.ins_length += 64) { - state.ins.u8x16s[0] = vld1q_u8((sz_u8_t const *)(start + state.ins_length)); + state.ins.u8x16s[0] = vld1q_u8((sz_u8_t const *)(start + state.ins_length + 0)); state.ins.u8x16s[1] = vld1q_u8((sz_u8_t const *)(start + state.ins_length + 16)); state.ins.u8x16s[2] = vld1q_u8((sz_u8_t const *)(start + state.ins_length + 32)); state.ins.u8x16s[3] = vld1q_u8((sz_u8_t const *)(start + state.ins_length + 48)); From 828263f27903ae9df580531ee2589ea3876cfbaa Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 20:24:59 +0000 Subject: [PATCH 174/751] Improve: Discard state in streaming hash --- scripts/bench_token.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index ac632767..8418f826 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -63,14 +63,15 @@ tracked_unary_functions_t hash_functions() { struct wrap_hash_stream { sz_hash_state_t state; + sz_hash_state_init_t init; sz_hash_state_stream_t stream; sz_hash_state_fold_t fold; - wrap_hash_stream(sz_hash_state_stream_t s, sz_hash_state_fold_t f) : stream(s), fold(f) { - sz_hash_state_init(&state, 42); - } + wrap_hash_stream(sz_hash_state_init_t i, sz_hash_state_stream_t s, sz_hash_state_fold_t f) + : init(i), stream(s), fold(f) {} std::size_t operator()(std::string_view s) noexcept { + init(&state, 42); stream(&state, s.data(), s.size()); return fold(&state); } @@ -78,18 +79,23 @@ struct wrap_hash_stream { tracked_unary_functions_t hash_stream_functions() { tracked_unary_functions_t result = { - {"sz_hash_stream_serial", wrap_hash_stream(sz_hash_state_stream_serial, sz_hash_state_fold_serial)}, + {"sz_hash_stream_serial", + wrap_hash_stream(sz_hash_state_init_serial, sz_hash_state_stream_serial, sz_hash_state_fold_serial)}, #if SZ_USE_HASWELL - {"sz_hash_stream_haswell", wrap_hash_stream(sz_hash_state_stream_haswell, sz_hash_state_fold_haswell), true}, + {"sz_hash_stream_haswell", + wrap_hash_stream(sz_hash_state_init_haswell, sz_hash_state_stream_haswell, sz_hash_state_fold_haswell), true}, #endif #if SZ_USE_SKYLAKE - {"sz_hash_stream_skylake", wrap_hash_stream(sz_hash_state_stream_skylake, sz_hash_state_fold_skylake), true}, + {"sz_hash_stream_skylake", + wrap_hash_stream(sz_hash_state_init_skylake, sz_hash_state_stream_skylake, sz_hash_state_fold_skylake), true}, #endif #if SZ_USE_ICE - {"sz_hash_stream_ice", wrap_hash_stream(sz_hash_state_stream_ice, sz_hash_state_fold_ice), true}, + {"sz_hash_stream_ice", + wrap_hash_stream(sz_hash_state_init_ice, sz_hash_state_stream_ice, sz_hash_state_fold_ice), true}, #endif #if SZ_USE_NEON - {"sz_hash_stream_neon", wrap_hash_stream(sz_hash_state_stream_neon, sz_hash_state_fold_neon), true}, + {"sz_hash_stream_neon", + wrap_hash_stream(sz_hash_state_init_neon, sz_hash_state_stream_neon, sz_hash_state_fold_neon), true}, #endif }; return result; From ff23c3de6851b83908cd81ab9b0e47e26d340f2f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sun, 9 Mar 2025 20:27:29 +0000 Subject: [PATCH 175/751] Fix: `std::string::data` is mutable only since C++17 --- scripts/bench_token.cpp | 2 +- scripts/test.cpp | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index 8418f826..af16e7a4 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -132,7 +132,7 @@ tracked_unary_functions_t random_generation_functions() { })}, {"std::uniform_int", unary_function_t([](std::string_view token) -> std::size_t { if (buffer.size() < token.size()) buffer.resize(token.size()); - randomize_string(buffer.data(), token.size()); + randomize_string(&buffer[0], token.size()); return token.size(); })}, }; diff --git a/scripts/test.cpp b/scripts/test.cpp index 8dd66dd9..58756efe 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -223,6 +223,13 @@ static void test_hashing_on_platform( // for (auto seed : seeds) for (std::size_t copies = 1; copies != 100; ++copies) // test_on_seed(repeat("abc", copies), seed); + + // Let's try truly random inputs of different lengths: + for (std::size_t length = 0; length != 200; ++length) { + std::string text(length, '\0'); + randomize_string(&text[0], length); + for (auto seed : seeds) test_on_seed(text, seed); + } } /** From f9da4edcbc845ff83f65c62818dd913179470e77 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 06:00:01 +0000 Subject: [PATCH 176/751] Fix: Composing STL collections --- include/stringzilla/stringzilla.hpp | 20 ++++++++++---------- scripts/bench_container.cpp | 27 +++++++++++++++------------ scripts/test.cpp | 4 ++-- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index b0146a82..03cef8d1 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -1971,7 +1971,7 @@ class basic_string_slice { #pragma endregion /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ - size_type hash(std::uint64_t seed = 42) const noexcept { + size_type hash(std::uint64_t seed = 0) const noexcept { return static_cast(sz_hash(start_, length_, static_cast(seed))); } @@ -3759,10 +3759,10 @@ bool basic_string::try_preparing_replacement( // * @see Similar to `std::less`: https://en.cppreference.com/w/cpp/utility/functional/less * * Unlike the STL analog, doesn't require C++14 or including the heavy `` header. - * Can be used to combine STL classes with StringZilla logic, like: `std::map`. + * Can be used to combine STL classes with StringZilla logic, like: `std::map`. */ -struct string_view_less { - bool operator()(string_view a, string_view b) const noexcept { return a < b; } +struct less { + inline bool operator()(string_view a, string_view b) const noexcept { return a < b; } }; /** @@ -3771,10 +3771,10 @@ struct string_view_less { * * Unlike the STL analog, doesn't require C++14 or including the heavy `` header. * Can be used to combine STL classes with StringZilla logic, like: - * `std::unordered_map`. + * `std::unordered_map`. */ -struct string_view_equal_to { - bool operator()(string_view a, string_view b) const noexcept { return a == b; } +struct equal_to { + inline bool operator()(string_view a, string_view b) const noexcept { return a == b; } }; /** @@ -3783,10 +3783,10 @@ struct string_view_equal_to { * * Unlike the STL analog, doesn't require C++14 or including the heavy `` header. * Can be used to combine STL classes with StringZilla logic, like: - * `std::unordered_map`. + * `std::unordered_map`. */ -struct string_view_hash { - std::size_t operator()(string_view str) const noexcept { return str.hash(); } +struct hash { + inline std::size_t operator()(string_view str) const noexcept { return str.hash(); } }; /** @brief SFINAE-type used to infer the resulting type of concatenating multiple string together. */ diff --git a/scripts/bench_container.cpp b/scripts/bench_container.cpp index 17cd1ec6..ab214517 100644 --- a/scripts/bench_container.cpp +++ b/scripts/bench_container.cpp @@ -51,22 +51,23 @@ void bench_tokens(strings_type const &strings) { auto const &s = strings; // StringZilla structures - bench>("map", s); - bench>("map", s); - bench>("unordered_map", s); - bench>("unordered_map", s); + bench>("std::map", s); + bench>("std::map", s); + bench>("std::umap", s); + bench>("std::umap", s); // Pure STL - bench>("map", s); - bench>("map", s); - bench>("unordered_map", s); - bench>("unordered_map", s); + bench>("std::map", s); + bench>("std::map", s); + bench>("std::umap", s); + bench>("std::umap", s); // STL structures with StringZilla operations - // bench>("map", s); - // bench>("map", s); - // bench>("unordered_map", s); - // bench>("unordered_map", s); + bench>("std::map", s); + bench>("std::map", s); + bench>("std::umap", s); + bench>("std::umap", + s); } int main(int argc, char const **argv) { @@ -77,6 +78,8 @@ int main(int argc, char const **argv) { // Baseline benchmarks for real words, coming in all lengths std::printf("Benchmarking on real words:\n"); bench_tokens(dataset.tokens); + std::printf("Benchmarking on real lines:\n"); + bench_tokens(dataset.lines); // Run benchmarks on tokens of different length for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { diff --git a/scripts/test.cpp b/scripts/test.cpp index 58756efe..0137053c 100644 --- a/scripts/test.cpp +++ b/scripts/test.cpp @@ -1915,8 +1915,8 @@ static void test_stl_containers() { assert(sorted_words_sz.empty()); assert(words_sz.empty()); - std::map sorted_words_stl; - std::unordered_map words_stl; + std::map sorted_words_stl; + std::unordered_map words_stl; assert(sorted_words_stl.empty()); assert(words_stl.empty()); } From 4bec1e511237b530454ec0e7d3d7b2ff35cd96d1 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 07:40:00 +0000 Subject: [PATCH 177/751] Fix: Revert to XMM on Haswell --- include/stringzilla/hash.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/include/stringzilla/hash.h b/include/stringzilla/hash.h index 3dab9488..5aa40884 100644 --- a/include/stringzilla/hash.h +++ b/include/stringzilla/hash.h @@ -875,7 +875,7 @@ SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) for (int i = 0; i < 4; ++i) state->aes.xmms[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2))); for (int i = 0; i < 4; ++i) - state->sum.u64x2s[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2 + 8))); + state->sum.xmms[i] = _mm_xor_si128(seed_vec, _mm_load_si128((__m128i const *)(pi + i * 2 + 8))); // The inputs are zeroed out at the beginning state->ins.xmms[0] = state->ins.xmms[1] = state->ins.xmms[2] = state->ins.xmms[3] = _mm_setzero_si128(); @@ -885,23 +885,23 @@ SZ_PUBLIC void sz_hash_state_init_haswell(sz_hash_state_t *state, sz_u64_t seed) SZ_INTERNAL void _sz_hash_state_update_haswell(sz_hash_state_t *state) { __m128i const shuffle_mask = _mm_load_si128((__m128i const *)_sz_hash_u8x16x4_shuffle()); state->aes.xmms[0] = _mm_aesenc_si128(state->aes.xmms[0], state->ins.xmms[0]); - state->sum.u64x2s[0] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[0], shuffle_mask), state->ins.xmms[0]); + state->sum.xmms[0] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[0], shuffle_mask), state->ins.xmms[0]); state->aes.xmms[1] = _mm_aesenc_si128(state->aes.xmms[1], state->ins.xmms[1]); - state->sum.u64x2s[1] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[1], shuffle_mask), state->ins.xmms[1]); + state->sum.xmms[1] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[1], shuffle_mask), state->ins.xmms[1]); state->aes.xmms[2] = _mm_aesenc_si128(state->aes.xmms[2], state->ins.xmms[2]); - state->sum.u64x2s[2] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[2], shuffle_mask), state->ins.xmms[2]); + state->sum.xmms[2] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[2], shuffle_mask), state->ins.xmms[2]); state->aes.xmms[3] = _mm_aesenc_si128(state->aes.xmms[3], state->ins.xmms[3]); - state->sum.u64x2s[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.u64x2s[3], shuffle_mask), state->ins.xmms[3]); + state->sum.xmms[3] = _mm_add_epi64(_mm_shuffle_epi8(state->sum.xmms[3], shuffle_mask), state->ins.xmms[3]); } SZ_INTERNAL sz_u64_t _sz_hash_state_finalize_haswell(sz_hash_state_t const *state) { // Mix the length into the key __m128i key_with_length = _mm_add_epi64(state->key.xmm, _mm_set_epi64x(0, state->ins_length)); // Combine the "sum" and the "AES" blocks - __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.u64x2s[0], state->aes.xmms[0]); - __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.u64x2s[1], state->aes.xmms[1]); - __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.u64x2s[2], state->aes.xmms[2]); - __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.u64x2s[3], state->aes.xmms[3]); + __m128i mixed_registers0 = _mm_aesenc_si128(state->sum.xmms[0], state->aes.xmms[0]); + __m128i mixed_registers1 = _mm_aesenc_si128(state->sum.xmms[1], state->aes.xmms[1]); + __m128i mixed_registers2 = _mm_aesenc_si128(state->sum.xmms[2], state->aes.xmms[2]); + __m128i mixed_registers3 = _mm_aesenc_si128(state->sum.xmms[3], state->aes.xmms[3]); // Combine the mixed registers __m128i mixed_registers01 = _mm_aesenc_si128(mixed_registers0, mixed_registers1); __m128i mixed_registers23 = _mm_aesenc_si128(mixed_registers2, mixed_registers3); From 48d70ea4d859c8624af661f4a10aa654aa46e6a7 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 07:40:29 +0000 Subject: [PATCH 178/751] Fix: No intersect for Skylake --- c/lib.c | 3 ++- include/stringzilla/intersect.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/c/lib.c b/c/lib.c index 7aea9455..afffeeff 100644 --- a/c/lib.c +++ b/c/lib.c @@ -298,7 +298,6 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->bytesum = sz_bytesum_skylake; impl->sequence_argsort = sz_sequence_argsort_skylake; - impl->sequence_intersect = sz_sequence_intersect_skylake; impl->pgrams_sort = sz_pgrams_sort_skylake; } #endif @@ -319,6 +318,8 @@ SZ_DYNAMIC void sz_dispatch_table_init(void) { impl->hash_state_stream = sz_hash_state_stream_ice; impl->hash_state_fold = sz_hash_state_fold_ice; impl->fill_random = sz_fill_random_ice; + + impl->sequence_intersect = sz_sequence_intersect_ice; } #endif diff --git a/include/stringzilla/intersect.h b/include/stringzilla/intersect.h index b3610969..cf24eb57 100644 --- a/include/stringzilla/intersect.h +++ b/include/stringzilla/intersect.h @@ -58,7 +58,7 @@ extern "C" { * Example usage: * * @code{.c} - * #include + * #include * int main() { * char const *first[] = {"banana", "apple", "cherry"}; * char const *second[] = {"cherry", "orange", "pineapple", "banana"}; From 4d955d37530a5e60c73fedd8d057da8a75203b55 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 08:06:31 +0000 Subject: [PATCH 179/751] Improve: Logging in container benchmarks --- scripts/bench.hpp | 3 ++- scripts/bench_container.cpp | 53 ++++++++++++++++++------------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/scripts/bench.hpp b/scripts/bench.hpp index cbec9bf5..69be9722 100644 --- a/scripts/bench.hpp +++ b/scripts/bench.hpp @@ -199,7 +199,8 @@ inline dataset_t make_dataset_from_path(std::string path) { "Parsed the dataset with:\n" // "- %zu words of mean length ~ %.2f bytes\n" // "- %zu lines of mean length ~ %.2f bytes\n", // - data.tokens.size(), mean_token_bytes, data.lines.size(), mean_line_bytes); + "- %zu bytes in total\n", // + data.tokens.size(), mean_token_bytes, data.lines.size(), mean_line_bytes, data.text.size()); return data; } diff --git a/scripts/bench_container.cpp b/scripts/bench_container.cpp index ab214517..38d92038 100644 --- a/scripts/bench_container.cpp +++ b/scripts/bench_container.cpp @@ -14,25 +14,25 @@ using namespace ashvardanian::stringzilla::scripts; -template -std::vector to(std::vector const &strings) { - std::vector result; +template +std::vector to(std::vector const &strings) { + std::vector result; result.reserve(strings.size()); - for (string_type_from const &string : strings) result.push_back({string.data(), string.size()}); + for (string_from_type_ const &string : strings) result.push_back({string.data(), string.size()}); return result; } /** * @brief Evaluation for search string operations: find. */ -template +template void bench(std::string name, std::vector const &strings) { - using key_type = typename container_at::key_type; + using key_type = typename container_type_::key_type; std::vector keys = to(strings); // Build up the container - container_at container; + container_type_ container; for (key_type const &key : keys) container[key] = 0; tracked_function_gt variant; @@ -45,33 +45,32 @@ void bench(std::string name, std::vector const &strings) { variant.print(); } -template -void bench_tokens(strings_type const &strings) { - if (strings.size() == 0) return; - auto const &s = strings; +template +void bench_tokens(strings_type_ const &s) { + if (s.size() == 0) return; - // StringZilla structures - bench>("std::map", s); - bench>("std::map", s); - bench>("std::umap", s); - bench>("std::umap", s); + // STL containers with StringZilla strings and views + bench>("std::map::find", s); + bench>("std::map::find", s); + bench>("std::umap::find", s); + bench>("std::umap::find", s); - // Pure STL - bench>("std::map", s); - bench>("std::map", s); - bench>("std::umap", s); - bench>("std::umap", s); + // STL containers with STL strings and views + bench>("std::map::find", s); + bench>("std::map::find", s); + bench>("std::umap::find", s); + bench>("std::umap::find", s); // STL structures with StringZilla operations - bench>("std::map", s); - bench>("std::map", s); - bench>("std::umap", s); - bench>("std::umap", - s); + bench>("std::map::find", s); + bench>("std::map::find", s); + bench>("std::umap::find", s); + bench>( + "std::umap::find", s); } int main(int argc, char const **argv) { - std::printf("StringZilla. Starting search benchmarks.\n"); + std::printf("StringZilla. Starting container benchmarks.\n"); dataset_t dataset = prepare_benchmark_environment(argc, argv); From 3b1897ef2acb8d00834c13f8873210244932ee2f Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 10:15:35 +0000 Subject: [PATCH 180/751] Improve: Token benchmarks In the past, token benchmarks weren't balanced. For equality comparisons and ordering, they would take random strings which are almost always differing in the very first character and in length, making branch prediction trivial and performance identical between backends. The new benchmarks include self-comparisons, which are more similar to hash-table probing or strings sorting workloads. --- scripts/bench.hpp | 26 ++++++--- scripts/bench_token.cpp | 115 +++++++++++++++++++++++++++------------- 2 files changed, 97 insertions(+), 44 deletions(-) diff --git a/scripts/bench.hpp b/scripts/bench.hpp index cbec9bf5..c0877033 100644 --- a/scripts/bench.hpp +++ b/scripts/bench.hpp @@ -41,6 +41,22 @@ struct benchmark_result_t { using unary_function_t = std::function; using binary_function_t = std::function; +/** + * @brief Wraps a binary function to compare all combinations of two tokens. + * Designed to benchmark functions that on-average take very different times to execute + * for the same string or different strings. For equality checks it's similar to a typical + * load when probing a Hash Table. For relative ordering, it's similar to sorting a dense + * array with many similar strings. + */ +template +binary_function_t binary_combinations(function_type_ function) { + return binary_function_t([function](std::string_view a, std::string_view b) { + // Assuming most outputs here will be 0 or 1, we want to combine them to with different + // multiples to ensure a unique output for each combination. + return function(a, b) * 1 + function(a, a) * 2 + function(b, a) * 4 + function(b, b) * 8; + }); +} + /** * @brief Wrapper for a single execution backend. */ @@ -144,9 +160,7 @@ inline std::vector tokenize(std::string_view str, is_separator return words; } -/** - * @brief Splits a string into words, using newlines, tabs, and whitespaces as delimiters. - */ +/** @brief Splits a string into words, using newlines, tabs, and whitespaces as delimiters using @b `std::isspace`. */ inline std::vector tokenize(std::string_view str) { return tokenize(str, [](char c) { return std::isspace(c); }); } @@ -175,11 +189,11 @@ struct dataset_t { inline dataset_t make_dataset_from_path(std::string path) { dataset_t data; data.text = read_file(path); - data.text.resize(bit_floor(data.text.size())); + data.text.resize(bit_floor(data.text.size())); // Shrink to the nearest power of two data.tokens = tokenize(data.text); - data.tokens.resize(bit_floor(data.tokens.size())); + data.tokens.resize(bit_floor(data.tokens.size())); // Shrink to the nearest power of two data.lines = tokenize(data.text, [](char c) { return c == '\n'; }); - data.lines.resize(bit_floor(data.lines.size())); + data.lines.resize(bit_floor(data.lines.size())); // Shrink to the nearest power of two #if !SZ_DEBUG // Shuffle only in release mode auto &generator = global_random_generator(); diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp index af16e7a4..57cb4bc2 100644 --- a/scripts/bench_token.cpp +++ b/scripts/bench_token.cpp @@ -11,6 +11,10 @@ using namespace ashvardanian::stringzilla::scripts; +/** + * @brief Provides kernels, each computing the unsigned sum of bytes in given tokens. + * Compares all supported SIMD backed outputs to the serial implementation. + */ tracked_unary_functions_t bytesum_functions() { auto wrap_sz = [](auto function) -> unary_function_t { return unary_function_t([function](std::string_view s) { return function(s.data(), s.size()); }); @@ -38,9 +42,13 @@ tracked_unary_functions_t bytesum_functions() { return result; } +/** + * @brief Provides kernels, each computing the hash of given tokens using the same seed. + * Compares all supported SIMD backed outputs to the serial implementation. + */ tracked_unary_functions_t hash_functions() { auto wrap_sz = [](auto function) -> unary_function_t { - return unary_function_t([function](std::string_view s) { return function(s.data(), s.size(), 42); }); + return unary_function_t([function](std::string_view s) { return function(s.data(), s.size(), 0); }); }; tracked_unary_functions_t result = { {"sz_hash_serial", wrap_sz(sz_hash_serial)}, @@ -61,13 +69,14 @@ tracked_unary_functions_t hash_functions() { return result; } -struct wrap_hash_stream { +/** @brief Wraps hash state initialization, streaming, and folding for streaming benchmarks. */ +struct wrap_sz_hash_stream { sz_hash_state_t state; sz_hash_state_init_t init; sz_hash_state_stream_t stream; sz_hash_state_fold_t fold; - wrap_hash_stream(sz_hash_state_init_t i, sz_hash_state_stream_t s, sz_hash_state_fold_t f) + wrap_sz_hash_stream(sz_hash_state_init_t i, sz_hash_state_stream_t s, sz_hash_state_fold_t f) : init(i), stream(s), fold(f) {} std::size_t operator()(std::string_view s) noexcept { @@ -77,30 +86,40 @@ struct wrap_hash_stream { } }; +/** + * @brief Provides kernels, each computing the hash of given tokens using more expensive "streaming" API. + * Compares all supported SIMD backed outputs to the serial implementation. + */ tracked_unary_functions_t hash_stream_functions() { tracked_unary_functions_t result = { {"sz_hash_stream_serial", - wrap_hash_stream(sz_hash_state_init_serial, sz_hash_state_stream_serial, sz_hash_state_fold_serial)}, + wrap_sz_hash_stream(sz_hash_state_init_serial, sz_hash_state_stream_serial, sz_hash_state_fold_serial)}, #if SZ_USE_HASWELL {"sz_hash_stream_haswell", - wrap_hash_stream(sz_hash_state_init_haswell, sz_hash_state_stream_haswell, sz_hash_state_fold_haswell), true}, + wrap_sz_hash_stream(sz_hash_state_init_haswell, sz_hash_state_stream_haswell, sz_hash_state_fold_haswell), + true}, #endif #if SZ_USE_SKYLAKE {"sz_hash_stream_skylake", - wrap_hash_stream(sz_hash_state_init_skylake, sz_hash_state_stream_skylake, sz_hash_state_fold_skylake), true}, + wrap_sz_hash_stream(sz_hash_state_init_skylake, sz_hash_state_stream_skylake, sz_hash_state_fold_skylake), + true}, #endif #if SZ_USE_ICE {"sz_hash_stream_ice", - wrap_hash_stream(sz_hash_state_init_ice, sz_hash_state_stream_ice, sz_hash_state_fold_ice), true}, + wrap_sz_hash_stream(sz_hash_state_init_ice, sz_hash_state_stream_ice, sz_hash_state_fold_ice), true}, #endif #if SZ_USE_NEON {"sz_hash_stream_neon", - wrap_hash_stream(sz_hash_state_init_neon, sz_hash_state_stream_neon, sz_hash_state_fold_neon), true}, + wrap_sz_hash_stream(sz_hash_state_init_neon, sz_hash_state_stream_neon, sz_hash_state_fold_neon), true}, #endif }; return result; } +/** + * @brief Provides kernels, each generating random bytes for given tokens using the same "nonce". + * Compares all supported SIMD backed outputs to the serial implementation. + */ tracked_unary_functions_t random_generation_functions() { static std::vector buffer; auto wrap_sz = [](auto function) -> unary_function_t { @@ -139,55 +158,75 @@ tracked_unary_functions_t random_generation_functions() { return result; } +/** @brief Wraps string equality check for potentially different length inputs. */ +struct wrap_sz_equal { + sz_equal_t function; + + wrap_sz_equal(sz_equal_t f) : function(f) {} + bool operator()(std::string_view a, std::string_view b) const noexcept { + return a.size() == b.size() && function(a.data(), b.data(), a.size()); + } +}; + +/** @brief Wraps LibC's string equality check for potentially different length inputs. */ +bool memcmp_for_equality(std::string_view a, std::string_view b) noexcept { + return (a.size() == b.size() && memcmp(a.data(), b.data(), a.size()) == 0); +} + +/** + * @brief Provides kernels, each comparing two tokens for equality. + * Compares all supported SIMD backed outputs to the serial implementation. + * In each iteration combines self- and cross-compares to dampen the branch prediction effect, + * assuming most random string would differ in the very first byte. + */ tracked_binary_functions_t equality_functions() { - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](std::string_view a, std::string_view b) { - return a.size() == b.size() && function(a.data(), b.data(), a.size()); - }); - }; tracked_binary_functions_t result = { - {"std::string_view.==", [](std::string_view a, std::string_view b) { return a == b; }}, - {"sz_equal_serial", wrap_sz(sz_equal_serial), true}, + {"sz_equal_serial", binary_combinations(wrap_sz_equal(sz_equal_serial))}, #if SZ_USE_HASWELL - {"sz_equal_haswell", wrap_sz(sz_equal_haswell), true}, + {"sz_equal_haswell", binary_combinations(wrap_sz_equal(sz_equal_haswell)), true}, #endif #if SZ_USE_SKYLAKE - {"sz_equal_skylake", wrap_sz(sz_equal_skylake), true}, + {"sz_equal_skylake", binary_combinations(wrap_sz_equal(sz_equal_skylake)), true}, #endif - {"memcmp", - [](std::string_view a, std::string_view b) { - return (a.size() == b.size() && memcmp(a.data(), b.data(), a.size()) == 0); - }}, +#if SZ_USE_SVE + {"sz_equal_sve", binary_combinations(wrap_sz_equal(sz_equal_sve)), true}, +#endif +#if SZ_USE_NEON + {"sz_equal_neon", binary_combinations(wrap_sz_equal(sz_equal_neon)), true}, +#endif + {"memcmp(equality)", binary_combinations(memcmp_for_equality)}, }; return result; } +/** @brief Wraps LibC's string comparison for potentially different length inputs. */ +int memcmp_for_ordering(std::string_view a, std::string_view b) noexcept { + auto order = memcmp(a.data(), b.data(), a.size() < b.size() ? a.size() : b.size()); + if (order == 0) return a.size() == b.size() ? 0 : (a.size() < b.size() ? -1 : 1); + return order; +} + +/** + * @brief Provides kernels, each computing the relative order of two tokens. + * Compares all supported SIMD backed outputs to the serial implementation. + * In each iteration combines self- and cross-compares to dampen the branch prediction effect, + * assuming most random string would differ in the very first byte. + */ tracked_binary_functions_t ordering_functions() { auto wrap_sz = [](auto function) -> binary_function_t { return binary_function_t([function](std::string_view a, std::string_view b) { - return function(a.data(), a.size(), b.data(), b.size()); + return (int)function(a.data(), a.size(), b.data(), b.size()); }); }; tracked_binary_functions_t result = { - {"std::string_view.compare", - [](std::string_view a, std::string_view b) { - auto order = a.compare(b); - return (order == 0 ? sz_equal_k : (order < 0 ? sz_less_k : sz_greater_k)); - }}, - {"sz_order_serial", wrap_sz(sz_order_serial), true}, + {"sz_order_serial", binary_combinations(wrap_sz(sz_order_serial))}, #if SZ_USE_HASWELL - {"sz_order_haswell", wrap_sz(sz_order_haswell), true}, + {"sz_order_haswell", binary_combinations(wrap_sz(sz_order_haswell)), true}, #endif #if SZ_USE_SKYLAKE - {"sz_order_skylake", wrap_sz(sz_order_skylake), true}, -#endif - {"memcmp", - [](std::string_view a, std::string_view b) { - auto order = memcmp(a.data(), b.data(), a.size() < b.size() ? a.size() : b.size()); - return order != 0 ? (a.size() == b.size() ? (order < 0 ? sz_less_k : sz_greater_k) - : (a.size() < b.size() ? sz_less_k : sz_greater_k)) - : sz_equal_k; - }}, + {"sz_order_skylake", binary_combinations(wrap_sz(sz_order_skylake)), true}, +#endif + {"memcmp(ordering)", binary_combinations(memcmp_for_ordering)}, }; return result; } From c31020dfbb0bc40a7bfdf86091add8cae1a2fd0c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 10:17:32 +0000 Subject: [PATCH 181/751] Add: Comparisons in SVE This leads to doubling the performance on mixed workloads which may include self-comparisons, where both comparison arguments are the same. --- include/stringzilla/compare.h | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/include/stringzilla/compare.h b/include/stringzilla/compare.h index 494d1442..b6412016 100644 --- a/include/stringzilla/compare.h +++ b/include/stringzilla/compare.h @@ -399,7 +399,26 @@ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) -/* Nothing here for now. */ +SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { + // Determine the number of bytes in an SVE vector. + do { + svbool_t progress_vec = svwhilelt_b8((sz_size_t)0, length); + svuint8_t a_vec = svld1(progress_vec, (sz_u8_t const *)a); + svuint8_t b_vec = svld1(progress_vec, (sz_u8_t const *)b); + // Compare: generate a predicate marking lanes where a!=b + svbool_t not_equal_vec = svcmpne(progress_vec, a_vec, b_vec); + if (svptest_any(progress_vec, not_equal_vec)) return sz_false_k; + sz_size_t const vector_length = svcntp_b8(svptrue_b8(), progress_vec); + a += vector_length, b += vector_length, length -= vector_length; + } while (length > 0); + return sz_true_k; +} + +SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: + //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations + return sz_order_serial(a, a_length, b, b_length); +} #pragma clang attribute pop #pragma GCC pop_options @@ -417,6 +436,8 @@ SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { return sz_equal_skylake(a, b, length); #elif SZ_USE_HASWELL return sz_equal_haswell(a, b, length); +#elif SZ_USE_SVE + return sz_equal_sve(a, b, length); #elif SZ_USE_NEON return sz_equal_neon(a, b, length); #else @@ -429,6 +450,8 @@ SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, return sz_order_skylake(a, a_length, b, b_length); #elif SZ_USE_HASWELL return sz_order_haswell(a, a_length, b, b_length); +#elif SZ_USE_SVE + return sz_order_sve(a, a_length, b, b_length); #elif SZ_USE_NEON return sz_order_neon(a, a_length, b, b_length); #else From 92b9a569d9b0ad24afd9f3a51ef312ccd130ce71 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 10:34:28 +0000 Subject: [PATCH 182/751] Docs: Outdated function naming & spelling --- .vscode/settings.json | 9 +++++++++ README.md | 28 ++++++++++++++-------------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 678b1305..85f842ea 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -69,6 +69,7 @@ "hexdigits", "Hirschberg's", "Horspool", + "Hutter", "Hyyro", "illformed", "initproc", @@ -100,6 +101,7 @@ "Morten", "Mosè", "MSVC", + "Nadav", "napi", "nargsf", "ndim", @@ -119,6 +121,7 @@ "pgrams", "Plouffe", "printables", + "ptrdiff", "pytest", "Pythonic", "qsort", @@ -134,13 +137,17 @@ "Ritchie", "rmatcher", "rmatches", + "Rotem", "rpartition", "rsplit", "rsplits", "rstrip", + "Sankoff", + "Sergey", "SIMD", "sklearn", "Skylake", + "Slotin", "splitlines", "ssize", "startswith", @@ -152,6 +159,7 @@ "substr", "SWAR", "Tanimoto", + "Taras", "thyrotropin", "Titin", "tparam", @@ -163,6 +171,7 @@ "VBMI", "vectorcallfunc", "Vectorizer", + "Vintsyuk", "Wagner", "whitespaces", "Wunsch", diff --git a/README.md b/README.md index a3121cb4..b544bb1c 100644 --- a/README.md +++ b/README.md @@ -486,9 +486,9 @@ count: int = sz.count("haystack", "needle", start=0, end=sys.maxsize, allowoverl ### Edit Distances ```py -assert sz.edit_distance("apple", "aple") == 1 # skip one ASCII character -assert sz.edit_distance("αβγδ", "αγδ") == 2 # skip two bytes forming one rune -assert sz.edit_distance_unicode("αβγδ", "αγδ") == 1 # one unicode rune +assert sz.levenshtein_distance("apple", "aple") == 1 # skip one ASCII character +assert sz.levenshtein_distance("αβγδ", "αγδ") == 2 # skip two bytes forming one rune +assert sz.levenshtein_distance_unicode("αβγδ", "αγδ") == 1 # one unicode rune ``` Several Python libraries provide edit distance computation. @@ -513,7 +513,7 @@ costs = np.zeros((256, 256), dtype=np.int8) costs.fill(-1) np.fill_diagonal(costs, 0) -assert sz.alignment_score("first", "second", substitution_matrix=costs, gap_score=-1) == -sz.edit_distance(a, b) +assert sz.alignment_score("first", "second", substitution_matrix=costs, gap_score=-1) == -sz.levenshtein_distance(a, b) ``` Using the same proteins as for Levenshtein distance benchmarks: @@ -1088,8 +1088,8 @@ Standard library functions may not offer the most efficient or convenient method - `haystack.replace_all(sz::byteset(""), replacement_string)` - `haystack.try_replace_all(needle_string, replacement_string)` - `haystack.try_replace_all(sz::byteset(""), replacement_string)` -- `haystack.transform(sz::look_up_table::identity())` -- `haystack.transform(sz::look_up_table::identity(), haystack.data())` +- `haystack.lookup(sz::look_up_table::identity())` +- `haystack.lookup(sz::look_up_table::identity(), haystack.data())` ### Levenshtein Edit Distance and Alignment Scores @@ -1103,8 +1103,8 @@ sz::hamming_distance(first, second[, upper_bound]) -> std::size_t; sz::hamming_distance_utf8(first, second[, upper_bound]) -> std::size_t; // Count number of insertions, deletions and substitutions -sz::edit_distance(first, second[, upper_bound[, allocator]]) -> std::size_t; -sz::edit_distance_utf8(first, second[, upper_bound[, allocator]]) -> std::size_t; +sz::levenshtein_distance(first, second[, upper_bound[, allocator]]) -> std::size_t; +sz::levenshtein_distance_utf8(first, second[, upper_bound[, allocator]]) -> std::size_t; // Substitution-parametrized Needleman-Wunsch global alignment score std::int8_t costs[256][256]; // Substitution costs matrix @@ -1160,8 +1160,8 @@ The performance of those containers is often limited by the performance of the s StringZilla can be used to accelerate containers with `std::string` keys, by overriding the default comparator and hash functions. ```cpp -std::map sorted_words; -std::unordered_map words; +std::map sorted_words; +std::unordered_map words; ``` Alternatively, a better approach would be to use the `sz::string` class as a key. @@ -1278,19 +1278,19 @@ assert_eq!(my_str.sz_find("world"), Some(7)); assert_eq!(my_cow_str.as_ref().sz_find("world"), Some(7)); ``` -The library also exposes Levenshtein and Hamming edit-distances for byte-arrays and UTF-8 strings, as well as Needleman-Wunch alignment scores. +The library also exposes Levenshtein and Hamming edit-distances for byte-arrays and UTF-8 strings, as well as Needleman-Wunsch alignment scores. ```rust use stringzilla::sz; // Handling arbitrary byte arrays: -sz::edit_distance("Hello, world!", "Hello, world?"); // 1 +sz::levenshtein_distance("Hello, world!", "Hello, world?"); // 1 sz::hamming_distance("Hello, world!", "Hello, world?"); // 1 sz::alignment_score("Hello, world!", "Hello, world?", sz::unary_substitution_costs(), -1); // -1 // Handling UTF-8 strings: sz::hamming_distance_utf8("αβγδ", "αγγδ") // 1 -sz::edit_distance_utf8("façade", "facade") // 1 +sz::levenshtein_distance_utf8("façade", "facade") // 1 ``` [memchr-benchmarks]: https://github.com/ashvardanian/memchr_vs_stringzilla @@ -1465,7 +1465,7 @@ In AVX-512, StringZilla uses non-temporal stores to avoid cache pollution, when Moreover, it handles the unaligned head and the tails of the `target` buffer separately, ensuring that writes in big copies are always aligned to cache-line boundaries. That's true for both AVX2 and AVX-512 backends. -StringZilla also contains "drafts" of smarter, but less efficient algorithms, that minimize the number of unaligned loads, perfoming shuffles and permutations. +StringZilla also contains "drafts" of smarter, but less efficient algorithms, that minimize the number of unaligned loads, performing shuffles and permutations. That's a topic for future research, as the performance gains are not yet satisfactory. > § Reading materials. From 298d2146b8839eab71cb40f5837adc687dc1b952 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 10:51:36 +0000 Subject: [PATCH 183/751] Fix: Extra comma in `printf` --- scripts/bench.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/bench.hpp b/scripts/bench.hpp index a18e47c5..fca1cfea 100644 --- a/scripts/bench.hpp +++ b/scripts/bench.hpp @@ -209,11 +209,11 @@ inline dataset_t make_dataset_from_path(std::string path) { mean_line_bytes /= data.lines.size(); std::setlocale(LC_NUMERIC, ""); - std::printf( // - "Parsed the dataset with:\n" // - "- %zu words of mean length ~ %.2f bytes\n" // - "- %zu lines of mean length ~ %.2f bytes\n", // - "- %zu bytes in total\n", // + std::printf( // + "Parsed the dataset with:\n" // + "- %zu words of mean length ~ %.2f bytes\n" // + "- %zu lines of mean length ~ %.2f bytes\n" // + "- %zu bytes in total\n", // data.tokens.size(), mean_token_bytes, data.lines.size(), mean_line_bytes, data.text.size()); return data; From 467b4b81cb4bc0e9a64844748a417762378918c9 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 11:10:04 +0000 Subject: [PATCH 184/751] Make: Formatting CMakeLists.txt --- .cmake-format.py | 19 ++ CMakeLists.txt | 666 +++++++++++++++++++++++------------------------ 2 files changed, 340 insertions(+), 345 deletions(-) create mode 100644 .cmake-format.py diff --git a/.cmake-format.py b/.cmake-format.py new file mode 100644 index 00000000..fb56f11b --- /dev/null +++ b/.cmake-format.py @@ -0,0 +1,19 @@ +# ----------------------------- +# Options effecting formatting. +# ----------------------------- +with section("format"): + # How wide to allow formatted cmake files + line_width = 120 + + # How many spaces to tab for indent + tab_size = 4 + + # If true, separate flow control names from their parentheses with a space + separate_ctrl_name_with_space = True + + # If true, separate function names from parentheses with a space + separate_fn_name_with_space = False + + # If a statement is wrapped to more than one line, than dangle the closing + # parenthesis on its own line. + dangle_parens = True diff --git a/CMakeLists.txt b/CMakeLists.txt index 1da1e36f..f1c1a24f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,42 +2,43 @@ # # This file defines several library build & installation targets: # -# - stringzilla_header: A header-only library with the StringZilla C and C++ headers. -# - stringzilla_shared: A shared library with the StringZilla C and C++ headers and dynamic SIMD dispatch. -# - stringzilla_bare: A shared library with the StringZilla headers, but without linking the standard C library. -# +# * stringzilla_header: A header-only library with the StringZilla C and C++ headers. +# * stringzilla_shared: A shared library with the StringZilla C and C++ headers and dynamic SIMD dispatch. +# * stringzilla_bare: A shared library with the StringZilla headers, but without linking the standard C library. +# # Tests for different C++ standards: # -# - stringzilla_test_cpp11: C++11 baseline support. -# - stringzilla_test_cpp14: C++14 support with `std::less`-like function objects. -# - stringzilla_test_cpp17: C++17 support with `std::string_view` compatibility. -# - stringzilla_test_cpp20: C++20 support with `<=>` operator and more `constexpr` features. +# * stringzilla_test_cpp11: C++11 baseline support. +# * stringzilla_test_cpp14: C++14 support with `std::less`-like function objects. +# * stringzilla_test_cpp17: C++17 support with `std::string_view` compatibility. +# * stringzilla_test_cpp20: C++20 support with `<=>` operator and more `constexpr` features. # # Tests for different SIMD architectures: # -# - stringzilla_test_cpp20_serial: A test executable for serial execution. -# - stringzilla_test_cpp20_haswell: A test executable for AVX2. -# - stringzilla_test_cpp20_ice: A test executable for AVX-512. -# - stringzilla_test_cpp20_neon: A test executable for ARM Neon. -# - stringzilla_test_cpp20_sve: A test executable for ARM Scalable Vector Extension. +# * stringzilla_test_cpp20_serial: A test executable for serial execution. +# * stringzilla_test_cpp20_haswell: A test executable for AVX2. +# * stringzilla_test_cpp20_ice: A test executable for AVX-512. +# * stringzilla_test_cpp20_neon: A test executable for ARM Neon. +# * stringzilla_test_cpp20_sve: A test executable for ARM Scalable Vector Extension. # # Benchmarks: # -# - stringzilla_bench_search: A benchmark for substring search operations. -# - stringzilla_bench_similarity: A benchmark for similarity operations. -# - stringzilla_bench_sort: A benchmark for sorting operations. -# - stringzilla_bench_token: A benchmark for comparators and hash functions. -# - stringzilla_bench_container: A benchmark for STL containers powered by StringZilla. -# - stringzilla_bench_memory: A benchmark for LibC-style low-level memory operations. +# * stringzilla_bench_search: A benchmark for substring search operations. +# * stringzilla_bench_similarity: A benchmark for similarity operations. +# * stringzilla_bench_sort: A benchmark for sorting operations. +# * stringzilla_bench_token: A benchmark for comparators and hash functions. +# * stringzilla_bench_container: A benchmark for STL containers powered by StringZilla. +# * stringzilla_bench_memory: A benchmark for LibC-style low-level memory operations. # # For higher-level language bindings separate build scripts are provided, native to each toolchain. cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project( - stringzilla - VERSION 3.11.3 - LANGUAGES C CXX - DESCRIPTION "SIMD-accelerated string search, sort, hashes, fingerprints, & edit distances" - HOMEPAGE_URL "https://github.com/ashvardanian/stringzilla") + stringzilla + VERSION 3.11.3 + LANGUAGES C CXX + DESCRIPTION "SIMD-accelerated string search, sort, hashes, fingerprints, & edit distances" + HOMEPAGE_URL "https://github.com/ashvardanian/stringzilla" +) set(CMAKE_C_STANDARD 99) set(CMAKE_CXX_STANDARD 11) @@ -55,363 +56,338 @@ message(STATUS "C++ Compiler Version: ${CMAKE_CXX_COMPILER_VERSION}") message(STATUS "C++ Compiler: ${CMAKE_CXX_COMPILER}") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") -if(CMAKE_SIZEOF_VOID_P EQUAL 8) - message(STATUS "Pointer size: 64-bit") -else() - message(STATUS "Pointer size: 32-bit") -endif() +if (CMAKE_SIZEOF_VOID_P EQUAL 8) + message(STATUS "Pointer size: 64-bit") +else () + message(STATUS "Pointer size: 32-bit") +endif () # Set a default build type to "Release" if none was specified -if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) - message(STATUS "Setting build type to 'Release' as none was specified.") - set(CMAKE_BUILD_TYPE - Release - CACHE STRING "Choose the type of build." FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" - "MinSizeRel" "RelWithDebInfo") -endif() - -if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64|amd64") - SET(SZ_PLATFORM_X86 TRUE) - message(STATUS "Platform: x86") -elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|AARCH64|arm64|ARM64") - SET(SZ_PLATFORM_ARM TRUE) - message(STATUS "Platform: ARM") -endif() - -# Determine if StringZilla is built as a sub-project (using `add_subdirectory`) -# or if it is the main project +if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + message(STATUS "Setting build type to 'Release' as none was specified.") + set(CMAKE_BUILD_TYPE + Release + CACHE STRING "Choose the type of build." FORCE + ) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif () + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64|amd64") + set(SZ_PLATFORM_X86 TRUE) + message(STATUS "Platform: x86") +elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|AARCH64|arm64|ARM64") + set(SZ_PLATFORM_ARM TRUE) + message(STATUS "Platform: ARM") +endif () + +# Determine if StringZilla is built as a sub-project (using `add_subdirectory`) or if it is the main project set(STRINGZILLA_IS_MAIN_PROJECT OFF) -if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) - set(STRINGZILLA_IS_MAIN_PROJECT ON) -endif() +if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) + set(STRINGZILLA_IS_MAIN_PROJECT ON) +endif () # Installation options option(STRINGZILLA_INSTALL "Install CMake targets" OFF) -option(STRINGZILLA_BUILD_TEST "Compile a native unit test in C++" - ${STRINGZILLA_IS_MAIN_PROJECT}) -option(STRINGZILLA_BUILD_BENCHMARK "Compile a native benchmark in C++" - ${STRINGZILLA_IS_MAIN_PROJECT}) +option(STRINGZILLA_BUILD_TEST "Compile a native unit test in C++" ${STRINGZILLA_IS_MAIN_PROJECT}) +option(STRINGZILLA_BUILD_BENCHMARK "Compile a native benchmark in C++" ${STRINGZILLA_IS_MAIN_PROJECT}) option(STRINGZILLA_BUILD_SHARED "Compile a dynamic library" ${STRINGZILLA_IS_MAIN_PROJECT}) set(STRINGZILLA_TARGET_ARCH - "" - CACHE STRING "Architecture to tell the compiler to optimize for (-march)") + "" + CACHE STRING "Architecture to tell the compiler to optimize for (-march)" +) # Includes set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) include(ExternalProject) include(CheckCSourceCompiles) -# Allow CMake 3.13+ to override options when using FetchContent / -# add_subdirectory -if(POLICY CMP0077) - cmake_policy(SET CMP0077 NEW) -endif() +# Allow CMake 3.13+ to override options when using FetchContent / add_subdirectory +if (POLICY CMP0077) + cmake_policy(SET CMP0077 NEW) +endif () # Configuration include(GNUInstallDirs) set(STRINGZILLA_INCLUDE_BUILD_DIR "${PROJECT_SOURCE_DIR}/include/") set(STRINGZILLA_INCLUDE_INSTALL_DIR "${CMAKE_INSTALL_INCLUDEDIR}") - -if(${CMAKE_VERSION} VERSION_EQUAL 3.13 OR ${CMAKE_VERSION} VERSION_GREATER 3.13) - include(CTest) - enable_testing() -endif() +if (${CMAKE_VERSION} VERSION_EQUAL 3.13 OR ${CMAKE_VERSION} VERSION_GREATER 3.13) + include(CTest) + enable_testing() +endif () if (MSVC) - # Remove /RTC* from MSVC debug flags by default (it will be added back in the set_compiler_flags function) - # Because /RTC* cannot be used without the crt so it needs to be disabled for that specific target - string(REGEX REPLACE "/RTC[^ ]*" "" CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") - string(REGEX REPLACE "/RTC[^ ]*" "" CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}") -endif() + # Remove /RTC* from MSVC debug flags by default (it will be added back in the set_compiler_flags function) Because + # /RTC* cannot be used without the crt so it needs to be disabled for that specific target + string(REGEX REPLACE "/RTC[^ ]*" "" CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") + string(REGEX REPLACE "/RTC[^ ]*" "" CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}") +endif () # Function to set compiler-specific flags -function(set_compiler_flags target cpp_standard target_arch) - get_target_property(target_type ${target} TYPE) +function (set_compiler_flags target cpp_standard target_arch) + get_target_property(target_type ${target} TYPE) - target_include_directories(${target} PRIVATE scripts) + target_include_directories(${target} PRIVATE scripts) - # Set output directory for single-configuration generators (like Make) - set_target_properties(${target} PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$<0:> - ) + # Set output directory for single-configuration generators (like Make) + set_target_properties(${target} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$<0:>) - # Set output directory for multi-configuration generators (like Visual Studio) - foreach(config IN LISTS CMAKE_CONFIGURATION_TYPES) - string(TOUPPER ${config} config_upper) - set_target_properties(${target} PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${config_upper} ${CMAKE_BINARY_DIR}/$<0:> - ) - endforeach() - - # Set the C++ standard - if(NOT ${cpp_standard} STREQUAL "") - set_target_properties(${target} PROPERTIES CXX_STANDARD ${cpp_standard}) - endif() - - # Use the /Zc:__cplusplus flag to correctly define the __cplusplus macro in MSVC - target_compile_options(${target} PRIVATE "$<$:/Zc:__cplusplus>") - - # Maximum warnings level & warnings as error. - # MVC uses numeric values: - # > 4068 for "unknown pragmas". - # > 4146 for "unary minus operator applied to unsigned type, result still unsigned". - # We also specify /utf-8 to properly UTF-8 symbols in tests. - target_compile_options( - ${target} - PRIVATE - "$<$:/Bt;/wd4068;/wd4146;/utf-8;/WX>" - "$<$:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas;-Wno-cast-function-type;-Wno-unused-function>" - "$<$:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas>" - "$<$:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas>" - ) - - # Set optimization options for different compilers differently - target_compile_options( - ${target} - PRIVATE - "$<$,$,$>>:-O3>" - "$<$,$,$>>:-g>" - "$<$,$,$>>:-O3>" - "$<$,$,$>>:-g>" - "$<$,$>:/O2>" - "$<$,$,$>>:/O2>" - "$<$,$,$>>:/Zi>" - ) - - if(NOT target_type STREQUAL "SHARED_LIBRARY") - if(MSVC) - target_compile_options(${target} PRIVATE "$<$:/RTC1>") - endif() - endif() - - # If available, enable Position Independent Code - get_target_property(target_pic ${target} POSITION_INDEPENDENT_CODE) - if(target_pic) - target_compile_options(${target} PRIVATE "$<$:-fPIC>") - target_link_options(${target} PRIVATE "$<$:-fPIC>") - target_compile_definitions(${target} PRIVATE "$<$:SZ_PIC>") - endif() - - # Avoid builtin functions where we know what we are doing. - target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memcmp>") - target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memchr>") - target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memcpy>") - target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memset>") - target_compile_options(${target} PRIVATE "$<$:/Oi->") - - # Check for ${target_arch} and set it or use the current system if not defined - if("${target_arch}" STREQUAL "") - # Only use the current system if we are not cross compiling - if((NOT CMAKE_CROSSCOMPILING) OR (CMAKE_SYSTEM_PROCESSOR MATCHES CMAKE_HOST_SYSTEM_PROCESSOR)) - if (NOT MSVC) - include(CheckCXXCompilerFlag) - check_cxx_compiler_flag("-march=native" supports_march_native) - if (supports_march_native) - target_compile_options(${target} PRIVATE "-march=native") - endif() - else() - # MSVC does not have a direct equivalent to -march=native - target_compile_options(${target} PRIVATE "/arch:AVX2") - endif() - endif() - else() + # Set output directory for multi-configuration generators (like Visual Studio) + foreach (config IN LISTS CMAKE_CONFIGURATION_TYPES) + string(TOUPPER ${config} config_upper) + set_target_properties(${target} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_${config_upper} ${CMAKE_BINARY_DIR}/$<0:>) + endforeach () + + # Set the C++ standard + if (NOT ${cpp_standard} STREQUAL "") + set_target_properties(${target} PROPERTIES CXX_STANDARD ${cpp_standard}) + endif () + + # Use the /Zc:__cplusplus flag to correctly define the __cplusplus macro in MSVC + target_compile_options(${target} PRIVATE "$<$:/Zc:__cplusplus>") + + # Maximum warnings level & warnings as error. MVC uses numeric values: > 4068 for "unknown pragmas". > 4146 for + # "unary minus operator applied to unsigned type, result still unsigned". We also specify /utf-8 to properly UTF-8 + # symbols in tests. target_compile_options( - ${target} - PRIVATE - "$<$:-march=${target_arch}>" - "$<$:/arch:${target_arch}>") - endif() - - # Define SZ_DETECT_BIG_ENDIAN macro based on system byte order - if(CMAKE_C_BYTE_ORDER STREQUAL "BIG_ENDIAN") - set(SZ_DETECT_BIG_ENDIAN 1) - else() - set(SZ_DETECT_BIG_ENDIAN 0) - endif() - - target_compile_definitions( - ${target} - PRIVATE - "SZ_DETECT_BIG_ENDIAN=${SZ_DETECT_BIG_ENDIAN}" - ) - - # Sanitizer options for Debug mode - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - if(NOT target_type STREQUAL "SHARED_LIBRARY") - target_compile_options( ${target} PRIVATE - "$<$:-fsanitize=address;-fsanitize=leak>" - "$<$:/fsanitize=address>") + "$<$:/Bt;/wd4068;/wd4146;/utf-8;/WX>" + "$<$:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas;-Wno-cast-function-type;-Wno-unused-function>" + "$<$:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas>" + "$<$:-Wall;-Wextra;-pedantic;-Werror;-Wfatal-errors;-Wno-unknown-pragmas>" + ) - target_link_options( + # Set optimization options for different compilers differently + target_compile_options( ${target} - PRIVATE - "$<$:-fsanitize=address;-fsanitize=leak>" - "$<$:/fsanitize=address>") - endif() - - # Define SZ_DEBUG macro based on build configuration - target_compile_definitions( - ${target} - PRIVATE - "$<$:SZ_DEBUG=1>" - "$<$>:SZ_DEBUG=0>" + PRIVATE "$<$,$,$>>:-O3>" + "$<$,$,$>>:-g>" + "$<$,$,$>>:-O3>" + "$<$,$,$>>:-g>" + "$<$,$>:/O2>" + "$<$,$,$>>:/O2>" + "$<$,$,$>>:/Zi>" ) - endif() -endfunction() - -function(define_launcher exec_name source cpp_standard target_arch) - add_executable(${exec_name} ${source}) - set_compiler_flags(${exec_name} ${cpp_standard} "${target_arch}") - target_link_libraries(${exec_name} PRIVATE stringzilla_header) - add_test(NAME ${exec_name} COMMAND ${exec_name}) -endfunction() - -if(${STRINGZILLA_BUILD_BENCHMARK}) - define_launcher(stringzilla_bench_search scripts/bench_search.cpp 17 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_bench_similarity scripts/bench_similarity.cpp 17 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_bench_sort scripts/bench_sort.cpp 17 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_bench_token scripts/bench_token.cpp 17 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_bench_container scripts/bench_container.cpp 17 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_bench_memory scripts/bench_memory.cpp 17 "${STRINGZILLA_TARGET_ARCH}") -endif() - -if(${STRINGZILLA_BUILD_TEST}) - # Make sure that the compilation passes for different C++ standards - # ! Keep in mind, MSVC only supports C++11 and newer. - define_launcher(stringzilla_test_cpp11 scripts/test.cpp 11 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_test_cpp14 scripts/test.cpp 14 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_test_cpp17 scripts/test.cpp 17 "${STRINGZILLA_TARGET_ARCH}") - define_launcher(stringzilla_test_cpp20 scripts/test.cpp 20 "${STRINGZILLA_TARGET_ARCH}") - - # Check system architecture to avoid complex cross-compilation workflows, but - # compile multiple backends: disabling all SIMD, enabling only AVX2, only AVX-512, only Arm Neon. - if(SZ_PLATFORM_X86) - # x86 specific backends - if (MSVC) - define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "AVX") - define_launcher(stringzilla_test_cpp20_haswell scripts/test.cpp 20 "AVX2") - define_launcher(stringzilla_test_cpp20_ice scripts/test.cpp 20 "AVX512") - else() - define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "ivybridge") - define_launcher(stringzilla_test_cpp20_haswell scripts/test.cpp 20 "haswell") - define_launcher(stringzilla_test_cpp20_ice scripts/test.cpp 20 "sapphirerapids") - endif() - elseif(SZ_PLATFORM_ARM) - # ARM specific backends - define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "armv8-a") - define_launcher(stringzilla_test_cpp20_neon scripts/test.cpp 20 "armv8-a+simd") - define_launcher(stringzilla_test_cpp20_sve scripts/test.cpp 20 "armv8.2-a+sve") - endif() -endif() + + if (NOT target_type STREQUAL "SHARED_LIBRARY") + if (MSVC) + target_compile_options(${target} PRIVATE "$<$:/RTC1>") + endif () + endif () + + # If available, enable Position Independent Code + get_target_property(target_pic ${target} POSITION_INDEPENDENT_CODE) + if (target_pic) + target_compile_options(${target} PRIVATE "$<$:-fPIC>") + target_link_options(${target} PRIVATE "$<$:-fPIC>") + target_compile_definitions(${target} PRIVATE "$<$:SZ_PIC>") + endif () + + # Avoid builtin functions where we know what we are doing. + target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memcmp>") + target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memchr>") + target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memcpy>") + target_compile_options(${target} PRIVATE "$<$:-fno-builtin-memset>") + target_compile_options(${target} PRIVATE "$<$:/Oi->") + + # Check for ${target_arch} and set it or use the current system if not defined + if ("${target_arch}" STREQUAL "") + # Only use the current system if we are not cross compiling + if ((NOT CMAKE_CROSSCOMPILING) OR (CMAKE_SYSTEM_PROCESSOR MATCHES CMAKE_HOST_SYSTEM_PROCESSOR)) + if (NOT MSVC) + include(CheckCXXCompilerFlag) + check_cxx_compiler_flag("-march=native" supports_march_native) + if (supports_march_native) + target_compile_options(${target} PRIVATE "-march=native") + endif () + else () + # MSVC does not have a direct equivalent to -march=native + target_compile_options(${target} PRIVATE "/arch:AVX2") + endif () + endif () + else () + target_compile_options( + ${target} PRIVATE "$<$:-march=${target_arch}>" + "$<$:/arch:${target_arch}>" + ) + endif () + + # Define SZ_DETECT_BIG_ENDIAN macro based on system byte order + if (CMAKE_C_BYTE_ORDER STREQUAL "BIG_ENDIAN") + set(SZ_DETECT_BIG_ENDIAN 1) + else () + set(SZ_DETECT_BIG_ENDIAN 0) + endif () + + target_compile_definitions(${target} PRIVATE "SZ_DETECT_BIG_ENDIAN=${SZ_DETECT_BIG_ENDIAN}") + + # Sanitizer options for Debug mode + if (CMAKE_BUILD_TYPE STREQUAL "Debug") + if (NOT target_type STREQUAL "SHARED_LIBRARY") + target_compile_options( + ${target} PRIVATE "$<$:-fsanitize=address;-fsanitize=leak>" + "$<$:/fsanitize=address>" + ) + + target_link_options( + ${target} PRIVATE "$<$:-fsanitize=address;-fsanitize=leak>" + "$<$:/fsanitize=address>" + ) + endif () + + # Define SZ_DEBUG macro based on build configuration + target_compile_definitions( + ${target} PRIVATE "$<$:SZ_DEBUG=1>" "$<$>:SZ_DEBUG=0>" + ) + endif () +endfunction () + +function (define_launcher exec_name source cpp_standard target_arch) + add_executable(${exec_name} ${source}) + set_compiler_flags(${exec_name} ${cpp_standard} "${target_arch}") + target_link_libraries(${exec_name} PRIVATE stringzilla_header) + add_test(NAME ${exec_name} COMMAND ${exec_name}) +endfunction () + +if (${STRINGZILLA_BUILD_BENCHMARK}) + define_launcher(stringzilla_bench_search scripts/bench_search.cpp 17 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_bench_similarity scripts/bench_similarity.cpp 17 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_bench_sort scripts/bench_sort.cpp 17 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_bench_token scripts/bench_token.cpp 17 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_bench_container scripts/bench_container.cpp 17 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_bench_memory scripts/bench_memory.cpp 17 "${STRINGZILLA_TARGET_ARCH}") +endif () + +if (${STRINGZILLA_BUILD_TEST}) + # Make sure that the compilation passes for different C++ standards ! Keep in mind, MSVC only supports C++11 and + # newer. + define_launcher(stringzilla_test_cpp11 scripts/test.cpp 11 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_test_cpp14 scripts/test.cpp 14 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_test_cpp17 scripts/test.cpp 17 "${STRINGZILLA_TARGET_ARCH}") + define_launcher(stringzilla_test_cpp20 scripts/test.cpp 20 "${STRINGZILLA_TARGET_ARCH}") + + # Check system architecture to avoid complex cross-compilation workflows, but compile multiple backends: disabling + # all SIMD, enabling only AVX2, only AVX-512, only Arm Neon. + if (SZ_PLATFORM_X86) + # x86 specific backends + if (MSVC) + define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "AVX") + define_launcher(stringzilla_test_cpp20_haswell scripts/test.cpp 20 "AVX2") + define_launcher(stringzilla_test_cpp20_ice scripts/test.cpp 20 "AVX512") + else () + define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "ivybridge") + define_launcher(stringzilla_test_cpp20_haswell scripts/test.cpp 20 "haswell") + define_launcher(stringzilla_test_cpp20_ice scripts/test.cpp 20 "sapphirerapids") + endif () + elseif (SZ_PLATFORM_ARM) + # ARM specific backends + define_launcher(stringzilla_test_cpp20_serial scripts/test.cpp 20 "armv8-a") + define_launcher(stringzilla_test_cpp20_neon scripts/test.cpp 20 "armv8-a+simd") + define_launcher(stringzilla_test_cpp20_sve scripts/test.cpp 20 "armv8.2-a+sve") + endif () +endif () # Define our libraries, first the header-only version add_library(stringzilla_header INTERFACE) add_library(${PROJECT_NAME}::stringzilla_header ALIAS stringzilla_header) target_include_directories( - stringzilla_header - INTERFACE $ - $) - - -if(${STRINGZILLA_BUILD_SHARED}) - - function(define_shared target) - add_library(${target} SHARED c/lib.c) - add_library(${PROJECT_NAME}::${target} ALIAS ${target}) - - set_target_properties(${target} PROPERTIES - VERSION ${PROJECT_VERSION} - SOVERSION 1 - POSITION_INDEPENDENT_CODE ON) - - if (SZ_PLATFORM_X86) - if (MSVC) - set_compiler_flags(${target} "" "SSE2") - else() - set_compiler_flags(${target} "" "ivybridge") - endif() - - target_compile_definitions(${target} PRIVATE - "SZ_USE_HASWELL=1" - "SZ_USE_SKYLAKE=1" - "SZ_USE_ICE=1" - "SZ_USE_NEON=0" - "SZ_USE_SVE=0") - elseif(SZ_PLATFORM_ARM) - set_compiler_flags(${target} "" "armv8-a") - - target_compile_definitions(${target} PRIVATE - "SZ_USE_HASWELL=0" - "SZ_USE_SKYLAKE=0" - "SZ_USE_ICE=0" - "SZ_USE_NEON=1" - "SZ_USE_SVE=1") - endif() - - if (MSVC) - # Add dependencies for necessary runtime libraries in case of static linking - # This ensures that basic runtime functions are available: - # msvcrt.lib: Microsoft Visual C Runtime, required for basic C runtime functions on Windows. - # vcruntime.lib: Microsoft Visual C++ Runtime library for basic runtime functions. - # ucrt.lib: Universal C Runtime, necessary for linking basic C functions like I/O. - target_link_libraries(${target} PRIVATE msvcrt.lib vcruntime.lib ucrt.lib) - endif() - - endfunction() - - define_shared(stringzilla_shared) - target_compile_definitions(stringzilla_shared PRIVATE "SZ_AVOID_LIBC=0") - target_compile_definitions(stringzilla_shared PRIVATE "SZ_OVERRIDE_LIBC=1") - target_include_directories(stringzilla_shared PUBLIC include) - - - # Try compiling a version without linking the LibC - # ! This is only for Linux and Windows, as on modern Arm-based MacOS machines - # ! we can't legally access Arm's "feature registers" without `sysctl` or `sysctlbyname`. - # So let's check if we are compiling for a Darwin-based OS. - if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - define_shared(stringzilla_bare) - target_compile_definitions(stringzilla_bare PRIVATE "SZ_AVOID_LIBC=1") - target_compile_definitions(stringzilla_bare PRIVATE "SZ_OVERRIDE_LIBC=1") - target_include_directories(stringzilla_bare PUBLIC include) - - # Avoid built-ins on MSVC and other compilers, as that will cause compilation errors - target_compile_options(stringzilla_bare PRIVATE - "$<$:-fno-builtin;-nostdlib>" - "$<$:/Oi-;/GS->") - target_link_options(stringzilla_bare PRIVATE "$<$:-nostdlib>") - target_link_options(stringzilla_bare PRIVATE "$<$:/NODEFAULTLIB>") - endif() -endif() - -if(STRINGZILLA_INSTALL) - install( - TARGETS stringzilla_shared - ARCHIVE - BUNDLE - FRAMEWORK - LIBRARY - OBJECTS - PRIVATE_HEADER - PUBLIC_HEADER - RESOURCE - RUNTIME) - install( - TARGETS stringzilla_bare - ARCHIVE - BUNDLE - FRAMEWORK - LIBRARY - OBJECTS - PRIVATE_HEADER - PUBLIC_HEADER - RESOURCE - RUNTIME) - install(DIRECTORY ${STRINGZILLA_INCLUDE_BUILD_DIR} DESTINATION ${STRINGZILLA_INCLUDE_INSTALL_DIR}) - install(DIRECTORY ./c/ DESTINATION /usr/src/${PROJECT_NAME}/) -endif() + stringzilla_header INTERFACE $ $ +) + +if (${STRINGZILLA_BUILD_SHARED}) + + function (define_shared target) + add_library(${target} SHARED c/lib.c) + add_library(${PROJECT_NAME}::${target} ALIAS ${target}) + + set_target_properties( + ${target} + PROPERTIES VERSION ${PROJECT_VERSION} + SOVERSION 1 + POSITION_INDEPENDENT_CODE ON + ) + + if (SZ_PLATFORM_X86) + if (MSVC) + set_compiler_flags(${target} "" "SSE2") + else () + set_compiler_flags(${target} "" "ivybridge") + endif () + + target_compile_definitions( + ${target} PRIVATE "SZ_USE_HASWELL=1" "SZ_USE_SKYLAKE=1" "SZ_USE_ICE=1" "SZ_USE_NEON=0" "SZ_USE_SVE=0" + ) + elseif (SZ_PLATFORM_ARM) + set_compiler_flags(${target} "" "armv8-a") + + target_compile_definitions( + ${target} PRIVATE "SZ_USE_HASWELL=0" "SZ_USE_SKYLAKE=0" "SZ_USE_ICE=0" "SZ_USE_NEON=1" "SZ_USE_SVE=1" + ) + endif () + + if (MSVC) + # Add dependencies for necessary runtime libraries in case of static linking. This ensures that basic + # runtime functions are available: + # + # * msvcrt.lib: Microsoft Visual C Runtime, required for basic C runtime functions on Windows. + # * vcruntime.lib: Microsoft Visual C++ Runtime library for basic runtime functions. + # * ucrt.lib: Universal C Runtime, necessary for linking basic C functions like I/O. + target_link_libraries(${target} PRIVATE msvcrt.lib vcruntime.lib ucrt.lib) + endif () + + endfunction () + + define_shared(stringzilla_shared) + target_compile_definitions(stringzilla_shared PRIVATE "SZ_AVOID_LIBC=0") + target_compile_definitions(stringzilla_shared PRIVATE "SZ_OVERRIDE_LIBC=1") + target_include_directories(stringzilla_shared PUBLIC include) + + # Try compiling a version without linking the LibC ! This is only for Linux and Windows, as on modern Arm-based + # MacOS machines ! we can't legally access Arm's "feature registers" without `sysctl` or `sysctlbyname`. So let's + # check if we are compiling for a Darwin-based OS. + if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + define_shared(stringzilla_bare) + target_compile_definitions(stringzilla_bare PRIVATE "SZ_AVOID_LIBC=1") + target_compile_definitions(stringzilla_bare PRIVATE "SZ_OVERRIDE_LIBC=1") + target_include_directories(stringzilla_bare PUBLIC include) + + # Avoid built-ins on MSVC and other compilers, as that will cause compilation errors + target_compile_options( + stringzilla_bare PRIVATE "$<$:-fno-builtin;-nostdlib>" + "$<$:/Oi-;/GS->" + ) + target_link_options(stringzilla_bare PRIVATE "$<$:-nostdlib>") + target_link_options(stringzilla_bare PRIVATE "$<$:/NODEFAULTLIB>") + endif () +endif () + +if (STRINGZILLA_INSTALL) + install( + TARGETS stringzilla_shared + ARCHIVE + BUNDLE + FRAMEWORK + LIBRARY + OBJECTS + PRIVATE_HEADER + PUBLIC_HEADER + RESOURCE + RUNTIME + ) + install( + TARGETS stringzilla_bare + ARCHIVE + BUNDLE + FRAMEWORK + LIBRARY + OBJECTS + PRIVATE_HEADER + PUBLIC_HEADER + RESOURCE + RUNTIME + ) + install(DIRECTORY ${STRINGZILLA_INCLUDE_BUILD_DIR} DESTINATION ${STRINGZILLA_INCLUDE_INSTALL_DIR}) + install(DIRECTORY ./c/ DESTINATION /usr/src/${PROJECT_NAME}/) +endif () From 366816ed5c9811de42b5affdfab41d329ede06cd Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 10 Mar 2025 11:10:30 +0000 Subject: [PATCH 185/751] Docs: Ignore formatting CMake --- .git-blame-ignore-revs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index c583f5fb..0f60e2c9 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -39,4 +39,4 @@ bd547453122e9f8565e5be15f137e7b0de37caca 22e3d1e34d62d68c1e89df7c8bdc201faa18a9de ecb377541d0c706cf8997faff4f026b07e3f76f3 0d982a45f842287d7e344f0d8b360f52482017f5 - +467b4b81cb4bc0e9a64844748a417762378918c9 From 47444066d8817831d961095c278b20a1a01d9678 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Tue, 11 Mar 2025 21:54:26 +0000 Subject: [PATCH 186/751] Add: All new benchmarking suite The initial version only reimplements the substring and byteset search benchmarks. --- .vscode/launch.json | 4 + scripts/bench.hpp | 734 +++++++++++++++++++++++------------ scripts/bench_search.cpp | 820 +++++++++++++++++++++++++-------------- 3 files changed, 999 insertions(+), 559 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 34ec245d..70e06dc5 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -50,6 +50,10 @@ { "name": "ASAN_OPTIONS", "value": "detect_leaks=0:atexit=1:strict_init_order=1:strict_string_checks=1" + }, + { + "name": "STRINGWARS_DATASET", + "value": "leipzig1M.txt" } ], "stopAtEntry": false, diff --git a/scripts/bench.hpp b/scripts/bench.hpp index fca1cfea..1d390238 100644 --- a/scripts/bench.hpp +++ b/scripts/bench.hpp @@ -1,16 +1,45 @@ /** + * @file bench.hpp * @brief Helper structures and functions for C++ benchmarks. + * + * The StringZilla benchmarking suite doesn't use any external frameworks like Criterion or Google Benchmark. + * There are several reasons for that: + * + * 1. Reduce the number of @b dependencies and the complexity of the build system. + * + * 2. Combine @b "stress-testing" with benchmarks to deduplicate logic. + * As we work with often large datasets, with complex preprocessing, and many different backends, + * we want to minimize the surface area we debug and maintain, keeping track of string-specific + * properties, like: + * + * - Is the string start aligned in memory? + * - Does it take more than one cache line? Is it's length a multiple of the SIMD vector size? + * - Is the string cached in the L1 or L2 cache? Can the dataset fit in L3? + * + * As part of that stress-testing, on failure, those properties will be persisted in a file on disk. + * + * 3. Use cheaper profiling methods like @b CPU-counter instructions, as opposed to wall-clock time. + * Assuming we can clearly isolate single-threaded workloads and are more interested in the number + * of retired instructions, CPU counters can be more accurate and less noisy. + * + * 4. Integrate with Linux @b `perf` and other tools for more detailed analysis. + * We can isolate the relevant pieces of code, excluding the preprocessing costs from the actual workload. + * + * 5. Visualize the results differently, with a compact output for both generic workloads and special cases. */ #pragma once #include #include // `std::chrono::high_resolution_clock` #include // `std::setlocale` #include // `std::memcpy` +#include // `std::invalid_argument` #include // `std::equal_to` #include // `std::numeric_limits` #include // `std::random_device`, `std::mt19937` #include // `std::hash` -#include +#include // `std::vector` +#include // `std::regex`, `std::regex_search` +#include // `std::this_thread::sleep_for` #include // Requires C++17 @@ -20,125 +49,127 @@ #include "test.hpp" // `read_file` namespace sz = ashvardanian::stringzilla; +namespace stdc = std::chrono; namespace ashvardanian { namespace stringzilla { namespace scripts { -using seconds_t = double; +using accurate_clock_t = stdc::high_resolution_clock; template std::size_t round_up_to_multiple(std::size_t n) { return n == 0 ? multiple : ((n + multiple - 1) / multiple) * multiple; } -struct benchmark_result_t { - std::size_t iterations = 0; +struct call_result_t { + /** @brief Number of input bytes processed. */ std::size_t bytes_passed = 0; - seconds_t seconds = 0; + /** @brief Some value used to compare execution result between the baseline and accelerated backend. */ + std::size_t check_value = 0; + /** @brief For some operations with non-linear complexity, the throughput should be measured differently. */ + std::size_t operations = 0; + + call_result_t() = default; + call_result_t(std::size_t bytes_passed, std::size_t check_value = 0, std::size_t operations = 0) + : bytes_passed(bytes_passed), check_value(check_value), operations(operations) {} +}; + +struct callable_no_op_t { + call_result_t operator()(std::size_t) const { return {}; } }; -using unary_function_t = std::function; -using binary_function_t = std::function; +using profiled_function_t = std::function; /** - * @brief Wraps a binary function to compare all combinations of two tokens. - * Designed to benchmark functions that on-average take very different times to execute - * for the same string or different strings. For equality checks it's similar to a typical - * load when probing a Hash Table. For relative ordering, it's similar to sorting a dense - * array with many similar strings. + * @brief Cross-platform function to get the number of CPU cycles elapsed @b only on the current core. + * Used as a more efficient alternative to `std::chrono::high_resolution_clock`. */ +inline std::uint64_t cpu_cycle_counter() { +#if defined(__i386__) || defined(__x86_64__) + // Use x86 inline assembly for `rdtsc` only if actually compiling for x86. + unsigned int lo, hi; + __asm__ volatile("rdtsc" : "=a"(lo), "=d"(hi)); + return (static_cast(hi) << 32) | lo; +#elif defined(__aarch64__) || defined(_SZ_IS_ARM64) + // On ARM64, read the virtual count register (CNTVCT_EL0) which provides cycle count. + std::uint64_t cnt; + asm volatile("mrs %0, cntvct_el0" : "=r"(cnt)); + return cnt; +#else + return 0; +#endif +} + +/** @brief Measures the approximate number of CPU cycles per second. */ +inline std::uint64_t cpu_cycles_per_second() { + std::uint64_t start = cpu_cycle_counter(); + std::this_thread::sleep_for(stdc::seconds(1)); + std::uint64_t end = cpu_cycle_counter(); + return end - start; +} + +/** @brief Measures the duration of a single call to the given function. */ template -binary_function_t binary_combinations(function_type_ function) { - return binary_function_t([function](std::string_view a, std::string_view b) { - // Assuming most outputs here will be 0 or 1, we want to combine them to with different - // multiples to ensure a unique output for each combination. - return function(a, b) * 1 + function(a, a) * 2 + function(b, a) * 4 + function(b, b) * 8; - }); +double seconds_per_call(function_type_ &&function) { + accurate_clock_t::time_point start = accurate_clock_t::now(); + function(); + accurate_clock_t::time_point end = accurate_clock_t::now(); + return stdc::duration_cast(end - start).count() / 1.e9; } /** - * @brief Wrapper for a single execution backend. + * @brief Allows time-limited for-loop iteration, similar to Google Benchmark's `for (auto _ : state)`. + * Use as `for (auto running_seconds : repeat_up_to(5.0)) { ... }`. */ -template -struct tracked_function_gt { - std::string name {""}; - function_type_ function {nullptr}; - bool needs_testing {false}; - - std::size_t failed_count; - std::vector failed_strings; - benchmark_result_t results; - - tracked_function_gt(std::string name = "", function_type_ function = nullptr, bool needs_testing = false) - : name(name), function(function), needs_testing(needs_testing), failed_count(0), failed_strings(), results() {} - - tracked_function_gt(tracked_function_gt const &) = default; - tracked_function_gt &operator=(tracked_function_gt const &) = default; - - void print() const { - bool is_binary = std::is_same(); - - // If failures have occurred, output them to file to simplify the debugging process. - bool contains_failures = !failed_strings.empty(); - if (contains_failures) { - // The file name is made of the string hash and the function name. - for (std::size_t fail_index = 0; fail_index != failed_strings.size();) { - std::string const &first_argument = failed_strings[fail_index]; - std::string file_name = - "failed_" + name + "_" + std::to_string(std::hash {}(first_argument)); - if (is_binary) { - std::string const &second_argument = failed_strings[fail_index + 1]; - write_file(file_name + ".first.txt", first_argument); - write_file(file_name + ".second.txt", second_argument); - fail_index += 2; - } - else { - write_file(file_name + ".txt", first_argument); - fail_index += 1; - } - } +struct repeat_up_to { + double max_seconds = 0; + double passed_seconds = 0; + + struct end_sentinel {}; + class iterator { + accurate_clock_t::time_point start_time_; + double max_seconds_ = 0; + double &passed_seconds_; + + public: + inline iterator(double max_seconds, double &passed_seconds) + : start_time_(accurate_clock_t::now()), max_seconds_(max_seconds), passed_seconds_(passed_seconds) {} + inline bool operator!=(end_sentinel) const { + accurate_clock_t::time_point current_time = accurate_clock_t::now(); + passed_seconds_ = stdc::duration_cast(current_time - start_time_).count() / 1.e9; + return passed_seconds_ < max_seconds_; } - - // Now let's print in the format: - // - name, up to 32 characters - // - throughput in GB/s with up to 3 significant digits, 10 characters - // - call latency in ns with up to 1 significant digit, 10 characters - // - number of failed tests, 10 characters - // - first example of a failed test, up to 20 characters - char const *format; - if (is_binary) { format = "- %-32s %15.4f GB/s %15.1f ns %10zu errors in %10zu iterations %-20s %-20s\n"; } - else { format = "- %-32s %15.4f GB/s %15.1f ns %10zu errors in %10zu iterations %-20s\n"; } - - std::printf(format, name.c_str(), results.bytes_passed / results.seconds / 1.e9, - results.seconds * 1e9 / results.iterations, failed_count, results.iterations, - failed_strings.size() ? failed_strings[0].c_str() : "", - failed_strings.size() >= 2 && is_binary ? failed_strings[1].c_str() : ""); - } + inline double operator*() const { return passed_seconds_; } + constexpr void operator++() {} // No-op + }; + + inline repeat_up_to(double max_seconds) : max_seconds(max_seconds) {} + inline iterator begin() { return {max_seconds, passed_seconds}; } + inline end_sentinel end() const noexcept { return {}; } + inline double seconds() const noexcept { return passed_seconds; } }; -using tracked_unary_functions_t = std::vector>; -using tracked_binary_functions_t = std::vector>; - /** * @brief Stops compilers from optimizing out the expression. - * Shamelessly stolen from Google Benchmark. + * Shamelessly stolen from Google Benchmark's @b `DoNotOptimize`. */ template -inline void do_not_optimize(argument_type &&value) { +static void do_not_optimize(argument_type &&value) noexcept { + #if defined(_MSC_VER) // MSVC using plain_type = typename std::remove_reference::type; // Use the `volatile` keyword and a memory barrier to prevent optimization volatile plain_type *p = &value; _ReadWriteBarrier(); #else // Other compilers (GCC, Clang, etc.) - asm volatile("" : "+r,m"(value) : : "memory"); + __asm__ __volatile__("" : "+g"(value) : : "memory"); #endif } /** - * @brief Rounds the number down to the preceding power of two. - * Equivalent to `std::bit_ceil`. + * @brief Rounds the number @b down to the preceding power of two. + * @see Equivalent to `std::bit_floor`: https://en.cppreference.com/w/cpp/numeric/bit_floor */ inline std::size_t bit_floor(std::size_t n) { if (n == 0) return 0; @@ -147,6 +178,10 @@ inline std::size_t bit_floor(std::size_t n) { return static_cast(1) << most_siginificant_bit_position; } +/** + * @brief Tokenizes a string with the given separator predicate. + * @see For faster ways to tokenize a string with STL: https://ashvardanian.com/posts/splitting-strings-cpp/ + */ template inline std::vector tokenize(std::string_view str, is_separator_callback_type &&is_separator) { std::vector words; @@ -175,219 +210,402 @@ inline std::vector filter_by_length(std::vector tokens; - std::vector lines; + + bool allow(std::string const &benchmark_name) const { + return filter.empty() || std::regex_search(benchmark_name, std::regex(filter)); + } }; /** - * @brief Loads a dataset from a file. + * @brief Prepares the environment for benchmarking based on environment variables and default settings. + * It's expected that different workloads may use different default datasets and tokenization modes, + * but time limits and seeds are usually consistent across all benchmarks. + * + * @param[in] argc Number of command-line string arguments. Not used in reality. + * @param[in] argv Array of command-line string arguments. Not used in reality. + * + * @param[in] default_dataset Path to the default dataset file, if the @b `STRINGWARS_DATASET` is not set. + * @param[in] default_tokens Tokenization mode, if the @b `STRINGWARS_TOKENS` is not set. + * @param[in] default_duration Time limit per benchmark, if the @b `STRINGWARS_DURATION` is not set. + * + * @param[in] default_stress Whether to stress-test the backends, if the @b `STRINGWARS_STRESS` is not set. + * @param[in] default_stress_dir Directory for stress-testing logs, if the @b `STRINGWARS_STRESS_DIR` is not set. + * @param[in] default_stress_limit Max number of failures to tolerate, if the @b `STRINGWARS_STRESS_LIMIT` is not set. + * @param[in] default_stress_duration Time limit per stress-test, if the @b `STRINGWARS_STRESS_DURATION` is not set. + * + * @param[in] default_filter Regular expression to filter the backends, if the @b `STRINGWARS_FILTER` is not set. + * @param[in] default_seed Seed for reproducibility, if the @b `STRINGWARS_SEED` is not set. */ -inline dataset_t make_dataset_from_path(std::string path) { - dataset_t data; - data.text = read_file(path); - data.text.resize(bit_floor(data.text.size())); // Shrink to the nearest power of two - data.tokens = tokenize(data.text); - data.tokens.resize(bit_floor(data.tokens.size())); // Shrink to the nearest power of two - data.lines = tokenize(data.text, [](char c) { return c == '\n'; }); - data.lines.resize(bit_floor(data.lines.size())); // Shrink to the nearest power of two - -#if !SZ_DEBUG // Shuffle only in release mode - auto &generator = global_random_generator(); - std::shuffle(data.tokens.begin(), data.tokens.end(), generator); - std::shuffle(data.lines.begin(), data.lines.end(), generator); -#endif +inline environment_t build_environment( // + int argc, char const *argv[], //< Ignored + std::string default_dataset, environment_t::tokenization_t default_tokens, //< Mandatory + std::size_t default_duration = SZ_DEBUG ? 1 : 10, //< Optional + bool default_stress = true, // + std::string default_stress_dir = ".tmp", // + std::size_t default_stress_limit = 1, // + std::size_t default_stress_duration = SZ_DEBUG ? 1 : 10, // + std::string default_filter = "", // + std::size_t default_seed = 0 // + ) noexcept(false) { + + sz_unused(argc && argv); // Unused in this context + environment_t env; + + // Use `STRINGWARS_DATASET` if set, otherwise `default_dataset` + if (char const *env_var = std::getenv("STRINGWARS_DATASET")) { env.path = env_var; } + else { env.path = default_dataset; } + + // Use `STRINGWARS_FILTER` if set, otherwise `default_filter` + if (char const *env_var = std::getenv("STRINGWARS_FILTER")) { env.filter = env_var; } + else { env.filter = default_filter; } + + // Use `STRINGWARS_DURATION` if set, otherwise `default_duration` + if (char const *env_var = std::getenv("STRINGWARS_DURATION")) { + env.benchmark_seconds = std::stoul(env_var); + if (env.benchmark_seconds == 0) throw std::invalid_argument("The time limit must be greater than 0."); + } + else { env.benchmark_seconds = default_duration; } - // Report some basic stats about the dataset - double mean_token_bytes = 0, mean_line_bytes = 0; - for (auto const &str : data.tokens) mean_token_bytes += str.size(); - for (auto const &str : data.lines) mean_line_bytes += str.size(); - mean_token_bytes /= data.tokens.size(); - mean_line_bytes /= data.lines.size(); - - std::setlocale(LC_NUMERIC, ""); - std::printf( // - "Parsed the dataset with:\n" // - "- %zu words of mean length ~ %.2f bytes\n" // - "- %zu lines of mean length ~ %.2f bytes\n" // - "- %zu bytes in total\n", // - data.tokens.size(), mean_token_bytes, data.lines.size(), mean_line_bytes, data.text.size()); - - return data; -} + // Use `STRINGWARS_SEED` if set, otherwise `default_seed` + if (char const *env_var = std::getenv("STRINGWARS_SEED")) { + env.seed = std::stoul(env_var); + if (env.seed == 0) throw std::invalid_argument("The seed must be a positive integer."); + } + else { env.seed = default_seed; } + + // Use `STRINGWARS_TOKENS` if set, otherwise `default_tokens` + if (char const *env_var = std::getenv("STRINGWARS_TOKENS")) { + std::string token_arg(env_var); + if (token_arg == "file") { env.tokenization = environment_t::file_k; } + else if (token_arg == "lines") { env.tokenization = environment_t::lines_k; } + else if (token_arg == "words") { env.tokenization = environment_t::words_k; } + else { + // If it's not one of the known strings, assume it's an unsigned integer (for N-grams). + env.tokenization = static_cast(std::stoul(token_arg)); + if (env.tokenization == 0) + throw std::invalid_argument( + "The tokenization mode must be 'file', 'line', 'word', or a positive integer."); + } + } + else { env.tokenization = default_tokens; } + + // Extract the stress-testing settings + if (char const *env_var = std::getenv("STRINGWARS_STRESS")) { + bool is_zero = std::strcmp(env_var, "0") != 0 || std::strcmp(env_var, "false") != 0; + bool is_one = std::strcmp(env_var, "1") != 0 || std::strcmp(env_var, "true") != 0; + env.stress = is_one; + if (!is_zero && !is_one) throw std::invalid_argument("The stress-testing flag must be '0' or '1'."); + } + else { env.stress = default_stress; } + if (char const *env_var = std::getenv("STRINGWARS_STRESS_DURATION")) { + env.stress_seconds = std::stoul(env_var); + if (env.stress_seconds == 0) + throw std::invalid_argument("The stress-testing time limit must be greater than 0."); + } + else { env.stress_seconds = default_stress_duration; } + if (char const *env_var = std::getenv("STRINGWARS_STRESS_DIR")) { env.stress_dir = env_var; } + else { env.stress_dir = default_stress_dir; } + if (char const *env_var = std::getenv("STRINGWARS_STRESS_LIMIT")) { + env.stress_limit = std::stoul(env_var); + if (env.stress_limit == 0) throw std::invalid_argument("The stress-testing limit must be greater than 0."); + } + else { env.stress_limit = default_stress_limit; } -/** - * @brief Loads a dataset, depending on the passed CLI arguments. - */ -inline dataset_t prepare_benchmark_environment(int argc, char const *argv[]) { - if (argc < 2 || argc > 3) - throw std::runtime_error("Usage: " + std::string(argv[0]) + " [seconds_per_benchmark]"); + env.dataset = read_file(env.path); + env.dataset.resize(bit_floor(env.dataset.size())); // Shrink to the nearest power of two - dataset_t data = make_dataset_from_path(argv[1]); + // Tokenize the dataset according to the tokenization mode + if (env.tokenization == environment_t::file_k) { env.tokens.push_back(env.dataset); } + else if (env.tokenization == environment_t::lines_k) { + env.tokens = tokenize(env.dataset, [](char c) { return c == '\n'; }); + } + else if (env.tokenization == environment_t::words_k) { env.tokens = tokenize(env.dataset); } + else { + std::size_t n = static_cast(env.tokenization); + env.tokens = filter_by_length(tokenize(env.dataset), n, std::equal_to()); + } + env.tokens.resize(bit_floor(env.tokens.size())); // Shrink to the nearest power of two + + // In "RELEASE" mode, shuffle tokens to avoid bias. + char const *seed_message = " (not used in DEBUG mode)"; +#if !defined(SZ_DEBUG) + std::mt19937 generator(static_cast(env.seed)); + std::shuffle(env.tokens.begin(), env.tokens.end(), generator); + seed_message = ""; +#endif - // If the seconds_per_benchmark argument is provided, update the value in the dataset - if (argc == 3) { - seconds_per_benchmark = std::stoi(argv[2]); - if (seconds_per_benchmark == 0) - throw std::invalid_argument("The number of seconds per task must be greater than 0."); + auto const mean_token_length = + std::accumulate(env.tokens.begin(), env.tokens.end(), 0, + [](std::size_t sum, std::string_view token) { return sum + token.size(); }) * + 1.0 / env.tokens.size(); + + // Group integer decimal separators by 3 + // https://www.ibm.com/docs/en/i/7.4?topic=categories-lc-numeric-category + std::setlocale(LC_NUMERIC, "en_US.UTF-8"); + std::printf("Environment built with the following settings:\n"); + std::printf(" - Dataset path: %s\n", env.path.c_str()); + std::printf(" - Time limit: %zu seconds per benchmark (%zu per stress-test)\n", env.benchmark_seconds, + env.stress_seconds); + if (!env.filter.empty()) std::printf(" - Algorithm filter: %s\n", env.filter.c_str()); + std::printf(" - Tokenization mode: "); + switch (env.tokenization) { + case environment_t::file_k: std::printf("file\n"); break; + case environment_t::lines_k: std::printf("line\n"); break; + case environment_t::words_k: std::printf("word\n"); break; + default: std::printf("%zu-grams\n", static_cast(env.tokenization)); break; } + std::printf(" - Seed: %zu%s\n", env.seed, seed_message); + std::printf(" - Loaded dataset size: %zu bytes\n", env.dataset.size()); + std::printf(" - Number of tokens: %zu\n", env.tokens.size()); + std::printf(" - Mean token length: %.2f bytes\n", mean_token_length); - return data; + return env; } -inline sz_string_view_t to_c(std::string_view str) noexcept { return {str.data(), str.size()}; } -inline sz_string_view_t to_c(std::string const &str) noexcept { return {str.data(), str.size()}; } -inline sz_string_view_t to_c(sz::string_view str) noexcept { return {str.data(), str.size()}; } -inline sz_string_view_t to_c(sz::string const &str) noexcept { return {str.data(), str.size()}; } -inline sz_string_view_t to_c(sz_string_view_t str) noexcept { return str; } - /** - * @brief Invoke the same function many times, until the total time elapsed exceeds the limit. - * @return Total seconds elapsed. + * @brief Uses C-style file IO to save information about the most recent stress test failure. + * Files can be found in: "$STRINGWARS_STRESS_DIR/failed_$time_$name.txt". */ -template -seconds_t repeat_until_limit(function_type_ &&function) { +inline void log_stress_failure(environment_t const &env, std::string const &name, std::size_t input_index, + std::size_t expected_check_value, std::size_t actual_check_value) noexcept(false) { - namespace stdc = std::chrono; - using clock_t = stdc::high_resolution_clock; - clock_t::time_point start_time = clock_t::now(); - seconds_t seconds = 0; + std::string file_name = "failed_" + name + "_" + std::to_string(input_index) + ".txt"; + std::string file_path = env.stress_dir + "/" + file_name; + std::FILE *file = std::fopen(file_path.c_str(), "w"); + if (!file) throw std::runtime_error("Failed to open file for writing: " + file_name); - while (seconds < seconds_per_benchmark) { - function(); - clock_t::time_point current_time = clock_t::now(); - seconds = stdc::duration_cast(current_time - start_time).count() / 1.e9; - } - return seconds; + std::fprintf(file, "Expected: %zu\n", expected_check_value); + std::fprintf(file, "Actual: %zu\n", actual_check_value); + std::fclose(file); } -/** - * @brief Loop over all elements in a dataset in somewhat random order, benchmarking the function cost. - * @param strings Strings to loop over. Length must be a power of two. - * @param function Function to be applied to each `sz_string_view_t`. Must return the number of bytes processed. - * @return Number of seconds per iteration. - */ -template -benchmark_result_t bench_on_tokens(strings_type_ &&strings, function_type_ &&function) { +struct benchmark_result_t { + std::string name; + bool skipped = false; + + std::size_t stress_calls = 0; + std::size_t profiled_calls = 0; + double profiled_seconds = 0; + std::uint64_t profiled_cpu_cycles = 0; + + std::size_t bytes_passed = 0; //< Pulled from the `call_result_t` + std::size_t operations = 0; //< Pulled from the `call_result_t` + std::size_t errors = 0; //< Pulled from the `call_result_t` + + inline benchmark_result_t &operator+=(call_result_t const &run) noexcept { + bytes_passed += run.bytes_passed; + operations += run.operations; + return *this; + } - benchmark_result_t result; - std::size_t const lookup_mask = bit_floor(strings.size()) - 1; - result.seconds = repeat_until_limit([&]() { - // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking - result.bytes_passed += // - function(strings[(result.iterations + 0) & lookup_mask]) + - function(strings[(result.iterations + 1) & lookup_mask]) + - function(strings[(result.iterations + 2) & lookup_mask]) + - function(strings[(result.iterations + 3) & lookup_mask]); - result.iterations += 4; - }); + /** + * @brief Logs the benchmark results to the console, including the throughput and latency. + * + * Example output: + * + * @code{.unparsed} + * Benchmarking sz_find_serial: + * - Performance: 0.00 TOps/s @ 0.00 ns/call + * - Errors: 1 in 1 calls + * @endcode + */ + benchmark_result_t const &log() const { + benchmark_result_t const &result = *this; + if (result.skipped) return result; + std::printf("Benchmarking %s:\n", result.name.c_str()); + + // Infer the latency from the number of calls and the total time + auto duration = result.profiled_seconds * 1e9 / result.profiled_calls; + auto duration_unit = "ns"; + if (duration > 1e3) duration /= 1e3, duration_unit = "us"; + if (duration > 1e3) duration /= 1e3, duration_unit = "ms"; + if (duration > 1e3) duration /= 1e3, duration_unit = "s"; + + // We may want to analyze the call latency distribution: + // auto cpu_frequency = result.profiled_cpu_cycles / result.profiled_seconds; + // auto cpu_frequency_unit = "Hz"; + // if (cpu_frequency > 1e3) cpu_frequency /= 1e3, cpu_frequency_unit = "KHz"; + // if (cpu_frequency > 1e3) cpu_frequency /= 1e3, cpu_frequency_unit = "MHz"; + // if (cpu_frequency > 1e3) cpu_frequency /= 1e3, cpu_frequency_unit = "GHz"; + + // Infer the throughput from the number of operations and the total time + auto throughput = (result.operations ? result.operations : result.bytes_passed) / result.profiled_seconds; + auto throughput_unit = result.operations ? "Ops/s" : "B/s"; + if (throughput > 1e3) throughput /= 1e3, throughput_unit = result.operations ? "KOps/s" : "KB/s"; + if (throughput > 1e3) throughput /= 1e3, throughput_unit = result.operations ? "MOps/s" : "MB/s"; + if (throughput > 1e3) throughput /= 1e3, throughput_unit = result.operations ? "GOps/s" : "GB/s"; + + // Print to console + std::printf(" - Performance: %.2f %s @ %.2f %s/call\n", throughput, throughput_unit, duration, duration_unit); + if (result.errors) std::printf(" - Errors: %zu in %zu calls\n", result.errors, result.stress_calls); + + return result; + } - return result; -} + /** + * @brief Logs @b relative results to the console, comparing @p this to a @p base result. + * + * Example output: + * + * @code{.unparsed} + * Benchmarking sz_find_skylake: + * - Performance: 0.00 TOps/s @ 0.00 ns/call + * - Errors: 1 in 1 calls + * - Relative performance: +25% vs sz_find_serial + * @endcode + */ + benchmark_result_t const &log(benchmark_result_t const &base) const { + benchmark_result_t const &new_ = *this; + new_.log(); + + if (new_.skipped || base.skipped) return new_; //? Nothing to compare to + auto base_throughput = (base.operations ? base.operations : base.bytes_passed) / base.profiled_seconds; + auto new_throughput = (new_.operations ? new_.operations : new_.bytes_passed) / new_.profiled_seconds; + auto relative_throughput = new_throughput / base_throughput; + + // Now format the relative improvement as a percentage for small changes and as a multiplier for large ones, + // formatting it with a plus and a green color for improvements and a minus and a red color for regressions. + auto relative_color = relative_throughput > 1 ? "\033[32m" : "\033[31m"; + auto relative_sign = relative_throughput > 1 ? "+" : "-"; + auto relative_unit = relative_throughput > 2 ? "x" : "%"; + if (relative_throughput < 0.5) relative_throughput = 1 / relative_throughput, relative_unit = "x"; + if (std::strcmp(relative_unit, "%") == 0) relative_throughput *= 100; + std::printf(" - Relative performance: %s%s%.0f %s\033[0m vs. %s\n", relative_color, relative_sign, + relative_throughput, relative_unit, base.name.c_str()); + return new_; + } +}; /** - * @brief Loop over all elements in a dataset, benchmarking the function cost. - * @param strings Strings to loop over. Length must be a power of two. - * @param function Function to be applied to pairs of `sz_string_view_t`. - * Must return the number of bytes processed. - * @return Number of seconds per iteration. + * @brief Loops over all tokens (in loop-unrolled batches) in environment and applies the given unary function. + * @param[in] env Environment with the dataset and tokens. + * @param[in] name Name of the benchmark, used for logging. + * @param[in] baseline Optional serial analog, against which the accelerated function will be stress-tested. + * @param[in] callable Unary function taking a @b `std::size_t` token index and returning a @b `call_result_t`. + * @return Profiling results, including the number of cycles, bytes processed, and error counts. */ -template -benchmark_result_t bench_on_token_pairs(strings_type_ &&strings, function_type_ &&function) { +template +benchmark_result_t benchmark(environment_t const &env, std::string const &name, baseline_type_ &&baseline, + callable_type_ &&callable) { benchmark_result_t result; - std::size_t lookup_mask = bit_floor(strings.size()) - 1; - std::size_t largest_prime = static_cast(18446744073709551557ull); - result.seconds = repeat_until_limit([&]() { - // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking - auto second_index = (result.iterations * largest_prime) & lookup_mask; - result.bytes_passed += // - function(strings[(result.iterations + 0) & lookup_mask], strings[second_index]) + - function(strings[(result.iterations + 1) & lookup_mask], strings[second_index]) + - function(strings[(result.iterations + 2) & lookup_mask], strings[second_index]) + - function(strings[(result.iterations + 3) & lookup_mask], strings[second_index]); - result.iterations += 4; - }); - - return result; -} - -/** - * @brief Evaluation for unary string operations: hashing. - */ -template -void bench_unary_functions(strings_type_ &&strings, functions_type &&variants) { - - for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { - auto &variant = variants[variant_idx]; - - // Tests - if (variant.function && variant.needs_testing) { - bench_on_tokens(strings, [&](auto str) -> std::size_t { - auto baseline = variants[0].function(str); - auto result = variant.function(str); - if (result != baseline) { - ++variant.failed_count; - if (variant.failed_strings.empty()) { - variant.failed_strings.push_back({to_c(str).start, to_c(str).length}); - } - } - return to_c(str).length; - }); - } + result.name = name; + if (!env.allow(name)) { + result.skipped = true; + return result; + } - // Benchmarks - if (variant.function) { - variant.results = bench_on_tokens(strings, [&](auto str) -> std::size_t { - do_not_optimize(variant.function(str)); - return to_c(str).length; - }); + std::size_t const lookup_mask = bit_floor(env.tokens.size()) - 1; + if constexpr (!std::is_same()) + for (auto running_seconds : repeat_up_to(env.stress_seconds)) { + std::size_t const input_index = (result.stress_calls++) & lookup_mask; + call_result_t const accelerated_result = callable(input_index); + call_result_t const baseline_result = baseline(input_index); + if (accelerated_result.check_value == baseline_result.check_value) continue; // No failures + + // If we got here, the error needs to be reported and investigated. + ++result.errors; + if (result.errors > env.stress_limit) { + std::printf("Too many errors in %s after %.3f seconds. Stopping the test.\n", name.c_str(), + running_seconds); + std::terminate(); + } + log_stress_failure(env, name, input_index, baseline_result.check_value, accelerated_result.check_value); } - variant.print(); + // For profiling, we will first run the benchmark just once to get a rough estimate of the time. + // But then we will repeat it in an unrolled fashion for a more accurate measurement. + result.profiled_seconds += seconds_per_call([&] { + std::uint64_t start_cycle = cpu_cycle_counter(); + result += callable(0); // First input for debugging + std::uint64_t end_cycle = cpu_cycle_counter(); + result.profiled_calls += 1; + result.profiled_cpu_cycles += end_cycle - start_cycle; + }); + if (result.profiled_seconds >= env.benchmark_seconds) return result; + + // Repeat the benchmarks in unrolled batches until the time limit is reached. + for (auto running_seconds : repeat_up_to(env.benchmark_seconds - result.profiled_seconds)) { + std::uint64_t start_cycle = cpu_cycle_counter(); + call_result_t r0 = callable((result.profiled_calls + 0) & lookup_mask); + call_result_t r1 = callable((result.profiled_calls + 1) & lookup_mask); + call_result_t r2 = callable((result.profiled_calls + 2) & lookup_mask); + call_result_t r3 = callable((result.profiled_calls + 3) & lookup_mask); + std::uint64_t end_cycle = cpu_cycle_counter(); + + // Aggregate all of them: + result += r0; + result += r1; + result += r2; + result += r3; + result.profiled_calls += 4; + result.profiled_cpu_cycles += end_cycle - start_cycle; + result.profiled_seconds = running_seconds; } + + return result; } /** - * @brief Evaluation for binary string operations: equality, ordering, prefix, suffix, distance. + * @brief Loops over all tokens (in loop-unrolled batches) in environment and applies the given unary function. + * @param[in] env Environment with the dataset and tokens. + * @param[in] name Name of the benchmark, used for logging. + * @param[in] callable Unary function taking a @b `std::size_t` token index and returning a @b `call_result_t`. + * @return Profiling results, including the number of cycles, bytes processed, and error counts. */ -template -void bench_binary_functions(strings_type_ &&strings, functions_type &&variants) { - - for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { - auto &variant = variants[variant_idx]; - - // Tests - if (variant.function && variant.needs_testing) { - bench_on_token_pairs(strings, [&](auto str_a, auto str_b) -> std::size_t { - auto baseline = variants[0].function(str_a, str_b); - auto result = variant.function(str_a, str_b); - if (result != baseline) { - ++variant.failed_count; - if (variant.failed_strings.empty()) { - variant.failed_strings.push_back({to_c(str_a).start, to_c(str_a).length}); - variant.failed_strings.push_back({to_c(str_b).start, to_c(str_b).length}); - } - } - return to_c(str_a).length + to_c(str_b).length; - }); - } - - // Benchmarks - if (variant.function) { - variant.results = bench_on_token_pairs(strings, [&](auto str_a, auto str_b) -> std::size_t { - do_not_optimize(variant.function(str_a, str_b)); - return to_c(str_a).length + to_c(str_b).length; - }); - } - - variant.print(); - } +template +benchmark_result_t benchmark(environment_t const &env, std::string const &name, callable_type_ &&callable) { + return benchmark(env, name, callable_no_op_t {}, callable); } +inline sz_string_view_t to_c(std::string_view str) noexcept { return {str.data(), str.size()}; } +inline sz_string_view_t to_c(std::string const &str) noexcept { return {str.data(), str.size()}; } +inline sz_string_view_t to_c(sz::string_view str) noexcept { return {str.data(), str.size()}; } +inline sz_string_view_t to_c(sz::string const &str) noexcept { return {str.data(), str.size()}; } +inline sz_string_view_t to_c(sz_string_view_t str) noexcept { return str; } + } // namespace scripts } // namespace stringzilla } // namespace ashvardanian \ No newline at end of file diff --git a/scripts/bench_search.cpp b/scripts/bench_search.cpp index 6ffd9790..aa76adf5 100644 --- a/scripts/bench_search.cpp +++ b/scripts/bench_search.cpp @@ -1,353 +1,571 @@ /** * @file bench_search.cpp - * @brief Benchmarks for bidirectional string search operations - exact and TODO: approximate. + * @brief Benchmarks for bidirectional string search operations. + * The program accepts a file path to a dataset, tokenizes it, and benchmarks the search operations, + * validating the SIMD-accelerated backends against the serial baselines. * - * This file is the sibling of `bench_sort.cpp`, `bench_token.cpp` and `bench_similarity.cpp`. - * It accepts a file with a list of words, and benchmarks the search operations on them. - * Outside of present tokens also tries missing tokens. + * Benchmarks include: + * - Substring search: find all inclusions of a token in the dataset - @b find & @b rfind. + * - Byte search: find a specific byte value in each token (word, line, or file) - @b find_byte & @b rfind_byte. + * - Byteset search: find any byte value from a set in each token (line or file) - @b find_byteset & @b rfind_byteset. + * + * For substring search, the number of operations per second are reported as the number of character-level comparisons + * happening in the worst case in the naive algorithm, meaning O(N*M) for N characters in the haystack and M in the + * needle. + * + * Instead of CLI arguments, for compatibility with @b StringWa.rs, the following environment variables are used: + * - `STRINGWARS_DATASET` : Path to the dataset file. + * - `STRINGWARS_TOKENS=word` : Tokenization model ("file", "line", "word", or positive integer [1:200] for N-grams + * - `STRINGWARS_SEED=42` : Optional seed for shuffling reproducibility. + * + * Unlike StringWa.rs, the following additional environment variables are supported: + * - `STRINGWARS_DURATION=10` : Time limit (in seconds) per benchmark. + * - `STRINGWARS_STRESS=1` : Test SIMD-accelerated functions against the serial baselines. + * - `STRINGWARS_STRESS_DIR=/.tmp` : Output directory for stress-testing failures logs. + * - `STRINGWARS_STRESS_LIMIT=1` : Controls the number of failures we're willing to tolerate. + * - `STRINGWARS_STRESS_DURATION=10` : Stress-testing time limit (in seconds) per benchmark. + * - `STRINGWARS_FILTER` : Regular Expression pattern to filter algorithm/backend names. + * + * Here are a few build & run commands: + * + * @code{.sh} + * cmake -D STRINGZILLA_BUILD_BENCHMARK=1 -D CMAKE_BUILD_TYPE=Release -B build_release + * cmake --build build_release --config Release --target stringzilla_bench_search + * STRINGWARS_DATASET=leipzig1M.txt STRINGWARS_TOKENS=lines build_release/stringzilla_bench_search + * @endcode + * + * Alternatively, if you really want to stress-test a very specific function on a certain size inputs, + * like all Skylake-X and newer kernels on a boundary-condition input length of 64 bytes (exactly 1 cache line), + * your last command may look like: + * + * @code{.sh} + * STRINGWARS_DATASET=leipzig1M.txt STRINGWARS_TOKENS=64 STRINGWARS_FILTER=skylake + * STRINGWARS_STRESS=1 STRINGWARS_STRESS_DURATION=120 STRINGWARS_STRESS_DIR=logs + * build_release/stringzilla_bench_search + * @endcode + * + * Unlike the full-blown StringWa.rs, it doesn't use any external frameworks like Criterion or Google Benchmark. + * This file is the sibling of `bench_sort.cpp`, `bench_token.cpp`, `bench_similarity.cpp`, and `bench_memory.cpp`. + * + * ! It requires more memory than some of the other benchmarks, as every token is re-allocated + * ! into a NULL-terminated buffer for compatibility with the C-style string functions. */ #include // `memmem` #include // `std::boyer_moore_searcher` #define SZ_USE_MISALIGNED_LOADS (1) -#include +#include "bench.hpp" using namespace ashvardanian::stringzilla::scripts; -tracked_binary_functions_t find_functions() { - // ! Despite receiving string-views, following functions are assuming the strings are null-terminated. - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](std::string_view h, std::string_view n) { - sz_cptr_t match = function(h.data(), h.size(), n.data(), n.size()); - return (match ? match - h.data() : h.size()); - }); - }; - tracked_binary_functions_t result = { - {"std::string_view.find", - [](std::string_view h, std::string_view n) { - auto match = n.size() == 1 ? h.find(n.front()) : h.find(n); - return (match == std::string_view::npos ? h.size() : match); - }}, - {"sz_find_serial", wrap_sz(sz_find_serial), true}, -#if SZ_USE_SKYLAKE - {"sz_find_skylake", wrap_sz(sz_find_skylake), true}, -#endif -#if SZ_USE_HASWELL - {"sz_find_haswell", wrap_sz(sz_find_haswell), true}, -#endif -#if SZ_USE_NEON - {"sz_find_neon", wrap_sz(sz_find_neon), true}, -#endif - {"strstr/strchr", - [](std::string_view h, std::string_view n) { - sz_cptr_t match = n.size() == 1 ? (sz_cptr_t)strchr(h.data(), n.front()) // - : (sz_cptr_t)strstr(h.data(), n.data()); - return (match ? match - h.data() : h.size()); - }}, -#ifdef _GNU_SOURCE - {"memmem/memchr", // Not supported on MSVC - [](std::string_view h, std::string_view n) { - sz_cptr_t match = n.size() == 1 ? (sz_cptr_t)memchr(h.data(), n.front(), h.size()) - : (sz_cptr_t)memmem(h.data(), h.size(), n.data(), n.size()); - return (match ? match - h.data() : h.size()); - }}, +#pragma region Substring Search + +/** + * @brief Wraps an individual hardware-specific search backend into something similar + * to @b `sz::matcher_find` and compatible with @b `sz::range_matches`. + */ +template +struct matcher_from_sz_find { + using size_type = std::size_t; + std::string_view needle_; + + inline matcher_from_sz_find(std::string_view needle = {}) noexcept : needle_(needle) {} + inline size_type needle_length() const noexcept { return needle_.size(); } + inline size_type operator()(std::string_view haystack) const noexcept { + auto ptr = find_func_(haystack.data(), haystack.size(), needle_.data(), needle_.size()); + if (!ptr) return std::string_view::npos; // No match found + return ptr - haystack.data(); + } + constexpr size_type skip_length() const noexcept { return 1; } +}; + +static std::string strstr_needle_copy_ {}; //! Reuse the same memory for all needles, potentially causing allocations + +/** + * @brief Wraps the LibC functionality for finding the next occurrence of a NULL-terminated string + * into something similar to @b `sz::matcher_find` and compatible with @b `sz::range_matches`. + */ +struct matcher_strstr_t { + using size_type = std::size_t; + + inline matcher_strstr_t(std::string_view needle = {}) noexcept(false) { strstr_needle_copy_ = needle; } + inline size_type needle_length() const noexcept { return strstr_needle_copy_.size(); } + inline size_type operator()(std::string_view haystack) const noexcept { + auto ptr = (char *)strstr(haystack.data(), strstr_needle_copy_.c_str()); + do_not_optimize(ptr); + if (!ptr) return std::string_view::npos; // No match found + return (size_type)(ptr - haystack.data()); + } + constexpr size_type skip_length() const noexcept { return 1; } +}; + +#if defined(_GNU_SOURCE) +/** + * @brief Wraps the LibC functionality for finding the next occurrence of a byte-string in a buffer + * into something similar to @b `sz::matcher_find` and compatible with @b `sz::range_matches`. + */ +struct matcher_memmem_t { + using size_type = std::size_t; + std::string_view needle_; + + inline matcher_memmem_t(std::string_view needle = {}) noexcept : needle_(needle) {} + inline size_type needle_length() const noexcept { return needle_.size(); } + inline size_type operator()(std::string_view haystack) const noexcept { + auto ptr = (char *)memmem(haystack.data(), haystack.size(), needle_.data(), needle_.size()); + do_not_optimize(ptr); + if (!ptr) return std::string_view::npos; // No match found + return (size_type)(ptr - haystack.data()); + } + constexpr size_type skip_length() const noexcept { return 1; } +}; #endif - {"std::search<>", - [](std::string_view h, std::string_view n) { - auto match = std::search(h.data(), h.data() + h.size(), n.data(), n.data() + n.size()); - return (match - h.data()); - }}, + #if __cpp_lib_boyer_moore_searcher - {"std::search", - [](std::string_view h, std::string_view n) { - auto match = - std::search(h.data(), h.data() + h.size(), std::boyer_moore_searcher(n.data(), n.data() + n.size())); - return (match - h.data()); - }}, - {"std::search", - [](std::string_view h, std::string_view n) { - auto match = std::search(h.data(), h.data() + h.size(), - std::boyer_moore_horspool_searcher(n.data(), n.data() + n.size())); - return (match - h.data()); - }}, +/** + * @brief Wraps the C++20 @b Boyer-Moore algorithms for finding the next occurrence of a string + * into something similar to @b `sz::matcher_find` and compatible with @b `sz::range_matches`. + * @tparam searcher_type_ Can be `std::boyer_moore_searcher` or `std::boyer_moore_horspool_searcher`. + * Both should be instantiated with the `std::string_view::const_iterator` type. + */ +template +struct matcher_from_std_search { + using size_type = std::size_t; + std::string_view needle_; + searcher_type_ searcher_; + + inline matcher_from_std_search(std::string_view needle = {}) noexcept + : needle_(needle), searcher_(needle.begin(), needle.end()) {} + inline size_type needle_length() const noexcept { return needle_.size(); } + inline size_type operator()(std::string_view haystack) const noexcept { + auto match = std::search(haystack.begin(), haystack.end(), searcher_); + return (size_type)(match - haystack.begin()); + } + constexpr size_type skip_length() const noexcept { return 1; } +}; + +template +struct rmatcher_from_std_search { + using size_type = std::size_t; + std::string_view needle_; + searcher_type_ searcher_; + + inline rmatcher_from_std_search(std::string_view needle = {}) noexcept + : needle_(needle), searcher_(needle.rbegin(), needle.rend()) {} + inline size_type needle_length() const noexcept { return needle_.size(); } + inline size_type operator()(std::string_view haystack) const noexcept { + auto match = std::search(haystack.rbegin(), haystack.rend(), searcher_); + auto offset_from_end = match - haystack.rbegin(); + return std::string_view::npos - offset_from_end - needle_.size(); + } + constexpr size_type skip_length() const noexcept { return 1; } +}; + #endif + +template
strcspn(haystack, needles)sz_rfind_charset(haystack, haystack_length, needles_bitset)sz_rfind_byteset(haystack, haystack_length, needles_bitset)
strspn(haystack, needles)sz_find_charset(haystack, haystack_length, needles_bitset)sz_find_byteset(haystack, haystack_length, needles_bitset)
memmem(haystack, haystack_length, needle, needle_length), strstr