From f5fa804c0fc4e603061118b3625f6e5d7657809b Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 09:44:58 -0500 Subject: [PATCH 1/8] Introduce multi-select for scalar quantization --- .../org/apache/lucene/util/IntroSelector.java | 126 ++++++++++++++++++ .../java/org/apache/lucene/util/Selector.java | 25 ++++ .../util/quantization/ScalarQuantizer.java | 40 +++--- .../quantization/TestScalarQuantizer.java | 26 ++-- 4 files changed, 186 insertions(+), 31 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java index 2ade7ab43077..0afe66f3c171 100644 --- a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java @@ -45,6 +45,12 @@ public final void select(int from, int to, int k) { select(from, to, k, 2 * MathUtil.log(to - from, 2)); } + @Override + public final void select(int from, int to, int[] k) { + checkArgs(from, to, k); + select(from, to, k, 0, k.length, 2 * MathUtil.log(to - from, 2)); + } + // Visible for testing. void select(int from, int to, int k, int maxDepth) { // This code is inspired from IntroSorter#sort, adapted to loop on a single partition. @@ -146,6 +152,126 @@ void select(int from, int to, int k, int maxDepth) { } } + // Visible for testing. + void select(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { + // If there is only 1 k value to select in this group, then use the single-k select method + if (kTo - kFrom == 1) { + select(from, to, k[kFrom], maxDepth); + return; + } + + // This code is inspired from IntroSorter#sort, adapted to loop on a single partition. + + // For efficiency, we must enter the loop with at least 4 entries to be able to skip + // some boundary tests during the 3-way partitioning. + int size; + if ((size = to - from) > 3) { + + if (--maxDepth == -1) { + // Max recursion depth exceeded: shuffle (only once) and continue. + shuffle(from, to); + } + + // Pivot selection based on medians. + int last = to - 1; + int mid = (from + last) >>> 1; + int pivot; + if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) { + // Select the pivot with a single median around the middle element. + // Do not take the median between [from, mid, last] because it hurts performance + // if the order is descending in conjunction with the 3-way partitioning. + int range = size >> 2; + pivot = median(mid - range, mid, mid + range); + } else { + // Select the pivot with a variant of the Tukey's ninther median of medians. + // If k is close to the boundaries, select either the lowest or highest median (this variant + // is inspired from the interpolation search). + int range = size >> 3; + int doubleRange = range << 1; + int medianFirst = median(from, from + range, from + doubleRange); + int medianMiddle = median(mid - range, mid, mid + range); + int medianLast = median(last - doubleRange, last - range, last); + int middleK = k[(kFrom + kTo - 1) >> 1]; + if (middleK - from < range) { + // k is close to 'from': select the lowest median. + pivot = min(medianFirst, medianMiddle, medianLast); + } else if (to - middleK <= range) { + // k is close to 'to': select the highest median. + pivot = max(medianFirst, medianMiddle, medianLast); + } else { + // Otherwise select the median of medians. + pivot = median(medianFirst, medianMiddle, medianLast); + } + } + + // Bentley-McIlroy 3-way partitioning. + setPivot(pivot); + swap(from, pivot); + int i = from; + int j = to; + int p = from + 1; + int q = last; + while (true) { + int leftCmp, rightCmp; + while ((leftCmp = comparePivot(++i)) > 0) {} + while ((rightCmp = comparePivot(--j)) < 0) {} + if (i >= j) { + if (i == j && rightCmp == 0) { + swap(i, p); + } + break; + } + swap(i, j); + if (rightCmp == 0) { + swap(i, p++); + } + if (leftCmp == 0) { + swap(j, q--); + } + } + i = j + 1; + for (int l = from; l < p; ) { + swap(l++, j--); + } + for (int l = last; l > q; ) { + swap(l--, i++); + } + + // Select the K values contained in the bottom and top partitions. + int topKFrom = kTo; + int bottomKTo = kFrom; + for (int ki = kTo-1; ki >= kFrom; ki--) { + if (k[ki] >= i) { + topKFrom = ki; + } + if (k[ki] <= j) { + bottomKTo = ki + 1; + break; + } + } + // Recursively select the relevant k-values from the bottom group, if there are any k-values to select there + if (bottomKTo > kFrom) { + select(from, j + 1, k, kFrom, bottomKTo, maxDepth); + } + // Recursively select the relevant k-values from the top group, if there are any k-values to select there + if (topKFrom < kTo) { + select(i, to, k, topKFrom, kTo, maxDepth); + } + } + + // Sort the final tiny range (3 entries or less) with a very specialized sort. + switch (size) { + case 2: + if (compare(from, from + 1) > 0) { + swap(from, from + 1); + } + break; + case 3: + sort3(from); + break; + } + } + /** Returns the index of the min element among three elements at provided indices. */ private int min(int i, int j, int k) { if (compare(i, j) <= 0) { diff --git a/lucene/core/src/java/org/apache/lucene/util/Selector.java b/lucene/core/src/java/org/apache/lucene/util/Selector.java index f44966ed371a..fd43bc07c7a3 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Selector.java +++ b/lucene/core/src/java/org/apache/lucene/util/Selector.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.util; +import java.util.Arrays; + /** * An implementation of a selection algorithm, ie. computing the k-th greatest value from a * collection. @@ -30,6 +32,16 @@ public abstract class Selector { */ public abstract void select(int from, int to, int k); + /** + * Reorder elements so that the elements at all positions in {@code k} are the same as if all elements were + * sorted and all other elements are partitioned around it: {@code [from, k[n])} only contains + * elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements + * that are greater than or equal to {@code k[n]}. + */ + public void select(int from, int to, int[] k) { + select(from, to, k[0]); + } + void checkArgs(int from, int to, int k) { if (k < from) { throw new IllegalArgumentException("k must be >= from"); @@ -39,6 +51,19 @@ void checkArgs(int from, int to, int k) { } } + void checkArgs(int from, int to, int[] k) { + if (k.length < 1) { + throw new IllegalArgumentException("There must be at least one k to select, none given"); + } + Arrays.sort(k); + if (k[0] < from) { + throw new IllegalArgumentException("All k must be >= from"); + } + if (k[k.length - 1] >= to) { + throw new IllegalArgumentException("All k must be < to"); + } + } + /** Swap values at slots i and j. */ protected abstract void swap(int i, int j); } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 3f7bcf6c5c45..2ffe30749459 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -439,11 +439,10 @@ private static void extractQuantiles( double[] lowerSum) { assert confidenceIntervals.length == upperSum.length && confidenceIntervals.length == lowerSum.length; + float[][] upperAndLowerQuantiles = getUpperAndLowerQuantiles(quantileGatheringScratch, confidenceIntervals); for (int i = 0; i < confidenceIntervals.length; i++) { - float[] upperAndLower = - getUpperAndLowerQuantile(quantileGatheringScratch, confidenceIntervals[i]); - upperSum[i] += upperAndLower[1]; - lowerSum[i] += upperAndLower[0]; + upperSum[i] += upperAndLowerQuantiles[i][1]; + lowerSum[i] += upperAndLowerQuantiles[i][0]; } } @@ -568,29 +567,34 @@ private static List findNearestNeighbors( * and `95`. * * @param arr array of floats - * @param confidenceInterval the configured confidence interval + * @param confidenceIntervals the configured confidence intervals * @return lower and upper quantile values */ - static float[] getUpperAndLowerQuantile(float[] arr, float confidenceInterval) { + static float[][] getUpperAndLowerQuantiles(float[] arr, float[] confidenceIntervals) { assert arr.length > 0; + float[][] minAndMaxPerInterval = new float[confidenceIntervals.length][2]; // If we have 1 or 2 values, we can't calculate the quantiles, simply return the min and max if (arr.length <= 2) { Arrays.sort(arr); - return new float[] {arr[0], arr[arr.length - 1]}; + Arrays.fill(minAndMaxPerInterval, new float[] {arr[0], arr[arr.length - 1]}); + return minAndMaxPerInterval; } - int selectorIndex = (int) (arr.length * (1f - confidenceInterval) / 2f + 0.5f); - if (selectorIndex > 0) { - Selector selector = new FloatSelector(arr); - selector.select(0, arr.length, arr.length - selectorIndex); - selector.select(0, arr.length - selectorIndex, selectorIndex); + // Collect all quantile values to select for together + int[] selectorIndexes = new int[confidenceIntervals.length * 2]; + for (int i = 0; i < confidenceIntervals.length; i++) { + int selectorIndex = (int) (arr.length * (1f - confidenceIntervals[i]) / 2f + 0.5f); + selectorIndexes[2 * i] = selectorIndex; + selectorIndexes[2 * i + 1] = arr.length - selectorIndex - 1; } - float min = Float.POSITIVE_INFINITY; - float max = Float.NEGATIVE_INFINITY; - for (int i = selectorIndex; i < arr.length - selectorIndex; i++) { - min = Math.min(arr[i], min); - max = Math.max(arr[i], max); + Selector selector = new FloatSelector(arr); + selector.select(0, arr.length, selectorIndexes); + + // After the selection process, pick out the given quantile values + for (int i = 0; i < confidenceIntervals.length; i++) { + minAndMaxPerInterval[i][0] = arr[selectorIndexes[2*i]]; + minAndMaxPerInterval[i][1] = arr[selectorIndexes[2*i + 1]]; } - return new float[] {min, max}; + return minAndMaxPerInterval; } private static class FloatSelector extends IntroSelector { diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 7f56688b7999..8f7a8c66f198 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -125,24 +125,24 @@ public void testQuantiles() { percs[i] = (float) i; } shuffleArray(percs); - float[] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.9f); - assertEquals(50f, upperAndLower[0], 1e-7); - assertEquals(949f, upperAndLower[1], 1e-7); + float[][] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.9f}); + assertEquals(50f, upperAndLower[0][0], 1e-7); + assertEquals(949f, upperAndLower[0][1], 1e-7); shuffleArray(percs); - upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.95f); - assertEquals(25f, upperAndLower[0], 1e-7); - assertEquals(974f, upperAndLower[1], 1e-7); + upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.95f}); + assertEquals(25f, upperAndLower[0][0], 1e-7); + assertEquals(974f, upperAndLower[0][1], 1e-7); shuffleArray(percs); - upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.99f); - assertEquals(5f, upperAndLower[0], 1e-7); - assertEquals(994f, upperAndLower[1], 1e-7); + upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.99f}); + assertEquals(5f, upperAndLower[0][0], 1e-7); + assertEquals(994f, upperAndLower[0][1], 1e-7); } public void testEdgeCase() { - float[] upperAndLower = - ScalarQuantizer.getUpperAndLowerQuantile(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, 0.9f); - assertEquals(1f, upperAndLower[0], 1e-7f); - assertEquals(1f, upperAndLower[1], 1e-7f); + float[][] upperAndLower = + ScalarQuantizer.getUpperAndLowerQuantiles(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, new float[]{0.9f}); + assertEquals(1f, upperAndLower[0][0], 1e-7f); + assertEquals(1f, upperAndLower[0][1], 1e-7f); } public void testScalarWithSampling() throws IOException { From f985f75e6228b0eb93da7eb4dce2cd5906f544a8 Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 13:28:17 -0500 Subject: [PATCH 2/8] Implement radix multiSelect, add tests, rename, make default method --- .../org/apache/lucene/util/IntroSelector.java | 14 ++-- .../org/apache/lucene/util/RadixSelector.java | 78 +++++++++++++++++++ .../java/org/apache/lucene/util/Selector.java | 40 ++++++++-- .../util/quantization/ScalarQuantizer.java | 2 +- .../apache/lucene/util/TestIntroSelector.java | 25 ++++++ .../apache/lucene/util/TestRadixSelector.java | 25 ++++++ 6 files changed, 169 insertions(+), 15 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java index 0afe66f3c171..dfb3c4551def 100644 --- a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java @@ -46,9 +46,9 @@ public final void select(int from, int to, int k) { } @Override - public final void select(int from, int to, int[] k) { - checkArgs(from, to, k); - select(from, to, k, 0, k.length, 2 * MathUtil.log(to - from, 2)); + public final void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { + checkMultiArgs(from, to, k, kFrom, kTo); + multiSelect(from, to, k, kFrom, kTo, 2 * MathUtil.log(to - from, 2)); } // Visible for testing. @@ -153,8 +153,8 @@ void select(int from, int to, int k, int maxDepth) { } // Visible for testing. - void select(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { - // If there is only 1 k value to select in this group, then use the single-k select method + void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { + // If there is only 1 k value to select in this group, then use the single-k select method, which does not do recursion if (kTo - kFrom == 1) { select(from, to, k[kFrom], maxDepth); return; @@ -251,11 +251,11 @@ void select(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { } // Recursively select the relevant k-values from the bottom group, if there are any k-values to select there if (bottomKTo > kFrom) { - select(from, j + 1, k, kFrom, bottomKTo, maxDepth); + multiSelect(from, j + 1, k, kFrom, bottomKTo, maxDepth); } // Recursively select the relevant k-values from the top group, if there are any k-values to select there if (topKFrom < kTo) { - select(i, to, k, topKFrom, kTo, maxDepth); + multiSelect(i, to, k, topKFrom, kTo, maxDepth); } } diff --git a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java index 47c3ca6eaf12..121ff4cd05a5 100644 --- a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java @@ -16,6 +16,7 @@ */ package org.apache.lucene.util; +import java.util.ArrayList; import java.util.Arrays; /** @@ -124,6 +125,12 @@ public void select(int from, int to, int k) { select(from, to, k, 0, 0); } + @Override + public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { + checkMultiArgs(from, to, k, kFrom, kTo); + multiSelect(from, to, k, kFrom, kTo, 0, 0); + } + private void select(int from, int to, int k, int d, int l) { if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) { getFallbackSelector(d).select(from, to, k); @@ -132,6 +139,22 @@ private void select(int from, int to, int k, int d, int l) { } } + private void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int d, int l) { + if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) { + if (kTo - kFrom == 1) { + getFallbackSelector(d).select(from, to, k[kFrom]); + } else { + getFallbackSelector(d).multiSelect(from, to, k, kFrom, kTo); + } + } else { + if (kTo - kFrom == 1) { + radixSelect(from, to, k[kFrom], d, l); + } else { + radixMultiSelect(from, to, k, kFrom, kTo, d, l); + } + } + } + /** * @param d the character number to compare * @param l the level of recursion @@ -171,6 +194,61 @@ private void radixSelect(int from, int to, int k, int d, int l) { throw new AssertionError("Unreachable code"); } + /** + * @param d the character number to compare + * @param l the level of recursion + */ + private void radixMultiSelect(int from, int to, int[] k, int kFrom, int kTo, int d, int l) { + final int[] histogram = this.histogram; + Arrays.fill(histogram, 0); + + final int commonPrefixLength = + computeCommonPrefixLengthAndBuildHistogram(from, to, d, histogram); + if (commonPrefixLength > 0) { + // if there are no more chars to compare or if all entries fell into the + // first bucket (which means strings are shorter than d) then we are done + // otherwise recurse + if (d + commonPrefixLength < maxLength && histogram[0] < to - from) { + radixMultiSelect(from, to, k, kFrom, kTo, d + commonPrefixLength, l); + } + return; + } + assert assertHistogram(commonPrefixLength, histogram); + + int bucketFrom = from; + int bucketKFrom = kFrom; + ArrayList bucketsToRecurse = new ArrayList<>(kTo - kFrom); + for (int bucket = 0; bucket < HISTOGRAM_SIZE && bucketKFrom < kTo; ++bucket) { + if (histogram[bucket] == 0) { + continue; + } + final int bucketTo = bucketFrom + histogram[bucket]; + int bucketKTo = bucketKFrom; + // Move the right-side of the k-window up until the k-value is no longer in the current histogram bucket + while (bucketKTo < kTo && k[bucketKTo] < bucketTo) { + bucketKTo++; + } + + // If there are any k-values captured in this histogram, continue down this path with those k-values + if (bucketKFrom < bucketKTo) { + partition(from, to, bucket, bucketFrom, bucketTo, d); + + // all elements in bucket 0 are equal so we only need to recurse if bucket != 0 + if (bucket != 0 && d + 1 < maxLength) { + // Recurse after the loop, so that we do not override the histogram + bucketsToRecurse.add(new Bucket(bucketFrom, bucketTo, bucketKFrom, bucketKTo)); + } + } + bucketFrom = bucketTo; + bucketKFrom = bucketKTo; + } + for (Bucket b : bucketsToRecurse) { + multiSelect(b.from, b.to, k, b.kFrom, b.kTo, d + 1, l + 1); + } + } + + private record Bucket(int from, int to, int kFrom, int kTo) {} + // only used from assert private boolean assertHistogram(int commonPrefixLength, int[] histogram) { int numberOfUniqueBytes = 0; diff --git a/lucene/core/src/java/org/apache/lucene/util/Selector.java b/lucene/core/src/java/org/apache/lucene/util/Selector.java index fd43bc07c7a3..b2691075e146 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Selector.java +++ b/lucene/core/src/java/org/apache/lucene/util/Selector.java @@ -38,8 +38,31 @@ public abstract class Selector { * elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements * that are greater than or equal to {@code k[n]}. */ - public void select(int from, int to, int[] k) { - select(from, to, k[0]); + public void multiSelect(int from, int to, int[] k) { + multiSelect(from, to, k, 0, k.length); + } + + /** + * Reorder elements so that the elements at all positions in {@code k} are the same as if all elements were + * sorted and all other elements are partitioned around it: {@code [from, k[n])} only contains + * elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements + * that are greater than or equal to {@code k[n]}. + * + * The array {@code k} will be sorted, so {@code kFrom} and {@code kTo} must be referring to the sorted order. + */ + public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { + // Default implementation only uses select(), so it is not optimal + checkMultiArgs(from, to, k, kFrom, kTo); + int nextFrom = from; + for (int i = kFrom; i < kTo; i++) { + int currentK = k[i]; + if (currentK < nextFrom) { + // This is a duplicate k + continue; + } + select(nextFrom, to, currentK); + nextFrom = currentK + 1; + } } void checkArgs(int from, int to, int k) { @@ -51,15 +74,18 @@ void checkArgs(int from, int to, int k) { } } - void checkArgs(int from, int to, int[] k) { - if (k.length < 1) { - throw new IllegalArgumentException("There must be at least one k to select, none given"); + void checkMultiArgs(int from, int to, int[] k, int kFrom, int kTo) { + if (kFrom < 0) { + throw new IllegalArgumentException("kFrom must be >= 0"); + } + if (kTo > k.length) { + throw new IllegalArgumentException("kFrom must be <= k.length"); } Arrays.sort(k); - if (k[0] < from) { + if (k[kFrom] < from) { throw new IllegalArgumentException("All k must be >= from"); } - if (k[k.length - 1] >= to) { + if (k[kTo - 1] >= to) { throw new IllegalArgumentException("All k must be < to"); } } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 2ffe30749459..74bb8f5af3e2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -587,7 +587,7 @@ static float[][] getUpperAndLowerQuantiles(float[] arr, float[] confidenceInterv selectorIndexes[2 * i + 1] = arr.length - selectorIndex - 1; } Selector selector = new FloatSelector(arr); - selector.select(0, arr.length, selectorIndexes); + selector.multiSelect(0, arr.length, selectorIndexes); // After the selection process, pick out the given quantile values for (int i = 0; i < confidenceIntervals.length; i++) { diff --git a/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java b/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java index 713451097c56..3f0f03a034ef 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java @@ -80,5 +80,30 @@ protected int comparePivot(int j) { assertTrue(actual[i] >= actual[k]); } } + + final int[] kArr = new int[TestUtil.nextInt(random, 1, 10)]; + for (int i = 0; i < kArr.length; i++) { + kArr[i] = TestUtil.nextInt(random, from, to - 1); + } + selector.multiSelect(from, to, kArr); + + int nextKIdx = 0; + Arrays.sort(kArr); + for (int i = 0; i < actual.length; ++i) { + if (i < from || i >= to) { + assertSame(arr[i], actual[i]); + } else if (nextKIdx < kArr.length) { + if (i == kArr[nextKIdx]) { + assertEquals(expected[i], actual[i]); + while (nextKIdx < kArr.length && i == kArr[nextKIdx]) { + nextKIdx++; + } + } else { + assertTrue(actual[i].compareTo(expected[kArr[nextKIdx]]) <= 0); + } + } else { + assertTrue(actual[i].compareTo(expected[kArr[kArr.length - 1]]) >= 0); + } + } } } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestRadixSelector.java b/lucene/core/src/test/org/apache/lucene/util/TestRadixSelector.java index 1f60dfb3555d..a635860b64d9 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestRadixSelector.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestRadixSelector.java @@ -108,5 +108,30 @@ protected int byteAt(int i, int k) { assertTrue(actual[i].compareTo(actual[k]) >= 0); } } + + final int[] kArr = new int[TestUtil.nextInt(random(), 1, 10)]; + for (int i = 0; i < kArr.length; i++) { + kArr[i] = TestUtil.nextInt(random(), from, to - 1); + } + selector.multiSelect(from, to, kArr); + + int nextKIdx = 0; + Arrays.sort(kArr); + for (int i = 0; i < actual.length; ++i) { + if (i < from || i >= to) { + assertSame(arr[i], actual[i]); + } else if (nextKIdx < kArr.length) { + if (i == kArr[nextKIdx]) { + assertEquals(expected[i], actual[i]); + while (nextKIdx < kArr.length && i == kArr[nextKIdx]) { + nextKIdx++; + } + } else { + assertTrue(actual[i].compareTo(expected[kArr[nextKIdx]]) <= 0); + } + } else { + assertTrue(actual[i].compareTo(expected[kArr[kArr.length - 1]]) >= 0); + } + } } } From 5aa46e5cae4787d1a28bda2f26e144d845b7f848 Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 13:55:22 -0500 Subject: [PATCH 3/8] Another method, more docs --- .../org/apache/lucene/util/ArrayUtil.java | 36 +++++++++++++++++++ .../org/apache/lucene/util/RadixSelector.java | 2 +- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java index 2cd3cb63cfe5..9ddd0a397a53 100644 --- a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java @@ -615,6 +615,42 @@ protected int comparePivot(int j) { }.select(from, to, k); } + /** + * Reorganize {@code arr[from:to[} so that the elements at the offsets included in {@code k} are at the same position as if + * {@code arr[from:to]} was sorted, and all elements on their left are less than or equal to them, and + * all elements on their right are greater than or equal to them. + * + *

This runs in linear time on average and in {@code n log(n)} time in the worst case. + * + * @param arr Array to be re-organized. + * @param from Starting index for re-organization. Elements before this index will be left as is. + * @param to Ending index. Elements after this index will be left as is. + * @param k Array containing the Indexes of elements to sort from. Values must be less than 'to' and greater than or equal to 'from'. This list will be sorted during the call. + * @param comparator Comparator to use for sorting + */ + public static void multiSelect( + T[] arr, int from, int to, int[] k, Comparator comparator) { + new IntroSelector() { + + T pivot; + + @Override + protected void swap(int i, int j) { + ArrayUtil.swap(arr, i, j); + } + + @Override + protected void setPivot(int i) { + pivot = arr[i]; + } + + @Override + protected int comparePivot(int j) { + return comparator.compare(pivot, arr[j]); + } + }.multiSelect(from, to, k); + } + /** Copies an array into a new array. */ public static byte[] copyArray(byte[] array) { return copyOfSubArray(array, 0, array.length); diff --git a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java index 121ff4cd05a5..a04bd96d9757 100644 --- a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java @@ -23,7 +23,7 @@ * Radix selector. * *

This implementation works similarly to a MSB radix sort except that it only recurses into the - * sub partition that contains the desired value. + * sub partition that contains the desired value(s). * * @lucene.internal */ From 08196572bfa6dd21063a616ac99fd8d2fecf7c57 Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 14:53:14 -0500 Subject: [PATCH 4/8] tidy --- .../org/apache/lucene/util/ArrayUtil.java | 9 +++++---- .../org/apache/lucene/util/IntroSelector.java | 11 +++++++---- .../org/apache/lucene/util/RadixSelector.java | 6 ++++-- .../java/org/apache/lucene/util/Selector.java | 19 ++++++++++--------- .../util/quantization/ScalarQuantizer.java | 7 ++++--- .../quantization/TestScalarQuantizer.java | 11 ++++++----- 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java index 9ddd0a397a53..05be0c3bec66 100644 --- a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java @@ -616,16 +616,17 @@ protected int comparePivot(int j) { } /** - * Reorganize {@code arr[from:to[} so that the elements at the offsets included in {@code k} are at the same position as if - * {@code arr[from:to]} was sorted, and all elements on their left are less than or equal to them, and - * all elements on their right are greater than or equal to them. + * Reorganize {@code arr[from:to[} so that the elements at the offsets included in {@code k} are + * at the same position as if {@code arr[from:to]} was sorted, and all elements on their left are + * less than or equal to them, and all elements on their right are greater than or equal to them. * *

This runs in linear time on average and in {@code n log(n)} time in the worst case. * * @param arr Array to be re-organized. * @param from Starting index for re-organization. Elements before this index will be left as is. * @param to Ending index. Elements after this index will be left as is. - * @param k Array containing the Indexes of elements to sort from. Values must be less than 'to' and greater than or equal to 'from'. This list will be sorted during the call. + * @param k Array containing the Indexes of elements to sort from. Values must be less than 'to' + * and greater than or equal to 'from'. This list will be sorted during the call. * @param comparator Comparator to use for sorting */ public static void multiSelect( diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java index dfb3c4551def..0c847098757b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java @@ -154,7 +154,8 @@ void select(int from, int to, int k, int maxDepth) { // Visible for testing. void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { - // If there is only 1 k value to select in this group, then use the single-k select method, which does not do recursion + // If there is only 1 k value to select in this group, then use the single-k select method, + // which does not do recursion if (kTo - kFrom == 1) { select(from, to, k[kFrom], maxDepth); return; @@ -240,7 +241,7 @@ void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { // Select the K values contained in the bottom and top partitions. int topKFrom = kTo; int bottomKTo = kFrom; - for (int ki = kTo-1; ki >= kFrom; ki--) { + for (int ki = kTo - 1; ki >= kFrom; ki--) { if (k[ki] >= i) { topKFrom = ki; } @@ -249,11 +250,13 @@ void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { break; } } - // Recursively select the relevant k-values from the bottom group, if there are any k-values to select there + // Recursively select the relevant k-values from the bottom group, if there are any k-values + // to select there if (bottomKTo > kFrom) { multiSelect(from, j + 1, k, kFrom, bottomKTo, maxDepth); } - // Recursively select the relevant k-values from the top group, if there are any k-values to select there + // Recursively select the relevant k-values from the top group, if there are any k-values to + // select there if (topKFrom < kTo) { multiSelect(i, to, k, topKFrom, kTo, maxDepth); } diff --git a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java index a04bd96d9757..16510a1bffcb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java @@ -224,12 +224,14 @@ private void radixMultiSelect(int from, int to, int[] k, int kFrom, int kTo, int } final int bucketTo = bucketFrom + histogram[bucket]; int bucketKTo = bucketKFrom; - // Move the right-side of the k-window up until the k-value is no longer in the current histogram bucket + // Move the right-side of the k-window up until the k-value is no longer in the current + // histogram bucket while (bucketKTo < kTo && k[bucketKTo] < bucketTo) { bucketKTo++; } - // If there are any k-values captured in this histogram, continue down this path with those k-values + // If there are any k-values captured in this histogram, continue down this path with those + // k-values if (bucketKFrom < bucketKTo) { partition(from, to, bucket, bucketFrom, bucketTo, d); diff --git a/lucene/core/src/java/org/apache/lucene/util/Selector.java b/lucene/core/src/java/org/apache/lucene/util/Selector.java index b2691075e146..fe7c44aaa890 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Selector.java +++ b/lucene/core/src/java/org/apache/lucene/util/Selector.java @@ -33,22 +33,23 @@ public abstract class Selector { public abstract void select(int from, int to, int k); /** - * Reorder elements so that the elements at all positions in {@code k} are the same as if all elements were - * sorted and all other elements are partitioned around it: {@code [from, k[n])} only contains - * elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements - * that are greater than or equal to {@code k[n]}. + * Reorder elements so that the elements at all positions in {@code k} are the same as if all + * elements were sorted and all other elements are partitioned around it: {@code [from, k[n])} + * only contains elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only + * contains elements that are greater than or equal to {@code k[n]}. */ public void multiSelect(int from, int to, int[] k) { multiSelect(from, to, k, 0, k.length); } /** - * Reorder elements so that the elements at all positions in {@code k} are the same as if all elements were - * sorted and all other elements are partitioned around it: {@code [from, k[n])} only contains - * elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements - * that are greater than or equal to {@code k[n]}. + * Reorder elements so that the elements at all positions in {@code k} are the same as if all + * elements were sorted and all other elements are partitioned around it: {@code [from, k[n])} + * only contains elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only + * contains elements that are greater than or equal to {@code k[n]}. * - * The array {@code k} will be sorted, so {@code kFrom} and {@code kTo} must be referring to the sorted order. + *

The array {@code k} will be sorted, so {@code kFrom} and {@code kTo} must be referring to + * the sorted order. */ public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { // Default implementation only uses select(), so it is not optimal diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 74bb8f5af3e2..3d27992c0457 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -439,7 +439,8 @@ private static void extractQuantiles( double[] lowerSum) { assert confidenceIntervals.length == upperSum.length && confidenceIntervals.length == lowerSum.length; - float[][] upperAndLowerQuantiles = getUpperAndLowerQuantiles(quantileGatheringScratch, confidenceIntervals); + float[][] upperAndLowerQuantiles = + getUpperAndLowerQuantiles(quantileGatheringScratch, confidenceIntervals); for (int i = 0; i < confidenceIntervals.length; i++) { upperSum[i] += upperAndLowerQuantiles[i][1]; lowerSum[i] += upperAndLowerQuantiles[i][0]; @@ -591,8 +592,8 @@ static float[][] getUpperAndLowerQuantiles(float[] arr, float[] confidenceInterv // After the selection process, pick out the given quantile values for (int i = 0; i < confidenceIntervals.length; i++) { - minAndMaxPerInterval[i][0] = arr[selectorIndexes[2*i]]; - minAndMaxPerInterval[i][1] = arr[selectorIndexes[2*i + 1]]; + minAndMaxPerInterval[i][0] = arr[selectorIndexes[2 * i]]; + minAndMaxPerInterval[i][1] = arr[selectorIndexes[2 * i + 1]]; } return minAndMaxPerInterval; } diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 8f7a8c66f198..9bc6ed5823b7 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -125,22 +125,23 @@ public void testQuantiles() { percs[i] = (float) i; } shuffleArray(percs); - float[][] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.9f}); + float[][] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[] {0.9f}); assertEquals(50f, upperAndLower[0][0], 1e-7); assertEquals(949f, upperAndLower[0][1], 1e-7); shuffleArray(percs); - upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.95f}); + upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[] {0.95f}); assertEquals(25f, upperAndLower[0][0], 1e-7); assertEquals(974f, upperAndLower[0][1], 1e-7); shuffleArray(percs); - upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.99f}); + upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[] {0.99f}); assertEquals(5f, upperAndLower[0][0], 1e-7); assertEquals(994f, upperAndLower[0][1], 1e-7); } public void testEdgeCase() { float[][] upperAndLower = - ScalarQuantizer.getUpperAndLowerQuantiles(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, new float[]{0.9f}); + ScalarQuantizer.getUpperAndLowerQuantiles( + new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, new float[] {0.9f}); assertEquals(1f, upperAndLower[0][0], 1e-7f); assertEquals(1f, upperAndLower[0][1], 1e-7f); } @@ -194,7 +195,7 @@ public void testScalarWithSampling() throws IOException { public void testFromVectorsAutoInterval4Bit() throws IOException { int dims = 128; - int numVecs = 100; + int numVecs = 1000; VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; float[][] floats = randomFloats(numVecs, dims); From cd7bcb35581c25d39e3ec910ea594253b37e558d Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 15:02:58 -0500 Subject: [PATCH 5/8] Undo test change --- .../apache/lucene/util/quantization/TestScalarQuantizer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 9bc6ed5823b7..ad583e8e0728 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -195,7 +195,7 @@ public void testScalarWithSampling() throws IOException { public void testFromVectorsAutoInterval4Bit() throws IOException { int dims = 128; - int numVecs = 1000; + int numVecs = 100; VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; float[][] floats = randomFloats(numVecs, dims); From b30e96749dc3f63abacf2749c704b211b4b0661b Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 16:30:46 -0500 Subject: [PATCH 6/8] Small refactor, less checks. Copy k as to keep order of array for caller --- .../org/apache/lucene/util/IntroSelector.java | 3 +-- .../org/apache/lucene/util/RadixSelector.java | 3 +-- .../java/org/apache/lucene/util/Selector.java | 22 +++++++++---------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java index 0c847098757b..10a45869ca08 100644 --- a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java @@ -46,8 +46,7 @@ public final void select(int from, int to, int k) { } @Override - public final void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { - checkMultiArgs(from, to, k, kFrom, kTo); + protected final void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { multiSelect(from, to, k, kFrom, kTo, 2 * MathUtil.log(to - from, 2)); } diff --git a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java index 16510a1bffcb..fd58bc050cbb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java +++ b/lucene/core/src/java/org/apache/lucene/util/RadixSelector.java @@ -126,8 +126,7 @@ public void select(int from, int to, int k) { } @Override - public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { - checkMultiArgs(from, to, k, kFrom, kTo); + protected void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { multiSelect(from, to, k, kFrom, kTo, 0, 0); } diff --git a/lucene/core/src/java/org/apache/lucene/util/Selector.java b/lucene/core/src/java/org/apache/lucene/util/Selector.java index fe7c44aaa890..c14e5c1c197b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Selector.java +++ b/lucene/core/src/java/org/apache/lucene/util/Selector.java @@ -39,6 +39,9 @@ public abstract class Selector { * contains elements that are greater than or equal to {@code k[n]}. */ public void multiSelect(int from, int to, int[] k) { + // k needs to be sorted, so copy the array + k = Arrays.copyOf(k, k.length); + checkMultiArgs(from, to, k); multiSelect(from, to, k, 0, k.length); } @@ -48,12 +51,11 @@ public void multiSelect(int from, int to, int[] k) { * only contains elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only * contains elements that are greater than or equal to {@code k[n]}. * - *

The array {@code k} will be sorted, so {@code kFrom} and {@code kTo} must be referring to + *

The array {@code k} must be sorted, and {@code kFrom} and {@code kTo} must be referring to * the sorted order. */ - public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { + protected void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { // Default implementation only uses select(), so it is not optimal - checkMultiArgs(from, to, k, kFrom, kTo); int nextFrom = from; for (int i = kFrom; i < kTo; i++) { int currentK = k[i]; @@ -75,18 +77,14 @@ void checkArgs(int from, int to, int k) { } } - void checkMultiArgs(int from, int to, int[] k, int kFrom, int kTo) { - if (kFrom < 0) { - throw new IllegalArgumentException("kFrom must be >= 0"); + void checkMultiArgs(int from, int to, int[] k) { + if (k.length == 0) { + throw new IllegalArgumentException("k must not be empty"); } - if (kTo > k.length) { - throw new IllegalArgumentException("kFrom must be <= k.length"); - } - Arrays.sort(k); - if (k[kFrom] < from) { + if (k[0] < from) { throw new IllegalArgumentException("All k must be >= from"); } - if (k[kTo - 1] >= to) { + if (k[k.length - 1] >= to) { throw new IllegalArgumentException("All k must be < to"); } } From d1e021055fb7b68e7711eb9d272b02fcc7115786 Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 16:43:58 -0500 Subject: [PATCH 7/8] Fix bad api --- lucene/core/src/java/org/apache/lucene/util/Selector.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/Selector.java b/lucene/core/src/java/org/apache/lucene/util/Selector.java index c14e5c1c197b..dfbe19b42356 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Selector.java +++ b/lucene/core/src/java/org/apache/lucene/util/Selector.java @@ -16,8 +16,6 @@ */ package org.apache.lucene.util; -import java.util.Arrays; - /** * An implementation of a selection algorithm, ie. computing the k-th greatest value from a * collection. @@ -40,7 +38,7 @@ public abstract class Selector { */ public void multiSelect(int from, int to, int[] k) { // k needs to be sorted, so copy the array - k = Arrays.copyOf(k, k.length); + k = ArrayUtil.copyArray(k); checkMultiArgs(from, to, k); multiSelect(from, to, k, 0, k.length); } From 48c3f3548e87004aa33d07f61b98c273805dd83c Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Tue, 15 Oct 2024 17:08:57 -0500 Subject: [PATCH 8/8] Add in the sort that was removed accidentally --- lucene/core/src/java/org/apache/lucene/util/Selector.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/util/Selector.java b/lucene/core/src/java/org/apache/lucene/util/Selector.java index dfbe19b42356..5a05400af590 100644 --- a/lucene/core/src/java/org/apache/lucene/util/Selector.java +++ b/lucene/core/src/java/org/apache/lucene/util/Selector.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.util; +import java.util.Arrays; + /** * An implementation of a selection algorithm, ie. computing the k-th greatest value from a * collection. @@ -39,6 +41,7 @@ public abstract class Selector { public void multiSelect(int from, int to, int[] k) { // k needs to be sorted, so copy the array k = ArrayUtil.copyArray(k); + Arrays.sort(k); checkMultiArgs(from, to, k); multiSelect(from, to, k, 0, k.length); }