Created
July 16, 2022 21:15
-
-
Save Heimdell/1fbad9c325253928c9ca8c6342847083 to your computer and use it in GitHub Desktop.
Unification ex-di
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE BlockArguments #-} | |
{-# LANGUAGE DerivingStrategies #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-# LANGUAGE DeriveDataTypeable #-} | |
{-# LANGUAGE DeriveGeneric #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE DeriveTraversable #-} | |
module Fixpoint where | |
import Control.Monad ((<=<)) | |
import Data.Data (Data, Typeable) | |
import GHC.Generics (Generic) | |
import Data.Foldable (fold) | |
-- | Fixpoint of a functor f, with some context c at each node. | |
-- | |
newtype Fix f = Fix | |
{ unFix :: f (Fix f) | |
} | |
deriving stock (Generic) | |
deriving stock instance (Show (f (Fix f))) => Show (Fix f) | |
deriving stock instance (Eq (f (Fix f))) => Eq (Fix f) | |
deriving stock instance (Ord (f (Fix f))) => Ord (Fix f) | |
deriving stock instance (Data (f (Fix f)), Typeable f) => Data (Fix f) | |
-- | Eliminator for `Fix`. | |
-- | |
cataFix :: (Functor f) => (f a -> a) -> Fix f -> a | |
cataFix alg = alg . fmap (cataFix alg) . unFix | |
-- | Eliminator for `Fix`, monadic. | |
-- | |
cataFixM :: (Traversable f, Monad m) => (f a -> m a) -> Fix f -> m a | |
cataFixM alg = do | |
alg <=< traverse (cataFixM alg) . unFix | |
-- | A variant of `Fix` with some nodes replaced by @a@. | |
-- | |
-- It is a free monad, yes. | |
-- | |
data Term f a | |
= Node (f (Term f a)) | |
| Leaf a | |
deriving stock (Generic, Functor, Foldable, Traversable) | |
newtype Unshow = Unshow { unShow :: String } | |
instance Show Unshow where | |
show = unShow | |
instance (Show a, Show (f Unshow), Functor f) => Show (Term f a) where | |
show = unShow . cataTerm (Unshow . show) (Unshow . show) | |
deriving stock instance (Eq a, Eq (f (Term f a))) => Eq (Term f a) | |
deriving stock instance (Ord a, Ord (f (Term f a))) => Ord (Term f a) | |
deriving stock instance (Data a, Data (f (Term f a)), Typeable f) => Data (Term f a) | |
cataTerm :: (Functor f) => (f b -> b) -> (a -> b) -> Term f a -> b | |
cataTerm node leaf = \case | |
Node layer -> node (fmap (cataTerm node leaf) layer) | |
Leaf a -> leaf a | |
cataTermM :: (Traversable f, Monad m) => (f b -> m b) -> (a -> m b) -> Term f a -> m b | |
cataTermM node leaf = \case | |
Node layer -> node =<< traverse (cataTermM node leaf) layer | |
Leaf a -> leaf a | |
unfreeze :: (Functor f) => Fix f -> Term f a | |
unfreeze = cataFix Node | |
freeze :: (Traversable f) => Term f a -> Maybe (Fix f) | |
freeze = cataTermM (Just . Fix) (const Nothing) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cabal-version: 2.4 | |
name: unification-xd | |
version: 0.1.0.0 | |
-- A short (one-line) description of the package. | |
-- synopsis: | |
-- A longer description of the package. | |
-- description: | |
-- A URL where users can report bugs. | |
-- bug-reports: | |
-- The license under which the package is released. | |
-- license: | |
author: Kirill Andreev | |
maintainer: Kirill.Andreev@kaspersky.com | |
-- A copyright notice. | |
-- copyright: | |
-- category: | |
extra-source-files: CHANGELOG.md | |
library | |
hs-source-dirs: src | |
build-depends: base, mtl, transformers, containers, microlens-platform, shower | |
hs-source-dirs: app | |
default-language: Haskell2010 | |
exposed-modules: Unification, Fixpoint | |
default-extensions: | |
LambdaCase | |
BlockArguments | |
DerivingStrategies | |
StandaloneDeriving | |
UndecidableInstances | |
DeriveDataTypeable | |
DeriveGeneric | |
MultiParamTypeClasses | |
FunctionalDependencies | |
TypeOperators | |
DefaultSignatures | |
FlexibleContexts | |
FlexibleInstances | |
ImportQualifiedPost | |
GeneralizedNewtypeDeriving | |
TemplateHaskell | |
ExplicitForAll | |
TypeApplications | |
DeriveAnyClass | |
DeriveFunctor | |
DeriveFoldable | |
DeriveTraversable | |
ScopedTypeVariables | |
executable unification-xd | |
main-is: Main.hs | |
-- Modules included in this executable, other than Main. | |
-- other-modules: | |
-- LANGUAGE extensions used by modules in this package. | |
-- other-extensions: | |
build-depends: base ^>=4.14.3.0 | |
hs-source-dirs: app | |
default-language: Haskell2010 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Unification where | |
import Control.Monad.Writer | |
import Control.Monad.Except | |
import Control.Monad.State | |
import Control.Monad.Identity | |
import Data.IntMap qualified as IntMap | |
import Data.IntMap (IntMap) | |
import Data.IntSet qualified as IntSet | |
import Data.IntSet (IntSet) | |
import Data.Functor.Compose | |
import Data.Foldable (fold) | |
import Data.Traversable (for) | |
import GHC.Generics | |
import Lens.Micro.Platform | |
import Shower | |
import Fixpoint | |
class Traversable f => Unifiable f where | |
match :: f a -> f a -> Maybe (f (Either a (a, a))) | |
default | |
match :: (Generic1 f, Unifiable (Rep1 f)) => f a -> f a -> Maybe (f (Either a (a, a))) | |
match a b = to1 <$> match (from1 a) (from1 b) | |
instance Unifiable V1 where | |
match a _ = return $ Left <$> a | |
instance Unifiable U1 where | |
match a _ = return $ Left <$> a | |
instance Unifiable Par1 where | |
match (Par1 a) (Par1 b) = return $ Par1 $ Right (a, b) | |
instance (Unifiable f) => Unifiable (Rec1 f) where | |
match (Rec1 a) (Rec1 b) = Rec1 <$> match a b | |
instance (Eq c) => Unifiable (K1 i c) where | |
match (K1 a) (K1 b) | |
| a == b = return (K1 a) | |
| otherwise = Nothing | |
instance (Unifiable f) => Unifiable (M1 i c f) where | |
match (M1 a) (M1 b) = M1 <$> match a b | |
instance (Unifiable f, Unifiable g) => Unifiable (f :+: g) where | |
match a b = case (a, b) of | |
(L1 q, L1 w) -> L1 <$> match q w | |
(R1 q, R1 w) -> R1 <$> match q w | |
_ -> Nothing | |
instance (Unifiable f, Unifiable g) => Unifiable (f :*: g) where | |
match (a :*: c) (b :*: d) = pure (:*:) <*> match a b <*> match c d | |
instance (Unifiable f, Unifiable g) => Unifiable (f :.: g) where | |
match (Comp1 a) (Comp1 b) = do | |
res' <- match a b | |
res'' <- for res' \case | |
Left ga -> return $ Left <$> ga | |
Right (ga, gb) -> match ga gb | |
return $ Comp1 res'' | |
class Variable v where | |
getVarId :: v -> Int | |
makeVar :: Int -> v | |
class | |
( Variable v | |
, Unifiable t | |
, Monad m | |
) | |
=> BindingMonad m t v | |
| m t -> v | |
, m v -> t | |
where | |
find :: v -> m (Maybe (Term t v)) | |
fresh :: m v | |
new :: Term t v -> m v | |
(=:) :: v -> Term t v -> m (Term t v) | |
new t = do | |
v <- fresh | |
v =: t | |
return v | |
instance {-# OVERLAPPABLE #-} | |
( BindingMonad m t v | |
, MonadTrans h | |
, Monad (h m) | |
) => BindingMonad (h m) t v | |
where | |
find = lift . find | |
fresh = lift fresh | |
(=:) = (lift .) . (=:) | |
data UnifState t v = UnifState | |
{ _usMap :: IntMap (Term t v) | |
, _usCounter :: Int | |
} | |
deriving stock instance (Show a, Show (f Unshow), Functor f) => Show (UnifState f a) | |
makeLenses ''UnifState | |
startUnifState :: UnifState t v | |
startUnifState = UnifState mempty 0 | |
data UnificationError t v | |
= Occurs v (Term t v) | |
| Mismatch (Term t v) (Term t v) | |
deriving stock instance (Show a, Show (f Unshow), Functor f) => Show (UnificationError f a) | |
type Unification t v = UnificationT t v Identity | |
newtype UnificationT t v m a = UnificationT | |
{ runUnificationT | |
:: StateT (UnifState t v) | |
( ExceptT (UnificationError t v) | |
m ) a | |
} | |
deriving newtype | |
( Functor | |
, Applicative | |
, Monad | |
, MonadError (UnificationError t v) | |
) | |
instance MonadTrans (UnificationT t v) where | |
lift = UnificationT . lift . lift | |
runUnification :: forall t v a. Unification t v a -> Either (UnificationError t v) (a, UnifState t v) | |
runUnification unif | |
= runIdentity | |
$ runExceptT | |
$ flip runStateT startUnifState | |
$ runUnificationT unif | |
instance | |
(Variable v, Unifiable t, Monad m) | |
=> BindingMonad (UnificationT t v m) t v | |
where | |
find v = UnificationT do IntMap.lookup (getVarId v) <$> use usMap | |
fresh = do | |
vId <- UnificationT do use usCounter | |
UnificationT do usCounter += 1 | |
return (makeVar vId) | |
v =: t = do | |
UnificationT do usMap %= IntMap.insert (getVarId v) t | |
return t | |
prune :: BindingMonad m t v => Term t v -> m (Term t v) | |
prune = \case | |
t@Node{} -> return t | |
Leaf v -> do | |
find v >>= \case | |
Just v' -> do | |
t <- prune v' | |
v =: t | |
Nothing -> return (Leaf v) | |
semiprune :: BindingMonad m t v => Term t v -> m (Term t v) | |
semiprune t = case t of | |
Node{} -> return t | |
Leaf v -> loop v t | |
where | |
loop v0 t0 = do | |
find v0 >>= \case | |
Nothing -> return t0 | |
Just Node{} -> return t0 | |
Just t@(Leaf v) -> do | |
final <- loop v t | |
v0 =: final | |
newtype OccursCheckT t v m a = OccursCheckT | |
{ runOccursCheckT :: StateT (IntMap (t (Term t v))) m a | |
} | |
deriving newtype | |
( Functor | |
, Applicative | |
, Monad | |
, MonadError e | |
) | |
instance MonadTrans (OccursCheckT t v) where | |
lift = OccursCheckT . lift | |
class (Monad m) => MonadOccurs m t v where | |
seenAs :: v -> t (Term t v) -> m () | |
instance (MonadError (UnificationError t v) m, Variable v) => MonadOccurs (OccursCheckT t v m) t v where | |
seenAs v t = do | |
vars <- OccursCheckT get | |
case IntMap.lookup (getVarId v) vars of | |
Just t -> throwError $ Occurs v $ Node t | |
Nothing -> OccursCheckT $ modify $ IntMap.insert (getVarId v) t | |
getFreeVars | |
:: forall m t v list | |
. (BindingMonad m t v, Traversable list) | |
=> list (Term t v) -> m [v] | |
getFreeVars list = do | |
idSet <- evalStateT (fold <$> traverse loop list) IntSet.empty | |
return $ map makeVar $ IntSet.toList idSet | |
where | |
varsOf :: Foldable f => f v -> [Int] | |
varsOf = IntSet.toList . foldMap (IntSet.singleton . getVarId) | |
loop :: Term t v -> StateT IntSet m IntSet | |
loop t = do | |
semiprune t >>= \case | |
Node t -> fold <$> traverse loop t | |
Leaf v -> do | |
let vId = getVarId v | |
done <- gets (IntSet.member vId) | |
if done then return mempty | |
else do | |
modify (IntSet.insert vId) | |
find v >>= \case | |
Nothing -> return $ IntSet.singleton $ getVarId vId | |
Just t -> loop t | |
-- applyBindings | |
-- :: forall m t v list | |
-- . (BindingMonad m t v, Traversable list) | |
-- => list (Term t v) -> m (list (Term t v)) | |
-- applyBindings list = do | |
-- evalStateT () IntMap.empty | |
-------------------------------------------------------------------------------- | |
data Type self | |
= TArr self self | |
| TCon String | |
| TSet self | |
| TMap self self | |
| TRec [(String, self)] | |
deriving stock (Generic1, Functor, Foldable, Traversable, Show) | |
deriving anyclass (Unifiable) | |
deriving anyclass instance (Eq a) => Unifiable ((,) a) | |
deriving anyclass instance Unifiable [] | |
instance Variable Int where | |
getVarId = id | |
makeVar = id | |
deriving anyclass instance Unifiable Maybe | |
test = printer do | |
runUnification @Type @Int do | |
a <- fresh | |
b <- fresh | |
c <- new $ Node $ TCon "Int" | |
a =: Leaf b | |
b =: Leaf c | |
semiprune (Leaf a) | |
test1 = printer do | |
runUnification @Type @Int do | |
a <- fresh | |
b <- fresh | |
c <- new $ Node $ TCon "Int" | |
a =: Leaf b | |
b =: Leaf c | |
getFreeVars [Node $ TArr (Leaf a) (Leaf 42)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment