Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick sort: Replace recursion with custom stack, small improvements #84

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 149 additions & 101 deletions sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ static __inline size_t rbnd(size_t len) {
#define TIM_SORT_MERGE SORT_MAKE_STR(tim_sort_merge)
#define TIM_SORT_COLLAPSE SORT_MAKE_STR(tim_sort_collapse)
#define HEAP_SORT SORT_MAKE_STR(heap_sort)
#define MEDIAN SORT_MAKE_STR(median)
#define MEDIAN3 SORT_MAKE_STR(median3)
#define MEDIAN3_AND_PREP_PARTITION SORT_MAKE_STR(median3_and_prep_partition)
#define QUICK_SORT SORT_MAKE_STR(quick_sort)
#define MERGE_SORT SORT_MAKE_STR(merge_sort)
#define MERGE_SORT_RECURSIVE SORT_MAKE_STR(merge_sort_recursive)
Expand Down Expand Up @@ -1276,33 +1277,38 @@ void MERGE_SORT(SORT_TYPE *dst, const size_t size) {
}


static __inline size_t QUICK_SORT_PARTITION(SORT_TYPE *dst, const size_t left,
const size_t right, const size_t pivot) {
SORT_TYPE value = dst[pivot];
size_t index = left;
size_t i;
int not_all_same = 0;
/* move the pivot to the right */
SORT_SWAP(dst[pivot], dst[right]);
static __inline SORT_TYPE * QUICK_SORT_PARTITION(SORT_TYPE *lo,
SORT_TYPE *hi,
SORT_TYPE *piv) {
SORT_TYPE *index;
SORT_TYPE *it;
SORT_TYPE value;
int not_all_same = 0;

value = *hi;
index = lo;
it = lo;

for (i = left; i < right; i++) {
int cmp = SORT_CMP(dst[i], value);
/* check if everything is all the same */
for (; it < hi; ++it) {
int cmp = SORT_CMP(*it, value);
/* check if everything is all the same. NB this costs upward of 10% */
/* performance on other cases */
not_all_same |= cmp;

if (cmp < 0) {
SORT_SWAP(dst[i], dst[index]);
SORT_SWAP(*it, *index);
index++;
}
}

SORT_SWAP(dst[right], dst[index]);

/* avoid degenerate case */
if (not_all_same == 0) {
return SIZE_MAX;
return NULL;
}

*hi = *index;
*index = value;

return index;
}

Expand Down Expand Up @@ -1338,115 +1344,156 @@ static __inline size_t QUICK_SORT_HOARE_PARTITION(SORT_TYPE *dst, const size_t l
*/


/* Return the median index of the objects at the three indices. */
static __inline size_t MEDIAN(const SORT_TYPE *dst, const size_t a, const size_t b,
const size_t c) {
const int AB = SORT_CMP(dst[a], dst[b]) < 0;

if (AB) {
/* a < b */
const int BC = SORT_CMP(dst[b], dst[c]) < 0;

if (BC) {
/* a < b < c */
return b;
} else {
/* a < b, c < b */
const int AC = SORT_CMP(dst[a], dst[c]) < 0;

if (AC) {
/* a < c < b */
return c;
} else {
/* c < a < b */
return a;
}
/* Move median index of the objects at the three indices in piv AND */
/* swap values s.t *a < *b < *c. As well prep *(c - 1) with *b in */
/* preperation for PARTITION */
static __inline void MEDIAN3_AND_PREP_PARTITION(SORT_TYPE *a,
SORT_TYPE *b,
SORT_TYPE *c) {
SORT_CSWAP(*a, *b);
if(SORT_CMP(*b, *c) > 0) {
SORT_SWAP(*b, *c);
if(SORT_CMP(*a, *b) > 0) {
/* Prepare for partition */
SORT_TYPE val = *a;
*a = *b;
*b = *(c - 1);
*(c - 1) = val;
return;
}
} else {
/* b < a */
const int AC = SORT_CMP(dst[a], dst[b]) < 0;
}
/* prepare for partition */
SORT_SWAP(*b, *(c - 1));
}

if (AC) {
/* b < a < c */
return a;
} else {
/* b < a, c < a */
const int BC = SORT_CMP(dst[b], dst[c]) < 0;

if (BC) {
/* b < c < a */
return c;
} else {
/* c < b < a */
return b;
}
/* Return the median index of the objects at the three indices. */
static __inline SORT_TYPE * MEDIAN3(SORT_TYPE *a, SORT_TYPE *b,
SORT_TYPE *c) {
if(SORT_CMP(a, b) > 0) {
SORT_TYPE * tmp = a;
a = b;
b = tmp;
}
if(SORT_CMP(b, c) > 0) {
if(SORT_CMP(a, c) > 0) {
return c;
}
return a;
}
return b;
}

static void QUICK_SORT_RECURSIVE(SORT_TYPE *dst, const size_t original_left,
const size_t original_right) {
size_t left;
size_t right;
size_t pivot;
size_t new_pivot;
size_t middle;
int loop_count = 0;
const int max_loops = 64 - CLZ(original_right - original_left); /* ~lg N */
left = original_left;
right = original_right;
static void QUICK_SORT_RECURSIVE(SORT_TYPE *dst, size_t size) {
/* State we need between "recursive" calls */
typedef struct range_stack {
SORT_TYPE *begin;
SORT_TYPE *end;
} range_stack_t;

while (1) {
if (right <= left) {
return;
}
/* Maximum "recursive" depth on this machine */
range_stack_t ranges_base[sizeof(size_t) * 8];
range_stack_t *cur_range = &ranges_base[0];
int loop_count;

if ((right - left + 1U) <= SMALL_SORT_BND) {
SMALL_SORT(&dst[left], right - left + 1U);
return;
}
SORT_TYPE *lo;
SORT_TYPE *hi;

if (++loop_count >= max_loops) {
/* we have recursed / looped too many times; switch to heap sort */
HEAP_SORT(&dst[left], right - left + 1U);
return;
}
loop_count = 64 - CLZ(size);
lo = dst;
hi = dst + size;

/* median of 5 */
middle = left + ((right - left) >> 1);
pivot = MEDIAN((const SORT_TYPE *) dst, left, middle, right);
pivot = MEDIAN((const SORT_TYPE *) dst, left + ((middle - left) >> 1), pivot,
middle + ((right - middle) >> 1));
new_pivot = QUICK_SORT_PARTITION(dst, left, right, pivot);
for (;;) {
SORT_TYPE *piv;
SORT_TYPE *new_piv;
size_t remaining_lo;
size_t remaining_hi;
size_t step;

/* median of 5 */
step = (hi - lo) >> 2;
piv = MEDIAN3(lo + step, lo + step * 2, hi - step);
MEDIAN3_AND_PREP_PARTITION(lo, piv, hi);

/* MEDIAN3_AND_PREP_PARTITION ensures values at lo/hi are on the */
/* correct side of the pivot so start at lo + 1 / hi - 1. This */
/* also always us to gurantee that lo < new_piv < hi which helps */
/* us optimize the conditions for pushing ranges to the stack */
new_piv = QUICK_SORT_PARTITION(lo + 1, hi - 1, piv);
/* check for partition all equal */
if (new_pivot == SIZE_MAX) {
if (new_piv == NULL) {
/* Pop new range from stack. Can't be smaller sort bound */
if (cur_range > ranges_base) {
--cur_range;
/* Reset loop count for this range. */
loop_count = 64 - CLZ(size);
lo = cur_range->begin;
hi = cur_range->end;
continue;
}
return;
}

remaining_hi = hi - new_piv - 1U;
remaining_lo = new_piv - 1U - lo;

/* recurse only on the small part to avoid degenerate stack sizes */
/* and manually do tail call on the large part */
if (new_pivot - 1U - left > right - new_pivot - 1U) {
/* left is bigger than right */
QUICK_SORT_RECURSIVE(dst, new_pivot + 1U, right);
/* tail call for left */
right = new_pivot - 1U;
/* this is why range stack size of sizeof(size_t) * 8 works */
if (remaining_lo > remaining_hi) {
/* hi range is smaller so push lo range to the "stack" */
cur_range->begin = lo;
cur_range->end = new_piv - 1U;
++cur_range;
lo = new_piv + 1U;
} else {
/* right is bigger than left */
QUICK_SORT_RECURSIVE(dst, left, new_pivot - 1U);
/* tail call for right */
left = new_pivot + 1U;
/* lo range is smaller so push hi range to the "stack" */
cur_range->begin = new_piv + 1U;
cur_range->end = hi;
++cur_range;
hi = new_piv - 1U;
}

if ((hi - lo) <= SMALL_SORT_BND) {
SMALL_SORT(lo, (hi - lo) + 1);
/* We just pushed to the range stack so we can't be done here */
--cur_range;
lo = cur_range->begin;
hi = cur_range->end;
/* Only the next range we pop can be less than SMALL_SORT_BND */
/* Any other range has already passed this check (or is greater */
/* than a value that passed this check) */
if ((hi - lo) <= SMALL_SORT_BND) {
SMALL_SORT(lo, (hi - lo) + 1);
if (cur_range <= ranges_base) {
return;
}
--cur_range;
lo = cur_range->begin;
hi = cur_range->end;
}
loop_count = 64 - CLZ(hi - lo);
} else if (--loop_count <= 0) {
/* we have recursed / looped too many times; switch to heap sort */
HEAP_SORT(lo, (hi - lo) + 1);
if(cur_range <= ranges_base) {
return;
}
/* pop new range. Can't be smaller than SMALL_SORT_BND */
--cur_range;
lo = cur_range->begin;
hi = cur_range->end;
loop_count = 64 - CLZ(hi - lo);
}
}
}

void QUICK_SORT(SORT_TYPE *dst, const size_t size) {
/* don't bother sorting an array of size 1 */
if (size <= 1) {
/* Handle small sizes early */
if (size <= SMALL_SORT_BND) {
SMALL_SORT(dst, size);
return;
}

QUICK_SORT_RECURSIVE(dst, 0U, size - 1U);
QUICK_SORT_RECURSIVE(dst, size - 1U);
}


Expand Down Expand Up @@ -3009,7 +3056,8 @@ void BUBBLE_SORT(SORT_TYPE *dst, const size_t size) {
#undef SORT_NEW_BUFFER
#undef SORT_DELETE_BUFFER
#undef QUICK_SORT
#undef MEDIAN
#undef MEDIAN3
#undef MEDIAN3_AND_PREP_PARTITION
#undef SORT_CONCAT
#undef SORT_MAKE_STR1
#undef SORT_MAKE_STR
Expand Down