Last active
March 29, 2022 11:26
-
-
Save igor-ramazanov/3bf9e0ca2363973b6cbaffa86f78005d to your computer and use it in GitHub Desktop.
Small notes I took during reading of Oleg Kiselyov's papers about Tagless Final: http://okmij.org/ftp/tagless-final/course/lecture.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//============================== | |
//Part 1: Introduction | |
//============================== | |
//Introduction to the "initial" encoding (also called Free monads/applicatives) | |
//and the "final" encodings (also called "Tagless Final") | |
//Main purpose of the both ways: | |
//1. create a strict and statically typed DSL languages | |
//2. write a "description" of a program using these languages | |
//3. (optional) perform possible various introspections or optimisations on the created "description" of a program | |
//4. define multiple possible ways of interpreting the encoded program | |
//5. run the program | |
//We will model a simple arithmetic typed DSL language | |
//using different encoding styles of the DSL language into the host language (Scala) | |
//"Initial" encoding of the language - uses ADT (Algebraic Data Types) | |
trait Expression | |
case class Literal(x: Int) extends Expression | |
case class Negation(v: Expression) extends Expression | |
case class Addition(a: Expression, b: Expression) extends Expression | |
//Our example we will work with | |
//8 + (- (1 + 2)) | |
//Program in "initial" encoding | |
val ti1 = Addition( | |
Literal(8), | |
Negation( | |
Addition( | |
Literal(1), | |
Literal(2)))) | |
//Evaluation of the program | |
def eval(e: Expression): Int = e match { | |
case Literal(v) => v | |
case Negation(a) => -eval(a) | |
case Addition(a, b) => eval(a) + eval(b) | |
} | |
eval(ti1) // 5: Int | |
//"Final" encoding of the language using expressions | |
type Representation = Int | |
def literal(n: Int): Representation = n | |
def negation(e: Representation): Representation = -e | |
def addition(a: Representation, b: Representation): Representation = a + b | |
//Evaluation of the program as expressions | |
val tf1 = addition( | |
literal(8), | |
negation( | |
addition( | |
literal(1), | |
literal(2)))) // 5: Int | |
//Different possible evaluation of the "initially"-encoded program | |
def view(e: Expression): String = e match { | |
case Literal(v) => v.toString | |
case Negation(a) => "(-" + view(a) + ")" | |
case Addition(a, b) => "(" + view(a) + "+" + view(b) + ")" | |
} | |
view(ti1) // (8+(-(1+2))): String | |
//The above "final" encoding doesn't allow to parameterize interpreters, | |
//so now we define "final" encoding using type-classes | |
trait ExprSymantics[Repr] { | |
def lit(v: Int): Repr | |
def neg(a: Repr): Repr | |
def add(a: Repr, b: Repr): Repr | |
} | |
//Different interpreters parameterised by a resulted type | |
implicit val expSymInt: ExprSymantics[Int] = new ExprSymantics[Int] { | |
override def lit(v: Int) = v | |
override def neg(a: Int) = -a | |
override def add(a: Int, b: Int) = a + b | |
} | |
implicit val expSymString: ExprSymantics[String] = new ExprSymantics[String] { | |
override def lit(v: Int): String = v.toString | |
override def neg(a: String): String = s"(-$a)" | |
override def add(a: String, b: String): String = s"($a+$b)" | |
} | |
//The above example program parameterised by a resulted type | |
def expr1[Repr: ExprSymantics]: Repr = { | |
val evidence = implicitly[ExprSymantics[Repr]] | |
import evidence._ | |
add(lit(8), neg(add(lit(1), lit(2)))) | |
} | |
//Substitution of different interpreters | |
expr1[Int] // 5: Int | |
expr1[String] // (8+(-(1+2))): String | |
//============================== | |
//Part 2: Extending the language | |
//============================== | |
//Extending the language using "initial" encoding: | |
//1. adding a new branch of ADT | |
//2. update evaluation functions | |
//causes changes or at least recompiling of all the dependent code | |
//this is an "expression problem": easy to add new operation on data, but hard to add new data variants | |
case class Mult(a: Expression, b: Expression) extends Expression | |
//Extending the language using "final" encoding: | |
//1. adding a new type-class | |
//2. adding a new interpreters | |
//3. combining a new interpreter with previous ones | |
trait MultSYM[Repr] { | |
def mul(a: Repr, b: Repr): Repr | |
} | |
//example: (7 + (-(1 * 2))) | |
def expr2[Repr: ExprSymantics : MultSYM]: Repr= { | |
val E = implicitly[ExprSymantics[Repr]] | |
val M = implicitly[MultSYM[Repr]] | |
import E._ | |
import M._ | |
add(lit(7), neg(mul(lit(1), lit(2)))) | |
} | |
//We have to implement additional interpreters (Scala's SAM feature is used) | |
implicit val multSymInt: MultSYM[Int] = (a: Int, b: Int) => a * b | |
implicit val multSymString: MultSYM[String] = (a: String, b: String) => | |
s"($a*$b)" | |
expr2[Int] // 5 | |
expr2[String] // (7+(-(1*2))) | |
//As you see, extending of the language | |
//using "final" encoding doesn't cause changes in existing code | |
//DSL language becomes easily extensible | |
//and extension mismatches are caught by the type-checker | |
//============================== | |
//Part 3: The de-serialization problem | |
//============================== | |
//Main statement: | |
//serialisation is simple (like above converting of the program to String) | |
//but deserialization is much harder | |
//Our target JSON-like format for serialisation: | |
trait Tree | |
case class Leaf(v: String) extends Tree | |
case class Node(v: String, ts: List[Tree]) extends Tree | |
//Serialisation: | |
//Serializer-interpreter: | |
implicit val expSymTree: ExprSymantics[Tree] = new ExprSymantics[Tree] { | |
override def lit(v: Int) = Node("Literal", List(Leaf(v.toString))) | |
override def neg(a: Tree) = Node("Negation", List(a)) | |
override def add(a: Tree, b: Tree) = Node("Addition", List(a, b)) | |
} | |
// 8 + (-(1 + 2)) | |
val tree = expr1[Tree] | |
//Node(Addition,List( | |
// Node(Literal,List(Leaf(8))), | |
// Node(Negation,List( | |
// Node(Addition,List( | |
// Node(Literal,List(Leaf(1))), | |
// Node(Literal,List(Leaf(2))))))))): Tree | |
//Deserialization: | |
//The input "Tree" structure may be invalid, so we need to handle errors | |
//We'll use the "Either" for this | |
type ErrMsg = String | |
def safeReadInt(s: String): Either[ErrMsg, Int] = { | |
import scala.util.Try | |
Try(s.toInt).toEither.left.map(_ => s"Couldn't parse to Int: '$s'") | |
} | |
//Here we want to convert a Tree directly into a result of a certain interpreter | |
//The 'A' type parameter is used for interpreter substitution | |
def fromTree[A: ExprSymantics](tree: Tree): Either[ErrMsg, A] = { | |
val E = implicitly[ExprSymantics[A]] | |
import E._ | |
tree match { | |
case Node("Literal", List(Leaf(n))) => | |
safeReadInt(n).right.map(lit) | |
case Node("Negation", List(subTree)) => | |
fromTree(subTree).right.map(neg) | |
case Node("Addition", List(leftSubTree, rightSubTree)) => | |
for (lt <- fromTree(leftSubTree); rt <- fromTree(rightSubTree)) | |
yield add(lt, rt) | |
case _ => | |
Left("Invalid tree") | |
} | |
} | |
fromTree[Int](tree) // Right(5): Either[ErrMsg, Int] | |
fromTree[String](tree) // Right((8+(-(1+2)))): Either[ErrMsg,String] | |
//This works but we can not construct a finally tagless tree representation waiting for later interpretation | |
//Let's define a Wrapper type for it | |
trait Wrapped { | |
def value[A: ExprSymantics]: A | |
} | |
//And write an interpeter from tree to Wrapped | |
implicit val WrappedInterpreter: ExprSymantics[Wrapped] = new ExprSymantics[Wrapped] { | |
override def lit(v: Int): Wrapped = new Wrapped { | |
override def value[A: ExprSymantics]: A = implicitly[ExprSymantics[A]].lit(v) | |
} | |
override def neg(a: Wrapped): Wrapped = new Wrapped { | |
override def value[A: ExprSymantics]: A = implicitly[ExprSymantics[A]].neg(a.value) | |
} | |
override def add(a: Wrapped, b: Wrapped): Wrapped = new Wrapped { | |
override def value[A: ExprSymantics]: A = implicitly[ExprSymantics[A]].add(a.value, b.value) | |
} | |
} | |
fromTree[Wrapped](tree) match { | |
case Left(err) => println(err) | |
case Right(wrapped) => | |
//here we can reuse wrapped value | |
wrapped.value[Int] // 5 | |
wrapped.value[String] // (8+(-(1+2))) | |
} | |
//Everything is fine but we still lack extensibility | |
//In order to add MultSYM we have to rewrite and recompile 'fromTree' function | |
//adding new 'case' clause | |
//This problem can be solved using open-recursion style | |
//Let's rewrite the 'fromTree' function in that style | |
def fromTreeExt[A: ExprSymantics] | |
(recur: => (Tree => Either[ErrMsg, A])) | |
: Tree => Either[ErrMsg, A] = { | |
val E = implicitly[ExprSymantics[A]] | |
import E._ | |
tree => tree match { | |
case Node("Literal", List(Leaf(n))) => | |
safeReadInt(n).right.map(lit) | |
case Node("Negation", List(subTree)) => | |
recur(subTree).right.map(neg) | |
case Node("Addition", List(leftSubTree, rightSubTree)) => | |
for (lt <- recur(leftSubTree); rt <- recur(rightSubTree)) | |
yield add(lt, rt) | |
case _ => | |
Left("Invalid tree") | |
} | |
} | |
//Fix point operator | |
def fix[A](f: (=> A) => A): A = f(fix(f)) | |
def fromTree2[A: ExprSymantics](t: Tree): Either[ErrMsg, A] = fix(fromTreeExt[A] _)(t) | |
fromTree2[Int](tree) // Right(5) | |
fromTree2[String](tree) // Right((8+(-(1+2)))) | |
//Here we defining new deserialisation logic without touching a previous code | |
def fromTreeExt2[A: ExprSymantics: MultSYM] | |
(recur: => (Tree => Either[ErrMsg, A])) | |
: Tree => Either[ErrMsg, A] = { | |
val E = implicitly[ExprSymantics[A]] | |
val M = implicitly[MultSYM[A]] | |
import E._ | |
import M._ | |
{ | |
case Node("Multiplication", List(leftSubTree, rightSubTree)) => | |
for (lt <- recur(leftSubTree); rt <- recur(rightSubTree)) | |
yield mul(lt, rt) | |
case t => fromTreeExt(recur).apply(t) | |
} | |
} | |
def fromTree3[A: ExprSymantics: MultSYM](t: Tree): Either[ErrMsg, A] = fix(fromTreeExt2[A] _)(t) | |
implicit val multSymTree: MultSYM[Tree] = new MultSYM[Tree] { | |
def mul(a: Tree, b: Tree): Tree = Node("Multiplication", List(a, b)) | |
} | |
def richProgram[A: ExprSymantics : MultSYM]: A = { | |
val E = implicitly[ExprSymantics[A]] | |
val M = implicitly[MultSYM[A]] | |
import E._ | |
import M._ | |
mul(lit(10), add(lit(2), lit(3))) | |
} | |
val treeWithMult = richProgram[Tree] | |
fromTree3[String](treeWithMult) // Right((10*(2+3))) | |
fromTree3[Int](treeWithMult) // Right(50) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment