Skip to content

Instantly share code, notes, and snippets.

@zesterer
Created August 21, 2024 14:41
Show Gist options
  • Save zesterer/de62cbe06a47efee779859159394529c to your computer and use it in GitHub Desktop.
Save zesterer/de62cbe06a47efee779859159394529c to your computer and use it in GitHub Desktop.
//! Type checking and inference in 100 lines of Rust
//! ----------------------------------
//! (if you don't count comments)
#![allow(dead_code)]
/// The ID of a type variable.
#[derive(Copy, Clone, Debug, PartialEq)]
struct TyVar(usize);
/// Possibly-incomplete information known about a type variable's type.
#[derive(Copy, Clone, Debug)]
enum TyInfo {
/// No information is known about the type.
Unknown,
/// The type is equal to another type.
Ref(TyVar),
/// The type is an `Int`
Int,
/// The type is a `Bool`
Bool,
// Function, `A -> B`
Func(TyVar, TyVar),
}
/// An expression in the AST of a programming language.
#[derive(Debug)]
enum Expr<'a> {
/// Integer literal
Int(u64),
/// Boolean literal
Bool(bool),
/// Variable
Var(&'a str),
/// Let binding, `let lhs = rhs; then`
Let { lhs: &'a str, rhs: Box<Self>, then: Box<Self> },
/// Inline function/lambda/closure, `fn(arg) body`
Func { arg: &'a str, body: Box<Self> },
/// Function application/call, `func(arg)`
Apply { func: Box<Self>, arg: Box<Self> }
}
/// The final type of an expression.
#[derive(Debug)]
enum Ty {
/// The expression has type `Int`.
Int,
/// The expression has type `Bool`.
Bool,
/// The expression is a function from type `A` to `B`
Func(Box<Self>, Box<Self>),
}
/// Contains the state of the type solver.
#[derive(Default)]
struct Solver { vars: Vec<TyInfo> }
impl Solver {
/// Create a new type variable in the type solver' environment, with the given information.
fn create_ty(&mut self, info: TyInfo) -> TyVar { self.vars.push(info); TyVar(self.vars.len() - 1) }
/// Unify two type variables together, forcing them to be equal.
fn unify(&mut self, a: TyVar, b: TyVar) {
match (self.vars[a.0], self.vars[b.0]) {
(TyInfo::Unknown, _) => self.vars[a.0] = TyInfo::Ref(b),
(_, TyInfo::Unknown) => self.vars[b.0] = TyInfo::Ref(a),
(TyInfo::Ref(a), _) => self.unify(a, b),
(_, TyInfo::Ref(b)) => self.unify(a, b),
(TyInfo::Int, TyInfo::Int) | (TyInfo::Bool, TyInfo::Bool) => {},
(TyInfo::Func(a_i, a_o), TyInfo::Func(b_i, b_o)) => {
self.unify(a_i, b_i);
self.unify(a_o, b_o);
},
(a, b) => panic!("Type mismatch between {a:?} and {b:?}"),
}
}
/// Type-check an expression, returning a type variable representing its type, with the given environment/scope.
fn check<'ast>(&mut self, expr: &Expr<'ast>, env: &mut Vec<(&'ast str, TyVar)>) -> TyVar {
match expr {
// Literal expressions are easy, their type doesn't need inferring.
Expr::Int(_) => self.create_ty(TyInfo::Int),
Expr::Bool(_) => self.create_ty(TyInfo::Bool),
// We search the environment backward until we find a binding matching the variable name.
Expr::Var(name) => env.iter_mut().rev().find(|(n, _)| n == name).expect("No such variable in scope").1,
// In a let expression, `rhs` gets bound with name `lhs` in the environment used to type-check `then`.
Expr::Let { lhs, rhs, then } => {
let rhs = self.check(rhs, env);
env.push((lhs, rhs));
let out = self.check(then, env);
env.pop();
out
},
// In a function, the argument becomes an unknown type in the environment used to type-check `body`.
Expr::Func { arg, body } => {
let arg_ty = self.create_ty(TyInfo::Unknown);
env.push((arg, arg_ty));
let body = self.check(body, env);
env.pop();
self.create_ty(TyInfo::Func(arg_ty, body))
},
// During function application, both argument and function are type-checked and then we force the latter to be a function of the former.
Expr::Apply { func, arg } => {
let func = self.check(func, env);
let arg = self.check(arg, env);
let out = self.create_ty(TyInfo::Unknown);
let func_ty = self.create_ty(TyInfo::Func(arg, out));
self.unify(func_ty, func);
out
},
}
}
/// Convert a type variable into a final type once type-checking has finished.
pub fn solve(&self, var: TyVar) -> Ty {
match self.vars[var.0] {
TyInfo::Unknown => panic!("Cannot infer type"),
TyInfo::Ref(var) => self.solve(var),
TyInfo::Int => Ty::Int,
TyInfo::Bool => Ty::Bool,
TyInfo::Func(i, o) => Ty::Func(Box::new(self.solve(i)), Box::new(self.solve(o))),
}
}
}
fn expect(tokens: &mut &[Token], expected: Token) -> Result<(), String> {
match tokens {
[tok, tail @ ..] if *tok == expected => Ok(*tokens = tail),
[tok, ..] => Err(format!("Expected {expected:?}, found {tok:?}")),
[] => Err(format!("Expected {expected:?}, found end of input")),
}
}
fn parse_list<'a, R>(tokens: &mut &'a [Token], mut f: impl FnMut(&mut &'a [Token]) -> Result<R, String>) -> Result<Vec<R>, String> {
let mut items = Vec::new();
loop {
items.push(f(tokens)?);
match *tokens {
[Token::Comma] | [] => break Ok(items),
[Token::Comma, tail @ ..] => {
*tokens = tail;
},
[tok, ..] => return Err(format!("Expected argument, found {tok:?}")),
}
}
}
fn parse_ident<'a>(tokens: &mut &'a [Token]) -> Result<&'a str, String> {
match *tokens {
[Token::Ident(ident), tail @ ..] => {
*tokens = tail;
Ok(ident)
},
[tok, ..] => Err(format!("Expected ident, found {tok:?}")),
[] => Err(format!("Expected ident, found end of input")),
}
}
fn parse_expr<'a>(tokens: &mut &'a [Token]) -> Result<Expr<'a>, String> {
let mut expr = match *tokens {
[Token::Ident(name), tail @ ..] => {
*tokens = tail;
Expr::Var(name)
},
[Token::Int(x), tail @ ..] => {
*tokens = tail;
Expr::Int(*x)
},
[Token::Let, Token::Ident(lhs), tail @ ..] => {
*tokens = tail;
expect(tokens, Token::Eq)?;
let rhs = Box::new(parse_expr(tokens)?);
expect(tokens, Token::Semicolon)?;
let then = Box::new(parse_expr(tokens)?);
Expr::Let { lhs, rhs, then }
},
[Token::Fn, Token::Parens(args), tail @ ..] => {
*tokens = tail;
let args = parse_list(&mut &args[..], parse_ident)?;
args.into_iter().rev().fold(
parse_expr(tokens)?,
|body, arg| Expr::Func { arg, body: Box::new(body) },
)
},
[tok, ..] => return Err(format!("Expected expression, found {tok:?}")),
[] => return Err(format!("Expected expression, found end of input")),
};
while let [Token::Parens(args), tail @ ..] = *tokens {
*tokens = tail;
let args = parse_list(&mut &args[..], parse_expr)?;
expr = args.into_iter().fold(
expr,
|func, arg| Expr::Apply { func: Box::new(func), arg: Box::new(arg) },
);
}
Ok(expr)
}
#[derive(Debug, PartialEq)]
enum Token<'a> {
Int(u64),
Ident(&'a str),
Let,
Eq,
Fn,
Semicolon,
Comma,
Parens(Vec<Self>),
}
fn take<'a>(s: &mut &'a str, mut f: impl FnMut(&char) -> bool) -> &'a str {
match s.char_indices().skip_while(|(_, c) | f(c)).next() {
Some((idx, _)) => {
let r = &s[..idx];
*s = &s[idx..];
r
},
None => {
let r = *s;
*s = "";
r
},
}
}
fn skip(s: &mut &str) {
let mut chars = s.chars();
chars.next();
*s = chars.as_str();
}
fn lex<'a>(src: &mut &'a str) -> Result<Vec<Token<'a>>, String> {
let mut tokens = Vec::new();
loop {
tokens.push(match src.chars().next() {
Some(c) if c.is_ascii_digit() => {
let x = take(src, char::is_ascii_digit);
Token::Int(x.parse().unwrap())
},
Some(c) if c.is_ascii_alphabetic() => {
match take(src, char::is_ascii_alphanumeric) {
"let" => Token::Let,
"fn" => Token::Fn,
x => Token::Ident(x),
}
},
Some('=') => { skip(src); Token::Eq },
Some(';') => { skip(src); Token::Semicolon },
Some('(') => {
skip(src);
Token::Parens(lex(src)?)
},
Some(c) if c.is_whitespace() => { skip(src); continue },
Some(')') | None => { skip(src); break Ok(tokens) },
Some(c) => break Err(format!("Unexpected character {c:?}")),
})
}
}
fn main() {
let tokens = lex(&mut "let f = fn(x) x; f(42)").unwrap();
let expr = parse_expr(&mut &*tokens).unwrap();
println!("{expr:?}");
let mut solver = Solver::default();
let program_ty = solver.check(&expr, &mut Vec::new());
println!("The expression outputs type `{:?}`", solver.solve(program_ty));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment