Created
May 19, 2022 13:57
-
-
Save igor-ramazanov/9e2217a35c36b4c8b1c021eaf8d7143a to your computer and use it in GitHub Desktop.
Toy IO/Fiber system
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.concurrent.atomic.AtomicInteger | |
import java.util.concurrent.Executors | |
import java.util.{Timer, TimerTask} | |
import scala.concurrent.duration._ | |
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Promise} | |
import scala.util.chaining._ | |
import scala.util.{Failure, Success, Try} | |
class Context(private val timer: Timer, private val threadPool: ExecutionContext) { | |
def schedule(task: => Unit, fd: FiniteDuration): Unit = { | |
val timerTask = new TimerTask { | |
def run(): Unit = sendToPool(task) | |
} | |
timer.schedule(timerTask, fd.toMillis) | |
} | |
def sendToPool[A](task: => A): Unit = threadPool.execute(() => { val _ = task }) | |
} | |
sealed trait IO[A] extends Product with Serializable { | |
/** Implemented in terms of [[runAsync]] */ | |
def runSync()(implicit context: Context): Try[A] = { | |
val p = Promise[A]() | |
this.runAsync(p.complete) | |
Await.ready(p.future, Duration.Inf).value.get | |
} | |
/** @param callback to handle a result of a computation, guaranteed to be invoked only once. */ | |
def runAsync(callback: Try[A] => Unit)(implicit context: Context): Unit = | |
IO.runLoop(this)(callback.asInstanceOf[Try[_] => Unit]) | |
def map[B](f: A => B): IO[B] = IO.Map(this, f) | |
def as[B](b: B): IO[B] = IO.Map(this, (_: A) => b) | |
def flatMap[B](f: A => IO[B]): IO[B] = IO.FlatMap(this, f) | |
def >>[B](f: IO[B]): IO[B] = IO.FlatMap(this, (_: A) => f) | |
def recover(f: PartialFunction[Throwable, A]): IO[A] = IO.Recover(this, f) | |
def recoverWith(f: PartialFunction[Throwable, IO[A]]): IO[A] = IO.RecoverWith(this, f) | |
def fork(): IO[IO.Fiber[A]] = IO.Fork(this) | |
} | |
object IO { | |
def apply[A](thunk: => A): IO[A] = IO.Delay(() => thunk) | |
def pure[A](a: A): IO[A] = IO.Pure(a) | |
def async[A](cb: (Try[A] => Unit) => Unit): IO[A] = IO.Async(cb) | |
def sleep(duration: FiniteDuration): IO[Unit] = IO.Sleep(duration) | |
def raiseError(e: Throwable): IO[Unit] = IO.Error(e) | |
def putStrLn(s: String): IO[Unit] = IO(println(s"${Thread.currentThread().getName}: $s")) | |
def shift: IO[Unit] = IO.Shift | |
final private case class Pure[A](value: A) extends IO[A] | |
final private case class Delay[A](thunk: () => A) extends IO[A] | |
final private case class Async[A](callback: (Try[A] => Unit) => Unit) extends IO[A] | |
final private case class FlatMap[A, B](prev: IO[A], f: A => IO[B]) extends IO[B] | |
final private case class Map[A, B](prev: IO[A], f: A => B) extends IO[B] | |
final private case class Recover[A](prev: IO[A], f: PartialFunction[Throwable, A]) extends IO[A] | |
final private case class RecoverWith[A](prev: IO[A], f: PartialFunction[Throwable, IO[A]]) extends IO[A] | |
final private case class Error(e: Throwable) extends IO[Unit] | |
final private case class Sleep[A](duration: FiniteDuration) extends IO[A] | |
final private case class Fork[A](io: IO[A]) extends IO[Fiber[A]] | |
final private case class Join[A](fiber: Fiber[A]) extends IO[A] | |
final private case object Shift extends IO[Unit] | |
class Fiber[A] { | |
private var callbacks = Set.empty[Try[A] => Unit] | |
private var result = Option.empty[Try[A]] | |
def join(): IO[A] = IO.Join(this) | |
private[IO] def register(cb: Try[A] => Unit): Unit = { | |
synchronized { | |
result match { | |
// To ensure the callback invoked only once. | |
case Some(value) => cb(value) | |
case None => callbacks = callbacks + cb | |
} | |
} | |
} | |
private[IO] def finish(res: Try[A]): Unit = { | |
synchronized { | |
// To ensure the callback invoked only once. | |
result = Some(res) | |
callbacks.foreach(_(res)) | |
callbacks = Set.empty | |
} | |
} | |
} | |
/** Optimised for maximum throughput, fairness must be ensured by the end developer by using [[IO.shift]]. */ | |
private def runLoop(io: IO[_])(done: Try[_] => Unit)(implicit context: Context): Unit = | |
// Evaluation should run in an intended thread pool since the beginning. | |
// Otherwise, the first computations would run in a default 'main' JVM thread. | |
context.sendToPool(eval(io)(done)) | |
private def eval(io: IO[_])(done: Try[_] => Unit)(implicit context: Context): Unit = | |
io match { | |
case IO.Pure(value) => done(Success(value)) | |
case IO.Delay(thunk) => done(Success(thunk())) | |
case IO.Async(asyncTaskDefinition) => asyncTaskDefinition(done) | |
case IO.FlatMap(prev, f) => | |
eval(prev) { | |
case Success(value) => eval(f.asInstanceOf[Any => IO[_]](value))(done) | |
case x => done(x) | |
} | |
case IO.Map(prev, f) => eval(prev)(res => done(res.map(f.asInstanceOf[Any => Any]))) | |
case IO.Recover(prev, f) => eval(prev)(res => done(res.recover(f))) | |
case IO.RecoverWith(prev, f) => | |
eval(prev) { | |
case Failure(e) if f.isDefinedAt(e) => eval(f(e))(done) | |
case x => done(x) | |
} | |
case IO.Error(e) => done(Failure(e)) | |
case IO.Sleep(duration) => context.schedule(done(Success(())), duration) | |
case _: IO.Fork[_] => | |
val fiber = new Fiber[Any] {} | |
context.sendToPool(eval(io.asInstanceOf[IO.Fork[_]].io)(fiber.finish)) | |
done(Success(fiber)) | |
case IO.Join(fiber) => fiber.register(done) | |
case IO.Shift => context.sendToPool(eval(io)(done)) | |
} | |
} | |
object Main { | |
def main(args: Array[String]): Unit = { | |
val program: IO[Unit] = (for { | |
fiber <- (IO.sleep(1.second) >> IO.putStrLn("1") >> IO.sleep(3.second) >> IO.putStrLn("2") >> IO.pure(42)).fork() | |
_ <- IO.putStrLn("3") | |
value <- fiber.join() | |
value2 <- fiber.join() | |
_ <- IO.raiseError(new RuntimeException(s"Boom! $value $value2")) | |
} yield ()).recoverWith { | |
case e => IO.putStrLn(e.getMessage) | |
} | |
implicit val context: Context = { | |
val timer: Timer = new Timer("Pet IO Timer", true) | |
val counter = new AtomicInteger(0) | |
val nThreads = 1 | |
val pool: ExecutionContextExecutor = ExecutionContext.fromExecutor( | |
Executors.newFixedThreadPool( | |
nThreads, | |
(r: Runnable) => new Thread(r).tap(_.setDaemon(true)).tap(_.setName(s"Pet IO ${counter.getAndIncrement()}")) | |
) | |
) | |
new Context(timer, pool) | |
} | |
println(program.runSync()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment