Skip to content

Commit

Permalink
[SPARK-18678][ML] Skewed reservoir sampling in SamplingUtils
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Fix reservoir sampling bias for small k. An off-by-one error meant that the probability of replacement was slightly too high -- k/(l-1) after l element instead of k/l, which matters for small k.

## How was this patch tested?

Existing test plus new test case.

Author: Sean Owen <[email protected]>

Closes apache#16129 from srowen/SPARK-18678.
  • Loading branch information
srowen committed Dec 7, 2016
1 parent b828027 commit 79f5f28
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
9 changes: 5 additions & 4 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -1007,10 +1007,11 @@ test_that("spark.randomForest", {
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
numTrees = 20, seed = 123)
predictions <- collect(predict(model, data))
expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
63.736, 64.296, 64.868, 64.300,
66.709, 67.697, 67.966, 67.252,
68.866, 69.593, 69.195, 69.658),
expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070,
63.53160, 64.05470, 65.12710, 64.30450,
66.70910, 67.86125, 68.08700, 67.21865,
68.89275, 69.53180, 69.39640, 69.68250),

tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$numTrees, 20)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ private[spark] object SamplingUtils {
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
l += 1
// There are k elements in the reservoir, and the l-th element has been
// consumed. It should be chosen with probability k/l. The expression
// below is a random long chosen uniformly from [0,l)
val replacementIndex = (rand.nextDouble() * l).toLong
if (replacementIndex < k) {
reservoir(replacementIndex.toInt) = item
}
l += 1
}
(reservoir, l)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite {
assert(sample3.length === 10)
}

test("SPARK-18678 reservoirSampleAndCount with tiny input") {
val input = Seq(0, 1)
val counts = new Array[Int](input.size)
for (i <- 0 until 500) {
val (samples, inputSize) = SamplingUtils.reservoirSampleAndCount(input.iterator, 1)
assert(inputSize === 2)
assert(samples.length === 1)
counts(samples.head) += 1
}
// If correct, should be true with prob ~ 0.99999707
assert(math.abs(counts(0) - counts(1)) <= 100)
}

test("computeFraction") {
// test that the computed fraction guarantees enough data points
// in the sample with a failure rate <= 0.0001
Expand Down

0 comments on commit 79f5f28

Please sign in to comment.