Created
January 8, 2015 04:57
-
-
Save pierzchalski/c2ae38718b53cdfaf33e to your computer and use it in GitHub Desktop.
A quick and dirty demonstration of using implicits within the new Shapeless TypeClass framework. Here, we use 'coproduct.Length' to ensure uniform probabilities over coproduct types.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shapeless._ | |
import shapeless.ops.coproduct.Length | |
import shapeless.ops.nat.ToInt | |
import scala.util.Random | |
import scalaz.Reader | |
sealed trait T | |
case class A(s: String) extends T | |
case class B(n: Int, t: T) extends T | |
case class C(t: T) extends T | |
object DistDemo { | |
def main(args: Array[String]): Unit = { | |
import Dist._ | |
val rand = new Random() | |
val numTs = 1000 | |
implicit val dInt: Dist[Int] = Dist { _.nextInt(100) } | |
implicit val dString: Dist[String] = Dist { _.alphanumeric.take(5).mkString } | |
val ts = List.fill(numTs)(generator[T].run(rand)) | |
def size(t: T): Int = t match { | |
case A(s) => 1 | |
case B(n, tt) => 1 + size(tt) | |
case C(tt) => 1 + size(tt) | |
} | |
// showing off that it's actually uniformly distributed | |
println("part As: " + (ts.count(_.isInstanceOf[A]): Double) / ts.length) | |
println(" Bs: " + (ts.count(_.isInstanceOf[B]): Double) / ts.length) | |
println(" Cs: " + (ts.count(_.isInstanceOf[C]): Double) / ts.length) | |
// unlike in this simple case, | |
// many interesting trait trees don't actually have an expected size | |
// (for instance, binary trees). | |
println("Average size: " + (ts.map(size).sum: Double) / ts.length) | |
println("(Expected size: 3)") | |
} | |
} | |
trait Dist[A] { self => | |
def generator: Reader[Random, A] | |
} | |
object Dist { | |
def apply[A](gen: Random => A): Dist[A] = new Dist[A] { | |
override def generator: Reader[Random, A] = Reader(gen) | |
} | |
def generator[A](implicit da: Dist[A]): Reader[Random, A] = da.generator | |
implicit def deriveHNil: Dist[HNil] = Dist(_ => HNil) | |
implicit def deriveHCons[H, T <: HList](implicit | |
dh: Lazy[Dist[H]], | |
dt: Lazy[Dist[T]]): Dist[H :: T] = | |
Dist { rand => | |
dh.value.generator.run(rand) :: dt.value.generator.run(rand) | |
} | |
implicit def deriveCNil: Dist[CNil] = | |
Dist(_ => throw new Exception("Should never generate CNil")) | |
implicit def deriveCCons[L, R <: Coproduct, N <: Nat](implicit | |
dl: Lazy[Dist[L]], | |
dr: Lazy[Dist[R]], | |
len: Length.Aux[R, N], // the interesting bits | |
lenValue: ToInt[N]): Dist[L :+: R] = | |
Dist { rand => | |
if (rand.nextDouble() <= 1.0 / (1 + lenValue.apply())) { | |
Inl(dl.value.generator.run(rand)) | |
} else { | |
Inr(dr.value.generator.run(rand)) | |
} | |
} | |
implicit def deriveInstance[F, G](implicit gen: Generic.Aux[F, G], dg: Lazy[Dist[G]]): Dist[F] = | |
Dist(rand => gen.from(dg.value.generator.run(rand))) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment