Skip to content

Commit

Permalink
factor out number_of_partitions_length_max_part and further speedups
Browse files Browse the repository at this point in the history
  • Loading branch information
mantepse committed Nov 19, 2024
1 parent 5149cdc commit 1dbea25
Showing 1 changed file with 68 additions and 86 deletions.
154 changes: 68 additions & 86 deletions src/sage/combinat/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6173,7 +6173,7 @@ def __classcall_private__(cls, n=None, **kwargs):
if 'max_part' in kwargs:
return PartitionsGreatestLE(n, kwargs['max_part'])
if 'min_part' in kwargs:
return Partitions_parts_length_restricted(n, kwargs['min_part'], n, 0, n)
return Partitions_parts_length_restricted(n, kwargs['min_part'], n, ZZ.zero(), n)
if 'length' in kwargs:
return Partitions_nk(n, kwargs['length'])

Expand Down Expand Up @@ -6203,15 +6203,16 @@ def __classcall_private__(cls, n=None, **kwargs):

if set(kwargs).issubset(['length', 'min_part', 'max_part',
'min_length', 'max_length']):

min_part = max(kwargs.get('min_part', ZZ.one()), ZZ.one())
max_part = max(min(kwargs.get('max_part', n), n), ZZ.zero())
if 'length' in kwargs:
return Partitions_parts_length_restricted(n, kwargs.get('min_part', 1),
kwargs.get('max_part', n),
kwargs['length'],
kwargs['length'])
return Partitions_parts_length_restricted(n, kwargs.get('min_part', 1),
kwargs.get('max_part', n),
kwargs.get('min_length', 0),
kwargs.get('max_length', n))
k = ZZ(kwargs['length'])
return Partitions_parts_length_restricted(n, min_part, max_part, k, k)

min_length = max(kwargs.get('min_length', ZZ.zero()), ZZ.zero())
max_length = min(kwargs.get('max_length', n), n)
return Partitions_parts_length_restricted(n, min_part, max_part, min_length, max_length)

# FIXME: should inherit from IntegerListLex, and implement repr, or _name as a lazy attribute
kwargs['name'] = "Partitions of the integer {} satisfying constraints {}".format(n, ", ".join(["{}={}".format(key, kwargs[key]) for key in sorted(kwargs)]))
Expand Down Expand Up @@ -8866,32 +8867,6 @@ class Partitions_parts_length_restricted(UniqueRepresentation, IntegerListsLex):
sage: [2,2,2,2,2] in Partitions_parts_length_restricted(10, 2, 10, 0, 10)
True
"""
@staticmethod
def __classcall_private__(cls, n, min_part, max_part, min_length, max_length):
"""
Normalize the input to ensure a unique representation.
TESTS::
sage: from sage.combinat.partition import Partitions_parts_length_restricted
sage: P1 = Partitions_parts_length_restricted(9, 0, 20, -1, 10)
sage: P2 = Partitions_parts_length_restricted(9, 1, 9, 0, 9)
sage: P1 is P2
True
"""
n = ZZ(n)
if min_part <= 0:
min_part = ZZ.one()
if max_part > n:
max_part = n
if max_part < 0:
max_part = ZZ.zero()
if min_length < 0:
min_length = ZZ.zero()
if max_length > n:
max_length = n
return super().__classcall__(cls, n, min_part, max_part, min_length, max_length)

def __init__(self, n, min_part, max_part, min_length, max_length):
"""
Initialize ``self``.
Expand All @@ -8903,17 +8878,12 @@ def __init__(self, n, min_part, max_part, min_length, max_length):
sage: TestSuite(p).run()
"""
self._n = n

IntegerListsLex.__init__(self, self._n, max_slope=0,
min_part=min_part,
max_part=max_part,
min_length=min_length,
max_length=max_length)

self._min_part = ZZ.one() if min_part is None else min_part
self._max_part = self._n if max_part is None else max_part
self._min_length = ZZ.zero() if min_length is None else min_length
self._max_length = self._n if max_length is None else max_length
max_length=max_length,
check=False)

def _repr_(self):
"""
Expand All @@ -8922,32 +8892,32 @@ def _repr_(self):
TESTS::
sage: from sage.combinat.partition import Partitions_parts_length_restricted
sage: Partitions_parts_length_restricted(9, 2, 9, 0, 10)
sage: Partitions_parts_length_restricted(9, 2, 9, 0, 9)
Partitions of 9 whose parts are at least 2
sage: Partitions_parts_length_restricted(9, 2, 9, 3, 5)
Partitions of 9 having length between 3 and 5 and whose parts are at least 2
"""
if not self._min_length and self._max_length == self._n:
if not self.min_length and self.max_length == self._n:
length_str = ""
elif self._min_length == self._max_length:
length_str = f"having length {self._min_length}"
elif not self._min_length:
length_str = f"having length at most {self._max_length}"
elif self._max_length == self._n:
length_str = f"having length at least {self._min_length}"
elif self.min_length == self.max_length:
length_str = f"having length {self.min_length}"
elif not self.min_length:
length_str = f"having length at most {self.max_length}"

Check warning on line 8905 in src/sage/combinat/partition.py

View check run for this annotation

Codecov / codecov/patch

src/sage/combinat/partition.py#L8905

Added line #L8905 was not covered by tests
elif self.max_length == self._n:
length_str = f"having length at least {self.min_length}"

Check warning on line 8907 in src/sage/combinat/partition.py

View check run for this annotation

Codecov / codecov/patch

src/sage/combinat/partition.py#L8907

Added line #L8907 was not covered by tests
else:
length_str = f"having length between {self._min_length} and {self._max_length}"
length_str = f"having length between {self.min_length} and {self.max_length}"

if self._min_part == ZZ.one() and self._max_part == self._n:
if self.min_part == ZZ.one() and self.max_part == self._n:
parts_str = ""

Check warning on line 8912 in src/sage/combinat/partition.py

View check run for this annotation

Codecov / codecov/patch

src/sage/combinat/partition.py#L8912

Added line #L8912 was not covered by tests
elif self._min_part == self._max_part:
parts_str = f"having parts equal to {self._min_part}"
elif self._min_part == ZZ.one():
parts_str = f"whose parts are at most {self._max_part}"
elif self._max_part == self._n:
parts_str = f"whose parts are at least {self._min_part}"
elif self.min_part == self.max_part:
parts_str = f"having parts equal to {self.min_part}"

Check warning on line 8914 in src/sage/combinat/partition.py

View check run for this annotation

Codecov / codecov/patch

src/sage/combinat/partition.py#L8914

Added line #L8914 was not covered by tests
elif self.min_part == ZZ.one():
parts_str = f"whose parts are at most {self.max_part}"
elif self.max_part == self._n:
parts_str = f"whose parts are at least {self.min_part}"
else:
parts_str = f"whose parts are between {self._min_part} and {self._max_part}"
parts_str = f"whose parts are between {self.min_part} and {self.max_part}"

if length_str:
if parts_str:
Expand All @@ -8972,37 +8942,22 @@ def cardinality(self):
TESTS::
sage: from itertools import product
sage: P = Partitions_parts_length_restricted
sage: all(P(n, a, b, k, m).cardinality() == len(list(P(n, a, b, k, m)))
sage: P = Partitions
sage: all(P(n, min_part=a, max_part=b, min_length=k, max_length=m).cardinality()
....: == len(list(P(n, min_part=a, max_part=b, min_length=k, max_length=m)))
....: for n, a, b, k, m in product(range(-1, 5), repeat=5))
True
"""
if not self._min_length and self._max_length == self._n and self._min_part == 1:
n = self._n
a = self.min_part - 1
if not self.min_length and self.max_length == n and not a:
# unrestricted length, parts smaller max_part
return ZZ.sum(number_of_partitions_length(self._n, i)
for i in range(self._max_part + 1))

def partitions_len_max_part(n, b, l):
r"""
Return the number of partitions of `n` with exactly `l` parts and
the largest part at most `b`.
"""
if not n:
if not l:
return ZZ.one()
return ZZ.zero()
if not l or l > n or n > b * l:
return ZZ.zero()
if b >= n:
return number_of_partitions_length(n, l)

return ZZ.sum(partitions_len_max_part(n - m, m, l - 1)
for m in range(1, b+1))

return ZZ.sum(partitions_len_max_part(self._n - (self._min_part-1)*ell,
self._max_part - self._min_part + 1,
ell)
for ell in range(self._min_length, self._max_length + 1))
return ZZ.sum(number_of_partitions_length(n, i)
for i in range(self.max_part + 1))

m = self.max_part - self.min_part + 1
return ZZ.sum(number_of_partitions_length_max_part(n - a * ell, ell, m)
for ell in range(self.min_length, self.max_length + 1))

Element = Partition
options = Partitions.options
Expand Down Expand Up @@ -9629,6 +9584,33 @@ def number_of_partitions_length(n, k, algorithm='hybrid'):
return ZZ(libgap.NrPartitions(ZZ(n), ZZ(k)))


@cached_function
def number_of_partitions_length_max_part(n, k, b):
r"""
Return the number of partitions of `n` with exactly `k` parts and
the largest part at most `b`.
EXAMPLES::
sage: from sage.combinat.partition import number_of_partitions_length_max_part
sage: number_of_partitions_length_max_part(10, 5, 3)
3
sage: list(Partitions(10, length=5, max_part=3))
[[3, 3, 2, 1, 1], [3, 2, 2, 2, 1], [2, 2, 2, 2, 2]]
"""
if not n:
if not k:
return ZZ.one()
return ZZ.zero()
if not k or k > n or n > b * k:
return ZZ.zero()
if b >= n:
return number_of_partitions_length(n, k)

return ZZ.sum(number_of_partitions_length_max_part(n - m, k - 1, m)
for m in range(1, b+1))


##########
# issue 14225: Partitions() is frequently used, but only weakly cached.
# Hence, establish a strong reference to it.
Expand Down

0 comments on commit 1dbea25

Please sign in to comment.