Created
May 27, 2024 12:45
-
-
Save tschuchortdev/3f02c32b4a2ddd3dd2158060b7b3bd6b to your computer and use it in GitHub Desktop.
How to check exhaustivity of a match expression with Scala 3 macros
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private def matchExhaustivelyImplImpl[T: Type]( | |
self: Expr[T], | |
expr: Expr[T => Any], | |
m: Expr[Mirror.Of[T]] | |
)(using q: Quotes): Expr[Any] = | |
import q.reflect.{*, given} | |
val expectedCases = m match | |
case '{ $m: Mirror.ProductOf[s] } => Seq(TypeRepr.of[T]) | |
case '{ | |
type elems <: Tuple; | |
$m: Mirror.SumOf[s] { type MirroredElemTypes = `elems` } | |
} => | |
tupleToTypeReprs[elems] | |
/*val cases2: Seq[CaseDef] = new TreeAccumulator[Seq[CaseDef]] { | |
override def foldTree(acc: Seq[CaseDef], tree: Tree)(owner: Symbol): Seq[CaseDef] = tree match | |
case Match(matchedVar, cases) => cases | |
case _ => super.foldOverTree(acc, tree)(owner) | |
}.foldOverTree(Seq.empty, expr.asTerm)(Symbol.spliceOwner)*/ | |
val caseDefs = expr.asTerm match | |
case Inlined(_, | |
_, | |
TypeApply( | |
Select( | |
Block( | |
List( | |
DefDef( | |
lambdaName, | |
List(TermParamClause(List(ValDef(lambdaParamName, lambdaParamType, _)))), | |
_, | |
Some(Match(matchVar @ Ident(matchVarName), cases)) | |
) | |
), | |
Closure(Ident(closureName), _) | |
), | |
"$asInstanceOf$" | |
), | |
_ | |
)) | |
if closureName == lambdaName && matchVarName == lambdaParamName => | |
cases | |
case _ => report.errorAndAbort("Must be a lambda with top-level match expression", expr) | |
def computeMatchedType(caseDefPattern: Tree): Seq[TypeRepr] = caseDefPattern match | |
case Alternatives(patterns) => patterns.flatMap(computeMatchedType) | |
case TypedOrTest(_, tpt) => | |
assert(tpt.symbol.isType) | |
List(tpt.tpe) | |
case Bind(bindName, tr) => | |
assert(tr.symbol.isType) | |
List(tr.symbol.typeRef.widenByName) | |
case Unapply(fun @ Select(sel @ Apply(TypeApply(_, typeArgs), _), "unapply"), implicits, bindPatterns) => | |
fun.tpe.widenTermRefByName match | |
// A MethodType is a regular method taking term parameters, a PolyType is a method taking type parameters, | |
// a TypeLambda is a method returning a type and not a value. Unapply's type should be a function with no | |
// type parameters, with a single value parameter (the match scrutinee) and with an Option[?] return type | |
// (no curried function), thus it should be a MethodType. | |
case methodType: MethodType => | |
methodType.resType.asType match | |
// Also matches Some[] and None in an easy way | |
case '[Option[tpe]] => TypeRepr.of[tpe] match | |
case AndType(left, right) | |
if methodType.paramTypes.nonEmpty && left =:= methodType.param(0) => List(right) | |
case AndType(left, right) | |
if methodType.paramTypes.nonEmpty && right =:= methodType.param(0) => List(left) | |
case tpe => List(tpe) | |
case '[tpe] => List(TypeRepr.of[tpe]) | |
case tpe: TypeRepr => throw AssertionError( | |
s"Expected type of Unapply function to be MethodType. Was: ${Printer.TypeReprStructure.show(tpe)}" | |
) | |
case pattern => | |
throw AssertionError(s"Expected pattern of CaseDef to be either Alternative, TypedOrTest, Bind or Unapply. " + | |
s"Was: ${Printer.TreeStructure.show(pattern)}") | |
val caseDefTypes = caseDefs.flatMap { caseDef => | |
if caseDef.guard.isDefined then List() | |
else computeMatchedType(caseDef.pattern) | |
} | |
val uncoveredCases = expectedCases.map(_.asType).filterNot { case '[expectedCase] => | |
caseDefTypes.map(_.asType).exists { case '[caseDefType] => | |
(TypeRepr.of[expectedCase] <:< TypeRepr.of[caseDefType]) | |
|| Expr.summon[expectedCase <:< caseDefType].isDefined | |
} | |
} | |
if uncoveredCases.nonEmpty then | |
val casesString = uncoveredCases.map { t => | |
"_: " + Printer.TypeReprCode.show(typeReprOf(t)) | |
}.mkString(", ") | |
report.warning( | |
s"Match may not be exhaustive.\n\nIt would fail on case: $casesString", | |
Position(self.asTerm.pos.sourceFile, start = expr.asTerm.pos.start - 1, end = expr.asTerm.pos.start + 1) | |
) | |
'{ $expr($self) } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment