Created
October 19, 2022 19:14
-
-
Save Heimdell/bf6fa1eae56c0b326f53d12313bfa332 to your computer and use it in GitHub Desktop.
Unifier
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FunctionalDependencies #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE ImportQualifiedPost #-} | |
{-# LANGUAGE BlockArguments #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE QuantifiedConstraints #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
module Bind where | |
import Control.Monad (when) | |
import Control.Monad.Except (MonadError(throwError)) | |
import Control.Monad.State | |
import Data.Foldable (for_) | |
import Data.Map qualified as Map | |
import Data.Set qualified as Set | |
import Data.String (IsString(fromString)) | |
import Var | |
import Unify | |
import Free | |
data TypeError n t e | |
= Mismatch (Term t n) (Term t n) | |
| Cycle (Var n) (Term t n) | |
| Other e | |
instance (Show e, forall x. Show x => Show (t x), Show n) => Show (TypeError n t e) where | |
show = \case | |
Mismatch fr fr' -> concat | |
[ "\nmismatch" | |
, " " <> show fr | |
, " =/=" | |
, " " <> show fr' | |
] | |
Cycle var fr -> concat | |
[ "\ncycle" | |
, " " <> show var | |
, " ~" | |
, " " <> show fr | |
] | |
Other e -> "\n" <> show e | |
class | |
( HasVars n m | |
, MonadError (TypeError n t e) m | |
, Unifiable t | |
) | |
=> | |
Binds n t e m | |
| n m -> t | |
, t m -> n | |
, n t -> e | |
where | |
(=:) :: Var n -> Term t n -> m () | |
see :: Var n -> m (Maybe (Term t n)) | |
instance {-# OVERLAPPABLE #-} | |
( Binds n t e m | |
, MonadTrans mt | |
, MonadError (TypeError n t e) (mt m) | |
) | |
=> | |
Binds n t e (mt m) | |
where | |
(=:) = (lift .) . (=:) | |
see = lift . see | |
prune :: (Binds n t e m) => Term t n -> m (Term t n) | |
prune = \case | |
t@Free {} -> return t | |
Pure var -> go [var] var | |
where | |
go visited var = do | |
see var >>= \case | |
Just (Pure var') -> do | |
when (var' `elem` visited) do | |
error $ "vars form a cycle: " <> show visited | |
end <- go (var : visited) var' | |
var =: end | |
return end | |
_ -> return (Pure var) | |
(=:=) :: (Binds n t e m) => Term t n -> Term t n -> m () | |
l0 =:= r0 = do | |
l <- prune l0 | |
r <- prune r0 | |
case (l, r) of | |
(Pure a, Pure b) -> do | |
t <- var (fromString "t") | |
a =: Pure t | |
b =: Pure t | |
(Pure a, b) -> assign a b | |
(a, Pure b) -> assign b a | |
(Free _ t, Free _ u) -> do | |
case coalesce t u of | |
Nothing -> throwError $ Mismatch l r | |
Just t' -> do | |
for_ t' \case | |
Left {} -> return () | |
Right (l', r') -> l' =:= r' | |
where | |
assign var term | |
| occurs var term = throwError $ Cycle var term | |
| otherwise = var =: term | |
update :: (Binds n t e m) => Term t n -> m (Term t n) | |
update = \case | |
Free set t -> do | |
t' <- traverse update t | |
return (wrap t') | |
Pure var -> do | |
see var >>= \case | |
Nothing -> return (Pure var) | |
Just t -> do | |
update t | |
refreshVarNames :: forall n t e m. (Binds n t e m) => Term t n -> m (Term t n) | |
refreshVarNames t = evalStateT (traverseFree rename t) Map.empty | |
where | |
rename :: Var n -> StateT (Map.Map (Var n) (Var n)) m (Var n) | |
rename n = do | |
gets (Map.lookup n) >>= \case | |
Nothing -> do | |
n' <- fresh n | |
modify (Map.insert n n') | |
return n' | |
Just n' -> do | |
return n' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE DerivingStrategies #-} | |
{-# LANGUAGE BlockArguments #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE QuantifiedConstraints #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
module Free where | |
import Data.Set (Set) | |
import Var (Var) | |
import qualified Data.Set as Set | |
import Data.Traversable (for) | |
data Free t v | |
= Free (Set v) (t (Free t v)) | |
| Pure v | |
type Term t n = Free t (Var n) | |
wrap :: (Ord v, Foldable t) => t (Free t v) -> Free t v | |
wrap t = Free (foldMap allVars t) t | |
allVars :: (Ord v) => Free t v -> Set v | |
allVars = \case | |
Free set _ -> set | |
Pure v -> Set.singleton v | |
occurs :: (Ord v) => v -> Free t v -> Bool | |
occurs var term = var `Set.member` allVars term | |
traverseFree :: (Monad f, Traversable t, Ord n, Ord m) => (n -> f m) -> Free t n -> f (Free t m) | |
traverseFree f = \case | |
Free set t -> do | |
set' <- Set.fromList <$> traverse f (Set.toList set) | |
t' <- traverse (traverseFree f) t | |
return $ Free set' t' | |
Pure n -> do | |
Pure <$> f n | |
instance (forall x. Show x => Show (t x), Show v) => Show (Free t v) where | |
show = \case | |
Free set t -> show t | |
Pure v -> show v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE QuantifiedConstraints #-} | |
module Scheme where | |
import Data.Set (Set) | |
import Bind | |
import Free | |
import Var | |
import Data.Map (Map) | |
import qualified Data.Set as Set | |
data Scheme t v = Scheme | |
{ schemeTypeArgs :: Set (Var v) | |
, schemeTypeBody :: Term t v | |
} | |
generalise :: (Binds n t e m) => Term t n -> m (Scheme t n) | |
generalise t = do | |
t' <- update t | |
return Scheme | |
{ schemeTypeArgs = allVars t' | |
, schemeTypeBody = t' | |
} | |
instantiate :: (Binds n t e m) => Scheme t n -> m (Term t n) | |
instantiate scheme = do | |
refreshVarNames (schemeTypeBody scheme) | |
instance (Show v, forall x. Show x => Show (t x)) => Show (Scheme t v) where | |
show scheme = | |
"forall " | |
<> unwords (map show (Set.toList (schemeTypeArgs scheme))) | |
<> ". " | |
<> show (schemeTypeBody scheme) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Unify where | |
class Traversable t => Unifiable t where | |
coalesce :: t a -> t a -> Maybe (t (Either a (a, a))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE MultiParamTypeClasses #-} | |
{-# LANGUAGE DerivingStrategies #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
module Var where | |
import Control.Monad.Trans.Class (MonadTrans(..)) | |
import Data.String (IsString) | |
data Var n = Var | |
{ varIndex :: Int | |
, varName :: n | |
} | |
deriving stock (Eq, Ord) | |
instance Show n => Show (Var n) where | |
show (Var 0 n) = show n | |
show (Var i n) = show n <> "#" <> show i | |
class (Show n, Ord n, IsString n, Monad m) => HasVars n m where | |
var :: n -> m (Var n) | |
fresh :: HasVars n m => Var n -> m (Var n) | |
fresh (Var _ n) = var n | |
instance {-# OVERLAPPABLE #-} | |
( HasVars n m | |
, MonadTrans t | |
, Monad (t m) | |
) | |
=> | |
HasVars n (t m) | |
where | |
var = lift . var |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment