Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose constructors for interval types #1325

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions core/src/main/scala/spire/math/Interval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
case Below(upper, uf) =>
List(Above(upper, upperFlagToLower(reverseUpperFlag(uf))))
case Point(p) =>
List(Interval.below(p), Interval.above(p))
List(Interval.below[A](p), Interval.above[A](p))
case Bounded(lower, upper, flags) =>
val lx = lowerFlagToUpper(reverseLowerFlag(lowerFlag(flags)))
val ux = upperFlagToLower(reverseUpperFlag(upperFlag(flags)))
Expand All @@ -251,7 +251,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
}

def split(t: A)(implicit o: Order[A]): (Interval[A], Interval[A]) =
(this.intersect(Interval.below(t)), this.intersect(Interval.above(t)))
(this.intersect(Interval.below[A](t)), this.intersect(Interval.above[A](t)))

def splitAtZero(implicit o: Order[A], ev: AdditiveMonoid[A]): (Interval[A], Interval[A]) =
split(ev.zero)
Expand Down Expand Up @@ -295,7 +295,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
else if (upper > x) Bounded(m.zero, upper, upperFlag(fs))
else Bounded(m.zero, x, lowerFlagToUpper(fs) & upperFlag(fs))
case _ => // Above or Below
Interval.atOrAbove(m.zero)
Interval.atOrAbove[A](m.zero)
}
} else if (hasBelow(m.zero)) {
-this
Expand Down Expand Up @@ -565,7 +565,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
if (k < 0) {
throw new IllegalArgumentException(s"negative exponent: $k")
} else if (k == 0) {
Interval.point(r.one)
Interval.point[A](r.one)
} else if (k == 1) {
this
} else if ((k & 1) == 0) {
Expand Down Expand Up @@ -632,7 +632,7 @@ sealed abstract class Interval[A] extends Serializable { lhs =>
* result = { p(x) | x ∈ interval }
*/
def translate(p: Polynomial[A])(implicit o: Order[A], ev: Field[A]): Interval[A] = {
val terms2 = p.terms.map { case Term(c, e) => Term(Interval.point(c), e) }
val terms2 = p.terms.map { case Term(c, e) => Term(Interval.point[A](c): Interval[A], e) }
val p2 = Polynomial(terms2)
p2(this)
}
Expand Down Expand Up @@ -840,13 +840,23 @@ object Interval {
else
Interval.empty[A]

def empty[A: Order]: Interval[A] = Empty[A]()
// Old methods with a return type of Interval, kept for binary compatibility
private[Interval] def empty[A: Order, B]: Interval[A] = empty[A]
private[Interval] def point[A: Order, B](a: A): Interval[A] = point[A](a)
private[Interval] def zero[A: Order, B](implicit r: Semiring[A]): Interval[A] = zero[A]
private[Interval] def all[A: Order, B]: Interval[A] = all[A]
private[Interval] def above[A: Order, B](a: A): Interval[A] = above[A](a)
private[Interval] def below[A: Order, B](a: A): Interval[A] = below[A](a)
private[Interval] def atOrAbove[A: Order, B](a: A): Interval[A] = atOrAbove[A](a)
private[Interval] def atOrBelow[A: Order, B](a: A): Interval[A] = atOrBelow[A](a)

def point[A: Order](a: A): Interval[A] = Point(a)
def empty[A: Order]: Empty[A] = Empty[A]()

def zero[A: Order](implicit r: Semiring[A]): Interval[A] = Point(r.zero)
def point[A: Order](a: A): Point[A] = Point(a)

def all[A: Order]: Interval[A] = All[A]()
def zero[A: Order](implicit r: Semiring[A]): Point[A] = Point(r.zero)

def all[A: Order]: All[A] = All[A]()

def apply[A: Order](lower: A, upper: A): Interval[A] = closed(lower, upper)

Expand All @@ -867,9 +877,9 @@ object Interval {
*/
def errorBounds(d: Double): Interval[Rational] =
if (d == Double.PositiveInfinity) {
Interval.above(Double.MaxValue)
Interval.above[Rational](Double.MaxValue)
} else if (d == Double.NegativeInfinity) {
Interval.below(Double.MinValue)
Interval.below[Rational](Double.MinValue)
} else if (isNaN(d)) {
Interval.empty[Rational]
} else {
Expand Down Expand Up @@ -902,32 +912,32 @@ object Interval {
*/
private[spire] def fromOrderedBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] =
(lower, upper) match {
case (EmptyBound(), EmptyBound()) => empty
case (EmptyBound(), EmptyBound()) => empty[A]
case (Closed(x), Closed(y)) => Bounded(x, y, closedLowerFlags | closedUpperFlags)
case (Open(x), Open(y)) => Bounded(x, y, openLowerFlags | openUpperFlags)
case (Unbound(), Open(y)) => below(y)
case (Open(x), Unbound()) => above(x)
case (Unbound(), Closed(y)) => atOrBelow(y)
case (Closed(x), Unbound()) => atOrAbove(x)
case (Unbound(), Open(y)) => below[A](y)
case (Open(x), Unbound()) => above[A](x)
case (Unbound(), Closed(y)) => atOrBelow[A](y)
case (Closed(x), Unbound()) => atOrAbove[A](x)
case (Closed(x), Open(y)) => Bounded(x, y, closedLowerFlags | openUpperFlags)
case (Open(x), Closed(y)) => Bounded(x, y, openLowerFlags | closedUpperFlags)
case (Unbound(), Unbound()) => all
case (Unbound(), Unbound()) => all[A]
case (EmptyBound(), _) | (_, EmptyBound()) =>
throw new IllegalArgumentException("invalid empty bound")
}

def fromBounds[A: Order](lower: Bound[A], upper: Bound[A]): Interval[A] =
(lower, upper) match {
case (EmptyBound(), EmptyBound()) => empty
case (EmptyBound(), EmptyBound()) => empty[A]
case (Closed(x), Closed(y)) => closed(x, y)
case (Open(x), Open(y)) => open(x, y)
case (Unbound(), Open(y)) => below(y)
case (Open(x), Unbound()) => above(x)
case (Unbound(), Closed(y)) => atOrBelow(y)
case (Closed(x), Unbound()) => atOrAbove(x)
case (Unbound(), Open(y)) => below[A](y)
case (Open(x), Unbound()) => above[A](x)
case (Unbound(), Closed(y)) => atOrBelow[A](y)
case (Closed(x), Unbound()) => atOrAbove[A](x)
case (Closed(x), Open(y)) => openUpper(x, y)
case (Open(x), Closed(y)) => openLower(x, y)
case (Unbound(), Unbound()) => all
case (Unbound(), Unbound()) => all[A]
case (EmptyBound(), _) | (_, EmptyBound()) =>
throw new IllegalArgumentException("invalid empty bound")
}
Expand All @@ -944,10 +954,10 @@ object Interval {
if (lower < upper) Bounded(lower, upper, 1) else Interval.empty[A]
def openUpper[A: Order](lower: A, upper: A): Interval[A] =
if (lower < upper) Bounded(lower, upper, 2) else Interval.empty[A]
def above[A: Order](a: A): Interval[A] = Above(a, 1)
def below[A: Order](a: A): Interval[A] = Below(a, 2)
def atOrAbove[A: Order](a: A): Interval[A] = Above(a, 0)
def atOrBelow[A: Order](a: A): Interval[A] = Below(a, 0)
def above[A: Order](a: A): Above[A] = Above(a, 1)
def below[A: Order](a: A): Below[A] = Below(a, 2)
def atOrAbove[A: Order](a: A): Above[A] = Above(a, 0)
def atOrBelow[A: Order](a: A): Below[A] = Below(a, 0)

private val NullRe = "^ *\\( *Ø *\\) *$".r
private val SingleRe = "^ *\\[ *([^,]+) *\\] *$".r
Expand All @@ -956,14 +966,14 @@ object Interval {
def apply(s: String): Interval[Rational] =
s match {
case NullRe() => Interval.empty[Rational]
case SingleRe(x) => Interval.point(Rational(x))
case SingleRe(x) => Interval.point[Rational](Rational(x))
case PairRe(left, x, y, right) =>
(left, x, y, right) match {
case ("(", "-∞", "∞", ")") => Interval.all[Rational]
case ("(", "-∞", y, ")") => Interval.below(Rational(y))
case ("(", "-∞", y, "]") => Interval.atOrBelow(Rational(y))
case ("(", x, "∞", ")") => Interval.above(Rational(x))
case ("[", x, "∞", ")") => Interval.atOrAbove(Rational(x))
case ("(", "-∞", y, ")") => Interval.below[Rational](Rational(y))
case ("(", "-∞", y, "]") => Interval.atOrBelow[Rational](Rational(y))
case ("(", x, "∞", ")") => Interval.above[Rational](Rational(x))
case ("[", x, "∞", ")") => Interval.atOrAbove[Rational](Rational(x))
case ("[", x, y, "]") => Interval.closed(Rational(x), Rational(y))
case ("(", x, y, ")") => Interval.open(Rational(x), Rational(y))
case ("[", x, y, ")") => Interval.openUpper(Rational(x), Rational(y))
Expand All @@ -980,7 +990,7 @@ object Interval {

implicit def semiring[A](implicit ev: Ring[A], o: Order[A]): Semiring[Interval[A]] =
new Semiring[Interval[A]] {
def zero: Interval[A] = Interval.point(ev.zero)
def zero: Interval[A] = Interval.point[A](ev.zero)
def plus(x: Interval[A], y: Interval[A]): Interval[A] = x + y
def times(x: Interval[A], y: Interval[A]): Interval[A] = x * y
override def pow(x: Interval[A], k: Int): Interval[A] = x.pow(k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class IntervalGeometricPartialOrderSuite extends munit.FunSuite {
test("[2, 3] cannot be compared to empty") { assert(closed(2, 3).partialCompare(open(2, 2)).isNaN) }
test("Minimal and maximal elements of {[1], [2, 3], [2, 4]}") {
val intervals = Seq(point(1), closed(2, 3), closed(2, 4))
assertEquals(intervals.pmin.toSet, Set(point(1)))
assertEquals(intervals.pmax.toSet, Set(closed(2, 3), closed(2, 4)))
assertEquals(intervals.pmin.toSet, Set[Interval[Int]](point(1)))
assertEquals(intervals.pmax.toSet, Set[Interval[Int]](closed(2, 3), closed(2, 4)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {

import spire.algebra.{Order, PartialOrder}
forAll { (x: Rational, y: Rational) =>
val a = Interval.point(x)
val b = Interval.point(y)
val a: Interval[Rational] = Interval.point(x)
val b: Interval[Rational] = Interval.point(y)
val order = PartialOrder[Interval[Rational]].tryCompare(a, b).get == Order[Rational].compare(x, y)
val min = a.pmin(b) match {
case Some(Point(vmin)) => vmin == x.min(y)
Expand All @@ -183,8 +183,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {
forAll { (a: Rational, w: Positive[Rational]) =>
val b = a + w.num
// a < b
val i = Interval.atOrBelow(a)
val j = Interval.atOrAbove(b)
val i: Interval[Rational] = Interval.atOrBelow(a)
val j: Interval[Rational] = Interval.atOrAbove(b)
(i < j) &&
!(i >= j) &&
(j > i) &&
Expand All @@ -197,8 +197,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {
forAll { (a: Rational, w: NonNegative[Rational]) =>
val b = a - w.num
// a >= b
val i = Interval.atOrBelow(a)
val j = Interval.atOrAbove(b)
val i: Interval[Rational] = Interval.atOrBelow(a)
val j: Interval[Rational] = Interval.atOrAbove(b)
i.partialCompare(j).isNaN &&
j.partialCompare(i).isNaN
}
Expand All @@ -207,8 +207,8 @@ class IntervalScalaCheckSuite extends munit.ScalaCheckSuite {
property("(-inf, inf) does not compare with [a, b]") {
import spire.optional.intervalGeometricPartialOrder._
forAll { (a: Rational, b: Rational) =>
val i = Interval.all[Rational]
val j = Interval.closed(a, b)
val i: Interval[Rational] = Interval.all[Rational]
val j: Interval[Rational] = Interval.closed(a, b)
i.partialCompare(j).isNaN &&
j.partialCompare(i).isNaN
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class IntervalSubsetPartialOrderSuite extends munit.FunSuite {

test("Minimal and maximal elements of {[1, 3], [3], [2], [1]} by subset partial order") {
val intervals = Seq(closed(1, 3), point(3), point(2), point(1))
assertEquals(intervals.pmin.toSet, Set(point(1), point(2), point(3)))
assertEquals(intervals.pmax.toSet, Set(closed(1, 3)))
assertEquals(intervals.pmin.toSet, Set(point(1), point(2), point(3)): Set[Interval[Int]])
assertEquals(intervals.pmax.toSet, Set(closed(1, 3)): Set[Interval[Int]])
}
}
Loading