A fast combinations calculation in jax.
Idea of combinadic implementation is from https://jamesmccaffrey.wordpress.com/2022/06/28/generating-the-mth-lexicographical-element-of-a-combination-using-the-combinadic and some useful information can be found here: https://en.wikipedia.org/wiki/Combinatorial_number_system. Below I copied and aggregated some of the details.
The following code demostrates the combinations calculation in numpy and via combinadics:
# setup
n = 4
k = 3
totalcount = math.comb(n, k)
# numpy
print(f"Calculate combinations \"{n} choose {k}\" in numpy:")
for comb in itertools.combinations(np.arange(start=0, stop=n, dtype=jnp.int32), k):
print(comb)
# combinadics
print("Calculate via combinadics:")
actual = n-1 - calculateMth(n, k, totalcount-1 - jnp.arange(start=0, stop=n, dtype=jnp.int32),)
for comb in actual:
print(comb)
And the output from execution of the code is:
Calculate combinations "4 choose 3" in numpy:
(0, 1, 2)
(0, 1, 3)
(0, 2, 3)
(1, 2, 3)
Calculate via combinadics:
[0 1 2]
[0 1 3]
[0 2 3]
[1 2 3]
You can think of a combinadic as an alternate representation of an integer. Consider the integer
The combinadic of an integer is its representation based on a variable base corresponding to the values of the binomial coefficient
With (
where
Here’s an example of how a combinadic is calculated. Suppose you are working with (
The combinadic of 8 will have the form:
The first step is to determine the value of
At this point we have used up
We used up
Suppose
Now, continuing the first example above for the number
The table below shows the relationships among
m dual(m) Element(m) combinadic(m) (n-1) - ci
==============================================
[0] 9 { 0 1 2 } ( 2 1 0 ) ( 2 3 4 )
[1] 8 { 0 1 3 } ( 3 1 0 ) ( 1 3 4 )
[2] 7 { 0 1 4 } ( 3 2 0 ) ( 1 2 4 )
[3] 6 { 0 2 3 } ( 3 2 1 ) ( 1 2 3 )
[4] 5 { 0 2 4 } ( 4 1 0 ) ( 0 3 4 )
[5] 4 { 0 3 4 } ( 4 2 0 ) ( 0 2 4 )
[6] 3 { 1 2 3 } ( 4 2 1 ) ( 0 2 3 )
[7] 2 { 1 2 4 } ( 4 3 0 ) ( 0 1 4 )
[8] 1 { 1 3 4 } ( 4 3 1 ) ( 0 1 3 )
[9] 0 { 2 3 4 } ( 4 3 2 ) ( 0 1 2 )
64-bit numbers
Performance of a single GPU