Skip to content

Instantly share code, notes, and snippets.

@redwrasse
Created August 12, 2016 04:06
Show Gist options
  • Save redwrasse/a91cb9fd519741ae083bf7f229d3727c to your computer and use it in GitHub Desktop.
Save redwrasse/a91cb9fd519741ae083bf7f229d3727c to your computer and use it in GitHub Desktop.
Distributed median binning with spark
/**
* Distributed median binning
*
* See
* "Fast Computation of the Median by Successive Binning"
* https://www.stat.cmu.edu/~ryantibs/papers/median.pdf
*
* This code currently only works for an odd number of elements
* See https://github.com/goodsoldiersvejk/medianbinning
*/
import org.apache.spark
.{SparkContext}
import org.apache.spark.rdd
.{RDD}
import org.apache.spark
.{Partitioner}
import org.slf4j
.{Logger, LoggerFactory}
object MedianBinning {
val logger = LoggerFactory.getLogger(getClass)
class BinPartitioner(numBins: Int, minValue: Int,
maxValue: Int) extends Partitioner {
val binSize = (maxValue - minValue) * 1.0 / numBins
def getPartition(key: Any): Int = key match {
case k: Int => {
((k - minValue) * 1.0 / binSize).toInt match {
case n if n == numBins => n - 1
case n => n
}
}
case _ => 0
}
def numPartitions: Int = numBins
}
def main(args: Array[String]) = {
val sc = new SparkContext()
val sampleRdd = sc.parallelize(
List(6,1,1,4,12,7,8,25,8))
val numBins = 3
val trueMedian = 7
val calculatedMedian = findMedian(sampleRdd, numBins)
println("NUMBER OF BINS: " + numBins)
println("TRUE MEDIAN: " + trueMedian)
println("CALCULATED MEDIAN: " + calculatedMedian)
sc.stop()
}
def findMedian(rdd: RDD[Int], numBins: Int): Int = {
val totalCt: Long = rdd.count
val halfCt = totalCt / 2 + 1
/**
* Returns a pair (rdd for new bin, updated left count)
*/
def findMedianBin(currentRdd: RDD[Int], leftCount: Long):
(RDD[Int], Long) = {
val (minValue, maxValue) = (currentRdd.min, currentRdd.max)
val binPartitionedRdd = currentRdd.map(e => (e, e))
.partitionBy(new BinPartitioner(numBins, minValue, maxValue))
val binCounts: Array[(Int, Long)] = binPartitionedRdd
.mapPartitionsWithIndex((i, it) => Iterator((i, it.size.toLong)),
preservesPartitioning=true).collect
var i = 0
var sm = leftCount
while (sm < halfCt) {
sm += binCounts(i)._2
i += 1
}
val medianBin = i - 1
val newLeftCt = sm - binCounts(medianBin)._2
val newRdd: RDD[Int] = binPartitionedRdd
.mapPartitionsWithIndex((i, it) => it.map(e => (i, e)))
.filter({case (i, e) => i == medianBin})
.map({case (i, e) => e._1})
(newRdd, newLeftCt)
}
var leftCt: Long = 0L
var medianBinCt = 1000L
var currentRdd = rdd
while (medianBinCt > 1L) {
findMedianBin(currentRdd, leftCt) match {
case (a,b) => currentRdd = a; leftCt = b
}
medianBinCt = currentRdd.count
}
val median = currentRdd.first
median
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment