Skip to content

Instantly share code, notes, and snippets.

@bhuemer
Last active August 29, 2015 14:01
Show Gist options
  • Save bhuemer/657f59312165d63feaf6 to your computer and use it in GitHub Desktop.
Save bhuemer/657f59312165d63feaf6 to your computer and use it in GitHub Desktop.
package at.bhuemer.fpis.chapter07
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicReference
/**
* Rather than defining Par simply as a function, we'll introduce a special purpose trait ..
*/
sealed trait Par[A] {
def apply(ex: ExecutorService): Future[A]
}
/**
* .. so that we can also give the unit thingy its own name/type/case class that we can match later on.
*/
case class UnitPar[A](a: A) extends Par[A] { // UnitPar rather than Unit to avoid collisions with scala.Unit
override def apply(ex: ExecutorService): Future[A] = {
println("- unit " + a)
UnitFuture(() => a)
}
}
case class UnitFuture[A](f: () => A) extends Future[A] {
override def get(timeout: Long, unit: TimeUnit): A = f()
override def get(): A = f()
override def isDone: Boolean = true
override def isCancelled: Boolean = false
override def cancel(mayInterruptIfRunning: Boolean): Boolean = true
}
case class Map2Future[A,B,C](fa: Future[A], fb: Future[B], f: (A, B) => C) extends Future[C] {
override def get(timeout: Long, unit: TimeUnit): C = {
val timeoutAware = new Timeout(timeout, unit)
f(
timeoutAware { fa.get },
timeoutAware { fb.get })
}
override def get(): C = f(
fa.get(),
fb.get()
)
override def isDone: Boolean = fa.isDone && fb.isDone
override def isCancelled: Boolean = fa.isCancelled || fb.isCancelled
override def cancel(mayInterruptIfRunning: Boolean): Boolean = {
val faCancelled = fa.cancel(mayInterruptIfRunning)
val fbCancelled = fb.cancel(mayInterruptIfRunning)
faCancelled && fbCancelled
}
}
/**
* Utility class that can be used to accumulate several function calls together within a single time out.
*/
class Timeout(private val timeout: Long, private val unit: TimeUnit) {
private var timeoutInMillis = unit.toMillis(timeout)
def apply[A](f: (Long, TimeUnit) => A): A = {
val start = System.currentTimeMillis
val result = f(timeoutInMillis, TimeUnit.MILLISECONDS)
// Update the remaining time
timeoutInMillis =
timeoutInMillis - (System.currentTimeMillis - start)
result
}
}
/**
* Poor man's version of a promise based on Java futures.
*/
class Promise[A] extends Future[A] {
private val countDownLatch = new CountDownLatch(1)
private val reference = new AtomicReference[A]()
def set(a: A) {
reference.set(a)
countDownLatch.countDown()
}
override def get(timeout: Long, unit: TimeUnit): A = {
countDownLatch.await(timeout, unit)
reference.get()
}
override def get(): A = {
countDownLatch.await()
reference.get()
}
override def isDone: Boolean = countDownLatch.getCount == 0
override def isCancelled: Boolean = false
override def cancel(mayInterruptIfRunning: Boolean): Boolean = false
}
object Par {
/**
* So that we can still treat functions as Par instances, just as if we defined Par[A] like
* 'type Par[A] = ExecutorService => Future[A]'
*/
private def asPar[A](body: ExecutorService => Future[A]): Par[A] = new Par[A] {
override def apply(ex: ExecutorService): Future[A] = body(ex)
}
def unit[A](a: A): Par[A] = UnitPar(a)
def fork[A](a: => Par[A]): Par[A] =
asPar { ex =>
println("- fork ")
val promise = new Promise[A]
def evaluate(future: Future[A]): Unit = {
if (future.isDone) {
promise.set(future.get())
} else {
// Try again later to evaluate this future, but let the thread
// do other stuff in the mean-time. It's a bit like "polling"
// now, but we're not blocking a thread.
ex.submit(new Runnable {
override def run(): Unit =
evaluate(future)
})
}
}
ex.submit(new Runnable {
override def run(): Unit = {
// Construct the future only once, which is why we have two different
// runnables .. the base case (this one) and the recursive one (see
// evaluate).
val future = a(ex)
evaluate(future)
}
})
promise
}
def lazyUnit[A](a: => A): Par[A] = fork { unit(a) }
def asyncF[A,B](f: A => B)(a: A): Par[B] = lazyUnit(f(a))
def memoize[A](a: => Par[A]): Par[A] = {
val futureReference = new AtomicReference[Future[A]]()
asPar {
ex =>
// Maybe one would rather use the first future that completes?
// (i.e. this is a bit simplistic, but it shows the intention nonetheless)
if (futureReference.get() == null) {
futureReference.compareAndSet(null, a(ex))
}
futureReference.get()
}
}
def map[A,B](pa: Par[A])(f: A => B): Par[B] =
map2(pa, unit(()))((a, _) => f(a))
def map2[A,B,C](pa: Par[A], pb: Par[B])(f: (A, B) => C): Par[C] =
(pa, pb) match {
// After all, a compiler replaces "3 + 4" with 7 as well ..
case (UnitPar(a), UnitPar(b)) => unit { f(a, b) }
// Otherwise there's nothing we can optimise in the computation graph
case _ => asPar { ex => Map2Future(pa(ex), pb(ex), f) }
}
def sequence[A](pas: List[Par[A]]): Par[List[A]] = pas match {
case Nil => unit(Nil)
case head :: tail =>
map2(head, sequence(tail))(_ :: _)
}
def parMap[A,B](as: List[A])(f: A => B): Par[List[B]] =
sequence(as.map(asyncF(f)))
}
object ParApp {
private val executorService = Executors.newSingleThreadExecutor()
def run[A](p: Par[A]): A = p(executorService).get()
def main(args: Array[String]) = {
val e = Par.map(Par.unit(5))(_ + 1) // will be compressed into Par.unit(6)
val f = Par.map2(Par.unit(5), Par.unit(3))(_ + _) // will be compressed into Par.unit(8)
// when running this you'll see:
// - fork
// - unit 4 (most probably this comes first)
// lazy run
// - unit 1
// 5 (the final result)
val g = Par.map2(Par.lazyUnit {
println("lazy run")
1
}, Par.unit(4))(_ + _)
// when running this you'll see:
// - fork
// - fork
// - unit 4
// - unit 3
// 7 (the final result)
val h = Par.map2(
Par.lazyUnit(4),
Par.lazyUnit(3)
)(_ + _)
// when running this you'll see:
// - fork
// - fork
// - fork
// - unit 8
// - unit 10
// 18
val i = Par.map2(
Par.fork {
Par.fork {
// will be compressed into Par.unit(10)
Par.map2(
Par.unit(9),
Par.unit(1)
)(_ + _)
}
},
Par.fork {
// will be compressed into Par.unit(8)
Par.map2(
Par.unit(3),
Par.unit(5)
)(_ + _)
}
)(_ + _)
// j will only be evaluated once!
val j = Par.memoize {
Par.lazyUnit {
println("call!")
1
}
}
val k = Par.map2(j, j)(_ + _)
println(run(e))
println(run(f))
println(run(g))
println(run(h))
println(run(i))
println(run(j))
println(run(k))
executorService.shutdown()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment