From 3d8bd4b516f313b835c3d61e6d447eb42430e63a Mon Sep 17 00:00:00 2001 From: Marcus Better Date: Fri, 20 Dec 2024 08:57:53 -0500 Subject: [PATCH] Add an efficient reservoir sampling aggregator This aggregator uses Li's "Algorithm L", a simple yet efficient sampling method, with modifications to support a monoidal setting. A JMH benchmark was added for both this and the old priority-queue algoritm. In a single-threaded benchmark on an Intel Core i9-10885H, the algorithms are roughly on par for a sample rate of 10%, but Algorithm L performs much better at lower sample rates (2x-5x througput increase observed at various collection sizes). Because of this, the new algorithm was made the default for Aggregtor.reservoirSample(). Unit tests were added for both algorithms. These are probabilistic and are expected to fail on some 0.1% of times, per test case (p-value is set to 0.001). --- .../ReservoirSamplingBenchmark.scala | 35 +++ .../com/twitter/algebird/Aggregator.scala | 11 +- .../algebird/mutable/ReservoirSampling.scala | 213 ++++++++++++++++++ .../twitter/algebird/RandomSamplingLaws.scala | 77 +++++++ .../algebird/scalacheck/Distribution.scala | 153 +++++++++++++ .../mutable/ReservoirMonoidTest.scala | 39 ++++ .../mutable/ReservoirSamplingTest.scala | 20 ++ build.sbt | 1 + 8 files changed, 543 insertions(+), 6 deletions(-) create mode 100644 algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala create mode 100644 algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala create mode 100644 algebird-test/src/main/scala/com/twitter/algebird/RandomSamplingLaws.scala create mode 100644 algebird-test/src/main/scala/com/twitter/algebird/scalacheck/Distribution.scala create mode 100644 algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirMonoidTest.scala create mode 100644 algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala diff --git a/algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala b/algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala new file mode 100644 index 000000000..08d539aa0 --- /dev/null +++ b/algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala @@ -0,0 +1,35 @@ +package com.twitter.algebird.benchmark + +import com.twitter.algebird.mutable.{PriorityQueueToListAggregator, ReservoirSamplingToListAggregator} +import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, State} +import org.openjdk.jmh.infra.Blackhole + +import scala.util.Random + +object ReservoirSamplingBenchmark { + @State(Scope.Benchmark) + class BenchmarkState { + @Param(Array("100", "10000", "1000000")) + var collectionSize: Int = 0 + + @Param(Array("0.001", "0.01", "0.1")) + var sampleRate: Double = 0.0 + + def samples: Int = (sampleRate * collectionSize).ceil.toInt + } + + val rng = new Random() + implicit val randomSupplier: () => Random = () => rng +} + +class ReservoirSamplingBenchmark { + import ReservoirSamplingBenchmark._ + + @Benchmark + def timeAlgorithmL(state: BenchmarkState, bh: Blackhole): Unit = + bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply(0 until state.collectionSize)) + + @Benchmark + def timePriorityQeueue(state: BenchmarkState, bh: Blackhole): Unit = + bh.consume(new PriorityQueueToListAggregator[Int](state.samples).apply(0 until state.collectionSize)) +} diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Aggregator.scala b/algebird-core/src/main/scala/com/twitter/algebird/Aggregator.scala index a64ce4033..b603a53ac 100644 --- a/algebird-core/src/main/scala/com/twitter/algebird/Aggregator.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/Aggregator.scala @@ -1,5 +1,7 @@ package com.twitter.algebird +import com.twitter.algebird.mutable.{Reservoir, ReservoirSamplingToListAggregator} + import java.util.PriorityQueue import scala.collection.compat._ import scala.collection.generic.CanBuildFrom @@ -286,12 +288,9 @@ object Aggregator extends java.io.Serializable { def reservoirSample[T]( count: Int, seed: Int = DefaultSeed - ): MonoidAggregator[T, PriorityQueue[(Double, T)], Seq[T]] = { - val rng = new java.util.Random(seed) - Preparer[T] - .map(rng.nextDouble() -> _) - .monoidAggregate(sortByTake(count)(_._1)) - .andThenPresent(_.map(_._2)) + ): MonoidAggregator[T, Reservoir[T], Seq[T]] = { + val rng = new scala.util.Random(seed) + new ReservoirSamplingToListAggregator[T](count)(() => rng) } /** diff --git a/algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala b/algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala new file mode 100644 index 000000000..a80a4c32e --- /dev/null +++ b/algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala @@ -0,0 +1,213 @@ +package com.twitter.algebird.mutable + +import com.twitter.algebird.{Monoid, MonoidAggregator} + +import scala.collection.mutable +import scala.util.Random + +/** + * A reservoir of the currently sampled items. + * + * @param capacity + * the reservoir capacity + * @tparam T + * the element type + */ +sealed class Reservoir[T](val capacity: Int) { + var reservoir: mutable.Buffer[T] = mutable.Buffer() + + // When the reservoir is full, w is the threshold for accepting an element into the reservoir, and + // the following invariant holds: The maximum score of the elements in the reservoir is w, + // and the remaining elements are distributed as U[0, w]. + // Scores are not kept explicitly, only their distribution is tracked and sampled from. + // (w = 1 when the reservoir is not full.) + var w: Double = 1 + + require(capacity > 0, "reservoir size must be positive") + private val kInv: Double = 1d / capacity + + def size: Int = reservoir.size + def isEmpty: Boolean = reservoir.isEmpty + def isFull: Boolean = size == capacity + + /** + * Add an element to the reservoir. If the reservoir is full then the element will replace a random element + * in the reservoir, and the threshold
w
is updated. + * + * When adding multiple elements, [[append]] should be used to take advantage of exponential jumps. + * + * @param x + * the element to add + * @param rng + * the random source + */ + def accept(x: T, rng: Random): Unit = { + if (isFull) { + reservoir(rng.nextInt(capacity)) = x + } else { + reservoir.append(x) + } + if (isFull) { + w *= Math.pow(rng.nextDouble, kInv) + } + } + + /** + * Add multiple elements to the reservoir. + * @param xs + * the elements to add + * @param rng + * the random source + * @param prior + * the threshold of the elements being added, such that the added element's value is distributed as + *
U[0, prior]
+ * @return + * this reservoir + */ + def append(xs: TraversableOnce[T], rng: Random, prior: Double = 1): Reservoir[T] = { + // The number of items to skip before accepting the next item is geometrically distributed + // with probability of success w / prior. The prior will be 1 when adding to a single reservoir, + // but when merging reservoirs it will be the threshold of the reservoir being pulled from, + // and in this case we require that w < prior. + def nextAcceptTime = (-rng.self.nextExponential / Math.log1p(-w / prior)).toInt + + var skip = if (isFull) nextAcceptTime else 0 + for (x <- xs) { + if (!isFull) { + // keep adding while reservoir is not full + accept(x, rng) + if (isFull) { + skip = nextAcceptTime + } + } else if (skip > 0) { + skip -= 1 + } else { + accept(x, rng) + skip = nextAcceptTime + } + } + this + } + + override def toString: String = s"Reservoir($capacity, $w, ${reservoir.toList})" +} + +object Reservoir { + implicit def monoid[T](implicit randomSupplier: () => Random): Monoid[Reservoir[T]] = + new ReservoirMonoid()(randomSupplier) +} + +/** + * This is the "Algorithm L" reservoir sampling algorithm [1], with modifications to act as a monoid by + * merging reservoirs. + * + * [1] Kim-Hung Li, "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))", 1994 + * + * @tparam T + * the item type + */ +class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Monoid[Reservoir[T]] { + + /** + * Builds a reservoir with a single item. + * + * @param k + * the reservoir capacity + * @param x + * the item to add + * @return + */ + def build(k: Int, x: T): Reservoir[T] = { + val r = new Reservoir[T](k) + r.accept(x, randomSupplier()) + r + } + + override def zero: Reservoir[T] = new Reservoir(1) + def zero(k: Int): Reservoir[T] = new Reservoir(k) + override def isNonZero(r: Reservoir[T]): Boolean = !r.isEmpty + + /** + * Merge two reservoirs. NOTE: This mutates one or both of the reservoirs. They should not be used after + * this operation, except as the return value for further aggregation. + */ + override def plus(left: Reservoir[T], right: Reservoir[T]): Reservoir[T] = + if (left.isEmpty) right + else if (left.size + right.size <= left.capacity) { + // the sum of the sizes is less than the reservoir size, so we can just merge + left.append(right.reservoir, randomSupplier()) + } else { + val (s1, s2) = if (left.w < right.w) (left, right) else (right, left) + val rng = randomSupplier() + if (s2.isFull) { + // The highest score in s2 is w, and the other scores are distributed as U[0, w]. + // Since s1.w < s2.w, we have to drop the single (sampled) element with the highest score + // unconditionally. The other elements enter the reservoir with probability s1.w / s2.w. + val i = rng.nextInt(s2.size) + s2.reservoir(i) = s2.reservoir.head + s1.append(s2.reservoir.drop(1), rng, s2.w) + } else { + s1.append(s2.reservoir, rng) + } + } +} + +/** + * An aggregator that uses reservoir sampling to sample k elements from a stream of items. Because the + * reservoir is mutable, it is a good idea to copy the result to an immutable view before using it, as is done + * by [[ReservoirSamplingToListAggregator]]. + * + * @param k + * the number of elements to sample + * @param randomSupplier + * the random generator + * @tparam T + * the item type + * @tparam C + * the result type + */ +abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSupplier: () => Random) + extends MonoidAggregator[T, Reservoir[T], C] { + override val monoid: ReservoirMonoid[T] = new ReservoirMonoid + override def prepare(x: T): Reservoir[T] = monoid.build(k, x) + + override def apply(xs: TraversableOnce[T]): C = present(agg(xs)) + + override def applyOption(inputs: TraversableOnce[T]): Option[C] = + if (inputs.isEmpty) None else Some(apply(inputs)) + + override def append(r: Reservoir[T], t: T): Reservoir[T] = r.append(Seq(t), randomSupplier()) + + override def appendAll(r: Reservoir[T], xs: TraversableOnce[T]): Reservoir[T] = + r.append(xs, randomSupplier()) + + override def appendAll(xs: TraversableOnce[T]): Reservoir[T] = agg(xs) + + private def agg(xs: TraversableOnce[T]): Reservoir[T] = + appendAll(monoid.zero(k), xs) +} + +class ReservoirSamplingToListAggregator[T](k: Int)(implicit randomSupplier: () => Random) + extends ReservoirSamplingAggregator[T, List[T]](k)(randomSupplier) { + override def present(r: Reservoir[T]): List[T] = + randomSupplier().shuffle(r.reservoir).toList + + override def andThenPresent[D](f: List[T] => D): MonoidAggregator[T, Reservoir[T], D] = + new AndThenPresent(this, f) +} + +/** + * Monoid that implements [[andThenPresent]] without ruining the optimized behavior of the aggregator. + */ +protected class AndThenPresent[-A, B, C, +D](val agg: MonoidAggregator[A, B, C], f: C => D) + extends MonoidAggregator[A, B, D] { + override val monoid: Monoid[B] = agg.monoid + override def prepare(a: A): B = agg.prepare(a) + override def present(b: B): D = f(agg.present(b)) + + override def apply(xs: TraversableOnce[A]): D = f(agg(xs)) + override def applyOption(xs: TraversableOnce[A]): Option[D] = agg.applyOption(xs).map(f) + override def append(b: B, a: A): B = agg.append(b, a) + override def appendAll(b: B, as: TraversableOnce[A]): B = agg.appendAll(b, as) + override def appendAll(as: TraversableOnce[A]): B = agg.appendAll(as) +} diff --git a/algebird-test/src/main/scala/com/twitter/algebird/RandomSamplingLaws.scala b/algebird-test/src/main/scala/com/twitter/algebird/RandomSamplingLaws.scala new file mode 100644 index 000000000..40f5fd1cd --- /dev/null +++ b/algebird-test/src/main/scala/com/twitter/algebird/RandomSamplingLaws.scala @@ -0,0 +1,77 @@ +package com.twitter.algebird + +import com.twitter.algebird.scalacheck.Distribution._ +import org.scalacheck.{Gen, Prop} + +object RandomSamplingLaws { + + def sampleOneUniformly[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = { + val n = 100 + + "sampleOne" |: forAllSampled(10000, Gen.choose(1, 20))(_ => uniform(n)) { k => + newSampler(k).andThenPresent(_.head).apply(0 until n) + } + } + + def reservoirSizeOne[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = { + val n = 100 + + "reservoirSizeOne" |: forAllSampled(10000)(uniform(n)) { + newSampler(1).andThenPresent(_.head).apply(0 until n) + } + } + + def reservoirSizeTwo[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = { + val n = 10 + val tuples = for { + i <- 0 until n + j <- 0 until n + if i != j + } yield (i, j) + + "reservoirSizeTwo" |: forAllSampled(10000)(tuples.map(_ -> 1d).toMap) { + newSampler(2).andThenPresent(xs => (xs(0), xs(1))).apply(0 until n) + } + } + + def sampleSpecificItem[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = { + val sizeAndIndex: Gen[(Int, Int)] = for { + k <- Gen.choose(1, 10) + i <- Gen.choose(0, k - 1) + } yield (k, i) + + val n = 100 + + "sampleAnyItem" |: forAllSampled(10000, sizeAndIndex)(_ => uniform(n)) { case (k, i) => + newSampler(k).andThenPresent(_(i)).apply(0 until n) + } + } + + def sampleTwoItems[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = { + val sizeAndIndexes: Gen[(Int, Int, Int)] = for { + k <- Gen.choose(1, 10) + i <- Gen.choose(0, k - 1) + j <- Gen.choose(0, k - 1) + if i != j + } yield (k, i, j) + + val n = 20 + + "sampleTwoItems" |: forAllSampled(10000, sizeAndIndexes)(_ => + (for { + i <- 0 until n + j <- 0 until n + if i != j + } yield (i, j)).map(_ -> 1d).toMap + ) { case (k, i, j) => + newSampler(k).andThenPresent(xs => (xs(i), xs(j))).apply(0 until n) + } + } + + def randomSamplingDistributions[T](newSampler: Int => MonoidAggregator[Int, T, Seq[Int]]): Prop = + sampleOneUniformly(newSampler) && + reservoirSizeOne(newSampler) && + reservoirSizeTwo(newSampler) && + sampleSpecificItem(newSampler) && + sampleTwoItems(newSampler) +} diff --git a/algebird-test/src/main/scala/com/twitter/algebird/scalacheck/Distribution.scala b/algebird-test/src/main/scala/com/twitter/algebird/scalacheck/Distribution.scala new file mode 100644 index 000000000..06442406d --- /dev/null +++ b/algebird-test/src/main/scala/com/twitter/algebird/scalacheck/Distribution.scala @@ -0,0 +1,153 @@ +package com.twitter.algebird.scalacheck + +import org.apache.commons.statistics.inference.ChiSquareTest +import org.scalacheck.Prop.forAllNoShrink +import org.scalacheck.{Gen, Prop} + +import scala.collection.mutable + +/** + * ScalaCheck properties for probabilistic testing. + * + * For randomized algorithms, we want to verify that the output follows the expected distribution. The + * [[forAllSampled]] properties execute the test code a speecified number of times, collect results and + * perform a chi-squared test. + * + * These properties do not shrink their input generators, as this is less useful for probabilistic tests. + * + * @param expectedFreq + * A map of outputs (results of the test code) to their expected frequencies. Frequencies do not have to add + * up to 1, only their relative size matters. They can all be equal to 1 if a uniform distribution is + * expected. + * + * @param alpha + * the significance level for the chi-squared test + * + * @tparam T + * the result type of the test code + */ +class Distribution[T](val expectedFreq: Map[T, Double], val alpha: Double) { + private val samples: mutable.Map[T, Long] = mutable.Map().withDefaultValue(0) + + def collect(t: T): Unit = samples(t) += 1 + + def isNotRejected: Boolean = { + val (expected, observed) = expectedFreq.toSeq.map { case (k, v) => (v, samples(k)) }.toArray.unzip + val chi = ChiSquareTest.withDefaults.test(expected, observed) + !chi.reject(alpha) + } +} + +object Distribution { + private val defaultSigLevel = 0.001 + + implicit def propFromDistribution[T](d: Distribution[T]): Prop = + d.isNotRejected + + def uniform(n: Int): Map[Int, Double] = (0 until n).map(_ -> 1d).toMap + + /** + * Runs the code block for the specified number of trials and verifies that the output follows the expected + * distribution. The propoerty passes if a chi-squared test fails to reject the null hypothesis that the + * distribution is the expected one at the given significance level. + * + * @param trials + * the number of iterations + * @param alpha + * the significance level + * @param expect + * A function computing the map of outputs (possible results) to their expected frequencies. For the + * overloaded versions of this method taking generator parameters, this function takes the generated + * values as input. + * @param f + * the test code + * @tparam T + * the result type + * @return + * a [[Distribution]] object that can be used as a ScalaCheck property + */ + def forAllSampled[T](trials: Int, alpha: Double = defaultSigLevel)( + expect: Map[T, Double] + )(f: => T): Prop = { + val d = new Distribution(expect, alpha) + (0 until trials).foreach { _ => + d.collect(f) + } + d + } + + def forAllSampled[T1, T](trials: Int, g1: Gen[T1])(expect: T1 => Map[T, Double])(f: T1 => T): Prop = + forAllNoShrink(g1)(t1 => forAllSampled(trials)(expect(t1))(f(t1))) + + def forAllSampled[T1, T2, T](trials: Int, g1: Gen[T1], g2: Gen[T2])(expect: (T1, T2) => Map[T, Double])( + f: (T1, T2) => T + ): Prop = forAllNoShrink(g1)(t1 => forAllSampled(trials, g2)(expect(t1, _: T2))(f(t1, _: T2))) + + def forAllSampled[T1, T2, T3, T](trials: Int, g1: Gen[T1], g2: Gen[T2], g3: Gen[T3])( + expect: (T1, T2, T3) => Map[T, Double] + )(f: (T1, T2, T3) => T): Prop = + forAllNoShrink(g1, g2)((t1, t2) => forAllSampled(trials, g3)(expect(t1, t2, _: T3))(f(t1, t2, _: T3))) + + def forAllSampled[T1, T2, T3, T4, T](trials: Int, g1: Gen[T1], g2: Gen[T2], g3: Gen[T3], g4: Gen[T4])( + expect: (T1, T2, T3, T4) => Map[T, Double] + )(f: (T1, T2, T3, T4) => T): Prop = forAllNoShrink(g1, g2, g3)((t1, t2, t3) => + forAllSampled(trials, g4)(expect(t1, t2, t3, _: T4))(f(t1, t2, t3, _: T4)) + ) + + def forAllSampled[T1, T2, T3, T4, T5, T]( + trials: Int, + g1: Gen[T1], + g2: Gen[T2], + g3: Gen[T3], + g4: Gen[T4], + g5: Gen[T5] + )(expect: (T1, T2, T3, T4, T5) => Map[T, Double])(f: (T1, T2, T3, T4, T5) => T): Prop = + forAllNoShrink(g1, g2, g3, g4)((t1, t2, t3, t4) => + forAllSampled(trials, g5)(expect(t1, t2, t3, t4, _: T5))(f(t1, t2, t3, t4, _: T5)) + ) + + def forAllSampled[T1, T2, T3, T4, T5, T6, T]( + trials: Int, + g1: Gen[T1], + g2: Gen[T2], + g3: Gen[T3], + g4: Gen[T4], + g5: Gen[T5], + g6: Gen[T6] + )(expect: (T1, T2, T3, T4, T5, T6) => Map[T, Double])(f: (T1, T2, T3, T4, T5, T6) => T): Prop = + forAllNoShrink(g1, g2, g3, g4, g5)((t1, t2, t3, t4, t5) => + forAllSampled(trials, g6)(expect(t1, t2, t3, t4, t5, _: T6))(f(t1, t2, t3, t4, t5, _: T6)) + ) + + def forAllSampled[T1, T2, T3, T4, T5, T6, T7, T]( + trials: Int, + g1: Gen[T1], + g2: Gen[T2], + g3: Gen[T3], + g4: Gen[T4], + g5: Gen[T5], + g6: Gen[T6], + g7: Gen[T7] + )(expect: (T1, T2, T3, T4, T5, T6, T7) => Map[T, Double])(f: (T1, T2, T3, T4, T5, T6, T7) => T): Prop = + forAllNoShrink(g1, g2, g3, g4, g5, g6)((t1, t2, t3, t4, t5, t6) => + forAllSampled(trials, g7)(expect(t1, t2, t3, t4, t5, t6, _: T7))(f(t1, t2, t3, t4, t5, t6, _: T7)) + ) + + def forAllSampled[T1, T2, T3, T4, T5, T6, T7, T8, T]( + trials: Int, + g1: Gen[T1], + g2: Gen[T2], + g3: Gen[T3], + g4: Gen[T4], + g5: Gen[T5], + g6: Gen[T6], + g7: Gen[T7], + g8: Gen[T8] + )(expect: (T1, T2, T3, T4, T5, T6, T7, T8) => Map[T, Double])( + f: (T1, T2, T3, T4, T5, T6, T7, T8) => T + ): Prop = forAllNoShrink(g1, g2, g3, g4, g5, g6, g7)((t1, t2, t3, t4, t5, t6, t7) => + forAllSampled(trials, g8)(expect(t1, t2, t3, t4, t5, t6, t7, _: T8))( + f(t1, t2, t3, t4, t5, t6, t7, _: T8) + ) + ) +} diff --git a/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirMonoidTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirMonoidTest.scala new file mode 100644 index 000000000..12ca77288 --- /dev/null +++ b/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirMonoidTest.scala @@ -0,0 +1,39 @@ +package com.twitter.algebird.mutable + +import com.twitter.algebird.scalacheck.Distribution.{forAllSampled, uniform} +import com.twitter.algebird.{CheckProperties, Monoid} +import org.scalacheck.Gen +import org.scalacheck.Prop.forAll + +import scala.util.Random + +class ReservoirMonoidTest extends CheckProperties { + val rng = new Random() + implicit val randomSupplier: () => Random = () => rng + + property("adding empty is no-op") { + val mon = implicitly[Monoid[Reservoir[Int]]] + + forAll(Gen.choose(1, 20)) { m: Int => + val a = new Reservoir[Int](m) + val z = new Reservoir[Int](1) + a.accept(1, rng) + mon.plus(a, z) == a && + mon.plus(z, a) == a + } + } + + property("plus produces correct distribution") { + val mon = implicitly[Monoid[Reservoir[Int]]] + + forAllSampled(10000, Gen.choose(1, 20))(n => uniform(2 * n)) { n => + val left = new Reservoir[Int](n) + val right = new Reservoir[Int](n) + (0 until n).foreach(left.accept(_, rng)) + (n until 2 * n).foreach(right.accept(_, rng)) + + val c = mon.plus(left, right) + c.reservoir(rng.nextInt(c.size)) + } + } +} diff --git a/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala new file mode 100644 index 000000000..961cb51a2 --- /dev/null +++ b/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala @@ -0,0 +1,20 @@ +package com.twitter.algebird.mutable + +import com.twitter.algebird.CheckProperties +import com.twitter.algebird.RandomSamplingLaws._ + +import scala.util.Random + +class ReservoirSamplingTest extends CheckProperties { + + val rng = new Random() + implicit val randomSupplier: () => Random = () => rng + + property("reservoir sampling with Algorithm L works") { + randomSamplingDistributions(new ReservoirSamplingToListAggregator[Int](_)) + } + + property("reservoir sampling with priority queue works") { + randomSamplingDistributions(new PriorityQueueToListAggregator[Int](_)) + } +} diff --git a/build.sbt b/build.sbt index 97e3f1f7f..262a04c97 100644 --- a/build.sbt +++ b/build.sbt @@ -240,6 +240,7 @@ lazy val algebirdTest = module("test") Seq( "org.scalacheck" %% "scalacheck" % scalacheckVersion, "org.scalatest" %% "scalatest" % scalaTestVersion, + "org.apache.commons" % "commons-statistics-inference" % "1.1", "org.scalatestplus" %% "scalatestplus-scalacheck" % scalaTestPlusVersion % "test" ) ++ { if (isScala213x(scalaVersion.value)) {