Created
October 18, 2016 13:44
-
-
Save SlavikBaranov/b06aeaf3f878c1fbeb2884035a4b7374 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// These two lines are only for code completion in IDEA, | |
// don't paste them into spark-shell | |
val spark: org.apache.spark.sql.SparkSession | |
import spark.implicits._ | |
import scala.util.Random | |
@transient val sc = spark.sparkContext | |
// # of rows 1M-10M. Running time is quadratic, | |
// so doubling # of rows increases running time by a factor of 4 | |
val numRows = 1000000 | |
val r0 = sc.parallelize(0 until numRows, 1) | |
val r1 = r0.mapPartitionsWithIndex { case (idx, it) => | |
val rnd = new Random(100500 + idx) | |
for (id <- it) yield { | |
(id, rnd.nextInt(200), rnd.nextInt(1000), rnd.nextInt(10)) | |
} | |
} | |
val t1 = r1.toDF("id", "bucket", "val1", "val2") | |
t1.createOrReplaceTempView("t1") | |
// Both Spark Iterables & Scala collections are slow, so need a custom function | |
// extracting to the primitive arrays | |
def iterToArray(it: Iterable[(Int, Int)]): (Int, Array[Int], Array[Int]) = { | |
val capacity = it.size | |
val res1 = new Array[Int](capacity) | |
val res2 = new Array[Int](capacity) | |
for (((v1, v2), i) <- it.iterator.zipWithIndex) { | |
res1(i) = v1 | |
res2(i) = v2 | |
} | |
(capacity, res1, res2) | |
} | |
for (_ <- 0 until 3) { | |
spark.sql( | |
"select a.bucket, sum(a.val2) tot " + | |
"from t1 a, t1 b " + | |
"where a.bucket=b.bucket and a.val1+b.val1<1000 " + | |
"group by a.bucket order by " + | |
"a.bucket").show(10) | |
} | |
for (_ <- 0 until 3) { | |
val res1 = r1 | |
.map { case (id, bucket, val1, val2) => | |
bucket -> (val1, val2) | |
} | |
.groupByKey(64) | |
.map { case (bucket, it) => | |
var total = 0L | |
val (sz, buf1, buf2) = iterToArray(it) | |
for { | |
a <- 0 until sz | |
b <- 0 until sz | |
} if (buf1(a) + buf1(b) < 1000) total += buf2(a) | |
bucket -> total | |
} | |
res1.toDF("bucket", "tot").repartition(1).orderBy("bucket").show(10) | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment