Description
The concept is to take a random sample to quickly find values that almost certainly (99% chance) bracket target value(s), then efficiently pass over the whole input, counting values that fall above/below the bracketed range and explicitly storing only those that fall within the target range. If the median does not fall within the target range, try again with a new random seed up to three times (99.9999% success rate if the randomness is good). If the median does fall within the selected subset, find the exact target values within the selected subset.
Here's a naive implementation that is 4x faster for large inputs and allocates O(n ^ 2/3) memory instead of O(n) memory.
using Statistics
function my_median(v::AbstractVector)
length(v) < 2^12 && return median(v)
k = round(Int, length(v)^(1/3))
lo_i = floor(Int, middle(1, k^2) - 1.3k)
hi_i = ceil(Int, middle(1, k^2) + 1.3k)
@assert 1 <= lo_i
for _ in 1:3
sample = rand(v, k^2)
middle_of_sample = partialsort!(sample, lo_i:hi_i)
lo_x, hi_x = first(middle_of_sample), last(middle_of_sample)
number_below = 0
middle_of_v = similar(v, 0)
sizehint!(middle_of_v, 3k^2)
for x in v
a = x < lo_x
b = x < hi_x
number_below += Int(a)
if a != b
push!(middle_of_v, x)
end
end
target = middle(firstindex(v), lastindex(v)) - number_below
if isinteger(target)
target_i = Int(target)
checkbounds(Bool, middle_of_v, target_i) && return middle(partialsort!(middle_of_v, target_i))
else
target_lo = floor(Int, target)
target_hi = ceil(Int, target)
checkbounds(Bool, middle_of_v, target_lo:target_hi) && return middle(partialsort!(middle_of_v, target_lo:target_hi))
end
end
median(v)
end
I think this is reasonably close to optimal for large inputs, but I payed no heed to optimizing the O(n^(2/3)) factors, so it is likely possible to optimize this to lower the crossover point where this becomes more efficient than the current median code.
This generalizes quite well to quantiles(n, k)
for short k
. It has a runtime of O(n * k)
with a low constant factor. The calls to partialsort!
can also be replaced with more efficient recursive calls to quantile
Benchmarks
Runtimes measured in clock cycles per element (@ 3.49 GHz)
length | median | my_median |
---|---|---|
10^1 | 16.01 | 30.84 |
10^2 | 15.74 | 40.28 |
10^3 | 14.52 | 17.47 |
10^4 | 9.87 | 8.67 |
10^5 | 8.77 | 5.29 |
10^6 | 11.15 | 3.67 |
10^7 | 14.53 | 3.11 |
10^8 | 13.06 | 2.71 |
10^9 OOMs.
Benchmark code
println("length | median | my_median")
println("-------|--------|----------")
for i in 1:8
n = 10^i
print("10^", rpad(i, 2), " | ")
x = rand(n)
t0 = @belapsed median($x)
t0 *= 3.49e9/n
print(rpad(round(t0, digits=2), 4, '0'), " | ")
t1 = @belapsed my_median($x)
t1 *= 3.49e9/n
println(rpad(round(t1, digits=2), 4, '0'))
end
And I removed the length(x) < 2^12
fastpath to get accurate results for smaller inputs. I replaced the @assert
with 1 <= lo_i || return median(v)