Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce multiSelect for ScalarQuantizer #13919

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,43 @@ protected int comparePivot(int j) {
}.select(from, to, k);
}

/**
* Reorganize {@code arr[from:to[} so that the elements at the offsets included in {@code k} are
* at the same position as if {@code arr[from:to]} was sorted, and all elements on their left are
* less than or equal to them, and all elements on their right are greater than or equal to them.
*
* <p>This runs in linear time on average and in {@code n log(n)} time in the worst case.
*
* @param arr Array to be re-organized.
* @param from Starting index for re-organization. Elements before this index will be left as is.
* @param to Ending index. Elements after this index will be left as is.
* @param k Array containing the Indexes of elements to sort from. Values must be less than 'to'
* and greater than or equal to 'from'. This list will be sorted during the call.
* @param comparator Comparator to use for sorting
*/
public static <T> void multiSelect(
T[] arr, int from, int to, int[] k, Comparator<? super T> comparator) {
new IntroSelector() {

T pivot;

@Override
protected void swap(int i, int j) {
ArrayUtil.swap(arr, i, j);
}

@Override
protected void setPivot(int i) {
pivot = arr[i];
}

@Override
protected int comparePivot(int j) {
return comparator.compare(pivot, arr[j]);
}
}.multiSelect(from, to, k);
}

/** Copies an array into a new array. */
public static byte[] copyArray(byte[] array) {
return copyOfSubArray(array, 0, array.length);
Expand Down
128 changes: 128 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/IntroSelector.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public final void select(int from, int to, int k) {
select(from, to, k, 2 * MathUtil.log(to - from, 2));
}

@Override
protected final void multiSelect(int from, int to, int[] k, int kFrom, int kTo) {
multiSelect(from, to, k, kFrom, kTo, 2 * MathUtil.log(to - from, 2));
}

// Visible for testing.
void select(int from, int to, int k, int maxDepth) {
// This code is inspired from IntroSorter#sort, adapted to loop on a single partition.
Expand Down Expand Up @@ -146,6 +151,129 @@ void select(int from, int to, int k, int maxDepth) {
}
}

// Visible for testing.
void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) {
// If there is only 1 k value to select in this group, then use the single-k select method,
// which does not do recursion
if (kTo - kFrom == 1) {
select(from, to, k[kFrom], maxDepth);
return;
}

// This code is inspired from IntroSorter#sort, adapted to loop on a single partition.

// For efficiency, we must enter the loop with at least 4 entries to be able to skip
// some boundary tests during the 3-way partitioning.
int size;
if ((size = to - from) > 3) {

if (--maxDepth == -1) {
// Max recursion depth exceeded: shuffle (only once) and continue.
shuffle(from, to);
}

// Pivot selection based on medians.
int last = to - 1;
int mid = (from + last) >>> 1;
int pivot;
if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) {
// Select the pivot with a single median around the middle element.
// Do not take the median between [from, mid, last] because it hurts performance
// if the order is descending in conjunction with the 3-way partitioning.
int range = size >> 2;
pivot = median(mid - range, mid, mid + range);
} else {
// Select the pivot with a variant of the Tukey's ninther median of medians.
// If k is close to the boundaries, select either the lowest or highest median (this variant
// is inspired from the interpolation search).
int range = size >> 3;
int doubleRange = range << 1;
int medianFirst = median(from, from + range, from + doubleRange);
int medianMiddle = median(mid - range, mid, mid + range);
int medianLast = median(last - doubleRange, last - range, last);
int middleK = k[(kFrom + kTo - 1) >> 1];
if (middleK - from < range) {
// k is close to 'from': select the lowest median.
pivot = min(medianFirst, medianMiddle, medianLast);
} else if (to - middleK <= range) {
// k is close to 'to': select the highest median.
pivot = max(medianFirst, medianMiddle, medianLast);
} else {
// Otherwise select the median of medians.
pivot = median(medianFirst, medianMiddle, medianLast);
}
}

// Bentley-McIlroy 3-way partitioning.
setPivot(pivot);
swap(from, pivot);
int i = from;
int j = to;
int p = from + 1;
int q = last;
while (true) {
int leftCmp, rightCmp;
while ((leftCmp = comparePivot(++i)) > 0) {}
while ((rightCmp = comparePivot(--j)) < 0) {}
if (i >= j) {
if (i == j && rightCmp == 0) {
swap(i, p);
}
break;
}
swap(i, j);
if (rightCmp == 0) {
swap(i, p++);
}
if (leftCmp == 0) {
swap(j, q--);
}
}
i = j + 1;
for (int l = from; l < p; ) {
swap(l++, j--);
}
for (int l = last; l > q; ) {
swap(l--, i++);
}

// Select the K values contained in the bottom and top partitions.
int topKFrom = kTo;
int bottomKTo = kFrom;
for (int ki = kTo - 1; ki >= kFrom; ki--) {
if (k[ki] >= i) {
topKFrom = ki;
}
if (k[ki] <= j) {
bottomKTo = ki + 1;
break;
}
}
// Recursively select the relevant k-values from the bottom group, if there are any k-values
// to select there
if (bottomKTo > kFrom) {
multiSelect(from, j + 1, k, kFrom, bottomKTo, maxDepth);
}
// Recursively select the relevant k-values from the top group, if there are any k-values to
// select there
if (topKFrom < kTo) {
multiSelect(i, to, k, topKFrom, kTo, maxDepth);
}
}

// Sort the final tiny range (3 entries or less) with a very specialized sort.
switch (size) {
case 2:
if (compare(from, from + 1) > 0) {
swap(from, from + 1);
}
break;
case 3:
sort3(from);
break;
}
}

/** Returns the index of the min element among three elements at provided indices. */
private int min(int i, int j, int k) {
if (compare(i, j) <= 0) {
Expand Down
81 changes: 80 additions & 1 deletion lucene/core/src/java/org/apache/lucene/util/RadixSelector.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
*/
package org.apache.lucene.util;

import java.util.ArrayList;
import java.util.Arrays;

/**
* Radix selector.
*
* <p>This implementation works similarly to a MSB radix sort except that it only recurses into the
* sub partition that contains the desired value.
* sub partition that contains the desired value(s).
*
* @lucene.internal
*/
Expand Down Expand Up @@ -124,6 +125,11 @@ public void select(int from, int to, int k) {
select(from, to, k, 0, 0);
}

@Override
protected void multiSelect(int from, int to, int[] k, int kFrom, int kTo) {
multiSelect(from, to, k, kFrom, kTo, 0, 0);
}

private void select(int from, int to, int k, int d, int l) {
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) {
getFallbackSelector(d).select(from, to, k);
Expand All @@ -132,6 +138,22 @@ private void select(int from, int to, int k, int d, int l) {
}
}

private void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int d, int l) {
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) {
if (kTo - kFrom == 1) {
getFallbackSelector(d).select(from, to, k[kFrom]);
} else {
getFallbackSelector(d).multiSelect(from, to, k, kFrom, kTo);
}
} else {
if (kTo - kFrom == 1) {
radixSelect(from, to, k[kFrom], d, l);
} else {
radixMultiSelect(from, to, k, kFrom, kTo, d, l);
}
}
}

/**
* @param d the character number to compare
* @param l the level of recursion
Expand Down Expand Up @@ -171,6 +193,63 @@ private void radixSelect(int from, int to, int k, int d, int l) {
throw new AssertionError("Unreachable code");
}

/**
* @param d the character number to compare
* @param l the level of recursion
*/
private void radixMultiSelect(int from, int to, int[] k, int kFrom, int kTo, int d, int l) {
final int[] histogram = this.histogram;
Arrays.fill(histogram, 0);

final int commonPrefixLength =
computeCommonPrefixLengthAndBuildHistogram(from, to, d, histogram);
if (commonPrefixLength > 0) {
// if there are no more chars to compare or if all entries fell into the
// first bucket (which means strings are shorter than d) then we are done
// otherwise recurse
if (d + commonPrefixLength < maxLength && histogram[0] < to - from) {
radixMultiSelect(from, to, k, kFrom, kTo, d + commonPrefixLength, l);
}
return;
}
assert assertHistogram(commonPrefixLength, histogram);

int bucketFrom = from;
int bucketKFrom = kFrom;
ArrayList<Bucket> bucketsToRecurse = new ArrayList<>(kTo - kFrom);
for (int bucket = 0; bucket < HISTOGRAM_SIZE && bucketKFrom < kTo; ++bucket) {
if (histogram[bucket] == 0) {
continue;
}
final int bucketTo = bucketFrom + histogram[bucket];
int bucketKTo = bucketKFrom;
// Move the right-side of the k-window up until the k-value is no longer in the current
// histogram bucket
while (bucketKTo < kTo && k[bucketKTo] < bucketTo) {
bucketKTo++;
}

// If there are any k-values captured in this histogram, continue down this path with those
// k-values
if (bucketKFrom < bucketKTo) {
partition(from, to, bucket, bucketFrom, bucketTo, d);

// all elements in bucket 0 are equal so we only need to recurse if bucket != 0
if (bucket != 0 && d + 1 < maxLength) {
// Recurse after the loop, so that we do not override the histogram
bucketsToRecurse.add(new Bucket(bucketFrom, bucketTo, bucketKFrom, bucketKTo));
}
}
bucketFrom = bucketTo;
bucketKFrom = bucketKTo;
}
for (Bucket b : bucketsToRecurse) {
multiSelect(b.from, b.to, k, b.kFrom, b.kTo, d + 1, l + 1);
}
}

private record Bucket(int from, int to, int kFrom, int kTo) {}

// only used from assert
private boolean assertHistogram(int commonPrefixLength, int[] histogram) {
int numberOfUniqueBytes = 0;
Expand Down
51 changes: 51 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/Selector.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.lucene.util;

import java.util.Arrays;

/**
* An implementation of a selection algorithm, ie. computing the k-th greatest value from a
* collection.
Expand All @@ -30,6 +32,43 @@ public abstract class Selector {
*/
public abstract void select(int from, int to, int k);

/**
* Reorder elements so that the elements at all positions in {@code k} are the same as if all
* elements were sorted and all other elements are partitioned around it: {@code [from, k[n])}
* only contains elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only
* contains elements that are greater than or equal to {@code k[n]}.
*/
public void multiSelect(int from, int to, int[] k) {
// k needs to be sorted, so copy the array
k = ArrayUtil.copyArray(k);
Arrays.sort(k);
checkMultiArgs(from, to, k);
multiSelect(from, to, k, 0, k.length);
}

/**
* Reorder elements so that the elements at all positions in {@code k} are the same as if all
* elements were sorted and all other elements are partitioned around it: {@code [from, k[n])}
* only contains elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only
* contains elements that are greater than or equal to {@code k[n]}.
*
* <p>The array {@code k} must be sorted, and {@code kFrom} and {@code kTo} must be referring to
* the sorted order.
*/
protected void multiSelect(int from, int to, int[] k, int kFrom, int kTo) {
// Default implementation only uses select(), so it is not optimal
int nextFrom = from;
for (int i = kFrom; i < kTo; i++) {
int currentK = k[i];
if (currentK < nextFrom) {
// This is a duplicate k
continue;
}
select(nextFrom, to, currentK);
nextFrom = currentK + 1;
}
}

void checkArgs(int from, int to, int k) {
if (k < from) {
throw new IllegalArgumentException("k must be >= from");
Expand All @@ -39,6 +78,18 @@ void checkArgs(int from, int to, int k) {
}
}

void checkMultiArgs(int from, int to, int[] k) {
if (k.length == 0) {
throw new IllegalArgumentException("k must not be empty");
}
if (k[0] < from) {
throw new IllegalArgumentException("All k must be >= from");
}
if (k[k.length - 1] >= to) {
throw new IllegalArgumentException("All k must be < to");
}
}

/** Swap values at slots <code>i</code> and <code>j</code>. */
protected abstract void swap(int i, int j);
}
Loading