Skip to content

Instantly share code, notes, and snippets.

@Cedev
Last active August 29, 2015 14:13
Show Gist options
  • Save Cedev/ee468d046ca0b95b6ec4 to your computer and use it in GitHub Desktop.
Save Cedev/ee468d046ca0b95b6ec4 to your computer and use it in GitHub Desktop.
Compiler from ArrowLike interface for primitive recursive functions to LLVM
module Control.PrimRec (
ArrowLike (..),
PrimRec (..),
module Control.Category,
module Data.Nat
) where
import Control.Category
import Data.Nat
import Prelude hiding (id, (.), fst, snd, succ)
import qualified Prelude (fst, snd)
class Category a => ArrowLike a where
fst :: a (b, d) b
snd :: a (d, b) b
(&&&) :: a b c -> a b c' -> a b (c,c')
first :: a b c -> a (b, d) (c, d)
first = (*** id)
second :: a b c -> a (d,b) (d,c)
second = (id ***)
(***) :: a b c -> a b' c' -> a (b,b') (c,c')
f *** g = (f . fst) &&& (g . snd)
class ArrowLike a => PrimRec a where
zero :: a b Nat
succ :: a Nat Nat
prec :: a e c -> a (c, (Nat,e)) c -> a (Nat, e) c
instance ArrowLike (->) where
fst = Prelude.fst
snd = Prelude.snd
f &&& g = \b -> (f b, g b)
instance PrimRec (->) where
zero = const Z
succ = S
prec f g = go
where
go (Z, d) = f d
go (S n, d) = g (go (n, d), (n, d))
module Data.Nat where
data Nat = Z | S Nat
deriving (Eq, Show, Read, Ord)
; ModuleID = 'main'
@formatIn = global [5 x i8] c"%llu\00"
@formatOut = global [6 x i8] c"%llu\0a\00"
declare i32 @scanf(i8*, ...)
declare i32 @printf(i8*, ...)
define void @n1({i64}* %n3, {i64, i64}* %n2){
n4:
%n8 = getelementptr inbounds {i64, i64}* %n2, i32 0, i32 0
%n9 = getelementptr inbounds {i64, i64}* %n2, i32 0, i32 1
%n10 = load i64* %n8
%n11 = load i64* %n9
br label %n5
n5:
%n12 = icmp eq i64 %n10, 0
br i1 %n12, label %n6, label %n7
n6:
%n13 = add i64 0, 1
%n14 = getelementptr inbounds {i64}* %n3, i32 0, i32 0
store i64 %n13, i64* %n14
ret void
n7:
%n15 = sub i64 %n10, 1
%n16 = alloca {i64, i64}
%n17 = alloca {i64}
%n18 = getelementptr inbounds {i64, i64}* %n16, i32 0, i32 0
%n19 = getelementptr inbounds {i64, i64}* %n16, i32 0, i32 1
store i64 %n15, i64* %n18
store i64 %n11, i64* %n19
call void @n1({i64}* %n17, {i64, i64}* %n16)
%n20 = getelementptr inbounds {i64}* %n17, i32 0, i32 0
%n21 = load i64* %n20
%n22 = getelementptr inbounds {i64}* %n3, i32 0, i32 0
store i64 0, i64* %n22
ret void
}
define void @n23({i64}* %n25, {i64, i64}* %n24){
n26:
%n30 = getelementptr inbounds {i64, i64}* %n24, i32 0, i32 0
%n31 = getelementptr inbounds {i64, i64}* %n24, i32 0, i32 1
%n32 = load i64* %n30
%n33 = load i64* %n31
br label %n27
n27:
%n34 = icmp eq i64 %n32, 0
br i1 %n34, label %n28, label %n29
n28:
%n35 = getelementptr inbounds {i64}* %n25, i32 0, i32 0
store i64 0, i64* %n35
ret void
n29:
%n36 = sub i64 %n32, 1
%n37 = alloca {i64, i64}
%n38 = alloca {i64}
%n39 = getelementptr inbounds {i64, i64}* %n37, i32 0, i32 0
%n40 = getelementptr inbounds {i64, i64}* %n37, i32 0, i32 1
store i64 %n36, i64* %n39
store i64 %n33, i64* %n40
call void @n23({i64}* %n38, {i64, i64}* %n37)
%n41 = getelementptr inbounds {i64}* %n38, i32 0, i32 0
%n42 = load i64* %n41
%n43 = alloca {i64, i64}
%n44 = alloca {i64}
%n45 = getelementptr inbounds {i64, i64}* %n43, i32 0, i32 0
%n46 = getelementptr inbounds {i64, i64}* %n43, i32 0, i32 1
store i64 %n42, i64* %n45
store i64 %n42, i64* %n46
call void @n1({i64}* %n44, {i64, i64}* %n43)
%n47 = getelementptr inbounds {i64}* %n44, i32 0, i32 0
%n48 = load i64* %n47
%n49 = getelementptr inbounds {i64}* %n25, i32 0, i32 0
store i64 %n48, i64* %n49
ret void
}
define void @main(){
mainBlock:
%n1 = alloca i64
%n2 = getelementptr inbounds [5 x i8]* @formatIn, i32 0, i32 0
call i32 (i8*, ...)* @scanf(i8* %n2, i64* %n1)
%n3 = load i64* %n1
%n4 = alloca {i64, i64}
%n5 = alloca {i64}
%n6 = getelementptr inbounds {i64, i64}* %n4, i32 0, i32 0
%n7 = getelementptr inbounds {i64, i64}* %n4, i32 0, i32 1
store i64 %n3, i64* %n6
store i64 %n3, i64* %n7
call void @n23({i64}* %n5, {i64, i64}* %n4)
%n8 = getelementptr inbounds {i64}* %n5, i32 0, i32 0
%n9 = load i64* %n8
%n10 = getelementptr inbounds [6 x i8]* @formatOut, i32 0, i32 0
call i32 (i8*, ...)* @printf(i8* %n10, i64 %n9)
ret void
}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeOperators #-}
import GHC.Exts (Constraint)
import Data.Proxy
import Data.Word
import Data.Char (ord)
import Control.PrimRec
import Prelude hiding (
id, (.), fst, snd, succ,
sequence, sequence_, foldr,
add)
import LLVM.General.AST hiding (type')
import LLVM.General.AST.Global
import LLVM.General.AST.Type
import qualified LLVM.General.AST.Constant as C
import qualified LLVM.General.AST.IntegerPredicate as ICmp
import qualified LLVM.General.AST.CallingConvention as CallingConvention
import Data.Monoid
import Data.Foldable
import Data.Traversable
import Control.Applicative
import Control.Monad (forever)
import Control.Monad.Trans.Class
import Control.Monad.Trans.Writer.Strict (tell)
import Data.Functor.Identity
import Control.Monad.Morph
import LLVM.General.Pretty
import Pipes hiding (Proxy, void)
import qualified Pipes as P
import qualified Pipes.Prelude as P
import Pipes.Lift (runWriterP)
-----------------------------
-- Constrained Category
-----------------------------
type Build w = Pipe Name w
getName :: (Monad m) => Build w m (Name)
getName = await
instr :: (Monad m) => Named Instruction -> Build (Named Instruction) m ()
instr = yield
data RegisterArrow m x y where
RegisterArrow :: (Registerable x, Registerable y) =>
(
Build Definition m (
Operands x ->
Build (Named Instruction) m (Operands y)
)
) -> RegisterArrow m x y
------------------------------
-- Functors
------------------------------
data (:*:) f g a = f a :*: g a
instance (Functor f, Functor g) => Functor (f :*: g) where
fmap f (x :*: y) = fmap f x :*: fmap f y
instance (Foldable f, Foldable g) => Foldable (f :*: g) where
foldMap f (x :*: y) = foldMap f x `mappend` foldMap f y
instance (Traversable f, Traversable g) => Traversable (f :*: g) where
traverse f (x :*: y) = (:*:) <$> traverse f x <*> traverse f y
class Apply f where
(<.>) :: f (a -> b) -> f a -> f b
infixl 4 <.>
instance Apply Identity where
(<.>) = (<*>)
instance Apply Proxy where
(<.>) = (<*>)
instance (Apply f, Apply g) => Apply (f :*: g) where
(fx :*: fy) <.> (x :*: y) = (fx <.> x) :*: (fy <.> y)
--------------------------
-- Constraint
--------------------------
class (Traversable (RegisterRep a), Apply (RegisterRep a)) => Registerable a where
type RegisterRep a :: * -> *
type RegisterableCtx a :: Constraint
registerableDict :: Proxy a -> RegisterableDict a
types :: Proxy a -> Registers a Type
instance Registerable Nat where
type RegisterRep Nat = Identity
type RegisterableCtx Nat = ()
registerableDict _ = Dict
types _ = Registers . Identity $ IntegerType 64
instance Registerable () where
type RegisterRep () = Proxy
type RegisterableCtx () = ()
registerableDict _ = Dict
types _ = Registers Proxy
instance (Registerable a, Registerable b) => Registerable (a, b) where
type RegisterRep (a, b) = Registers a :*: Registers b
type RegisterableCtx (a, b) = (Registerable a, Registerable b)
registerableDict _ = Dict
types _ = Registers $ types (Proxy :: Proxy a) :*: types (Proxy :: Proxy b)
--------------------------
-- Constraint Juggling
--------------------------
data Dict c where
Dict :: c => Dict c
type RegisterableDict a = Dict (Registerable a, RegisterableCtx a)
data PRFCompiled m a b where
BlockLike :: (RegisterableDict a -> RegisterArrow m a b) -> PRFCompiled m a b
data Registers r a where
Registers :: Registerable r => RegisterRep r a -> Registers r a
type Operands f = Registers f Operand
instance Functor (Registers r) where
fmap f (Registers xs) = Registers (fmap f xs)
instance Foldable (Registers r) where
foldr f z (Registers xs) = foldr f z xs
instance Traversable (Registers r) where
traverse f (Registers xs) = fmap Registers (traverse f xs)
instance Apply (Registers r) where
Registers f <.> Registers x = Registers (f <.> x)
number :: (Enum e, Traversable t) => (a -> e -> b) -> t a -> t b
number f = snd . mapAccumL (\(h:t) a -> (t, f a h)) [toEnum 0..]
rarrowDict :: forall m x y. RegisterArrow m x y -> Dict (Registerable x, Registerable y, RegisterableCtx x, RegisterableCtx y)
rarrowDict (RegisterArrow _) =
case registerableDict (Proxy :: Proxy x)
of Dict ->
case registerableDict (Proxy :: Proxy y)
of Dict -> Dict
fstDict :: forall a b. RegisterableDict (a, b) -> RegisterableDict a
fstDict Dict = case registerableDict (Proxy :: Proxy a) of Dict -> Dict
sndDict :: forall a b. RegisterableDict (a, b) -> RegisterableDict b
sndDict Dict = case registerableDict (Proxy :: Proxy b) of Dict -> Dict
--------------------------
-- Instances
--------------------------
instance (Monad m) => Category (PRFCompiled m) where
id = BlockLike $ \Dict -> RegisterArrow . return $ return
BlockLike df . BlockLike dg = BlockLike $ \Dict ->
case dg Dict
of rg@(RegisterArrow mg) ->
case rarrowDict rg
of Dict ->
case df Dict
of RegisterArrow mf -> RegisterArrow $ do
g <- mg
f <- mf
return (\a -> g a >>= f)
instance (Monad m) => ArrowLike (PRFCompiled m) where
fst = BlockLike $ \Dict -> RegisterArrow . return $ \(Registers (regs :*: _)) -> return regs
snd = BlockLike $ \Dict -> RegisterArrow . return $ \(Registers (_ :*: regs)) -> return regs
BlockLike df &&& BlockLike dg = BlockLike $ \Dict ->
case (df Dict, dg Dict)
of (RegisterArrow mf, RegisterArrow mg) -> RegisterArrow $ do
f <- mf
g <- mg
return $ \regs -> do
rf <- f regs
rg <- g regs
return $ Registers (rf :*: rg)
instance (Monad m) => PrimRec (PRFCompiled m) where
zero = BlockLike $ \Dict -> RegisterArrow . return $ \_ -> return . Registers . Identity . constant $ C.Int 64 0
succ = BlockLike $ \Dict -> RegisterArrow . return $ regSucc
where
regSucc (Registers op) = (>>= return) . traverse opSucc $ Registers op
opSucc op = bind i64 $ add op (constant $ C.Int 64 1)
prec (BlockLike df) (BlockLike dg) = BlockLike $ \d@Dict ->
case df $ sndDict d
of (RegisterArrow mf) ->
case dg Dict
of (RegisterArrow mg) -> RegisterArrow $ do
f <- mf
g <- mg
defineRecursive $ \go read ret -> do
headName <- getName
brName <- getName
zeroName <- getName
succName <- getName
rs@(Registers (Registers (Identity n) :*: e)) <- block headName $ do
rs <- read
return (br brName,rs)
block' brName $ do
cmp <- bind i1 $ icmp ICmp.EQ n (constant $ C.Int 64 0)
return (condbr cmp zeroName succName)
block' zeroName $ do
c <- f e
ret c
block' succName $ do
pred <- bind i64 $ sub n (constant $ C.Int 64 1)
c <- go (Registers (Registers (Identity pred) :*: e))
c' <- g (Registers (c :*: rs))
ret c'
--------------------------
-- Code generating tools
--------------------------
defineRecursive :: forall x y m. (Registerable x, Registerable y, Monad m) =>
(
(Operands x -> Build (Named Instruction) m (Operands y)) -> -- recursive call
Build (Named Instruction) m (Operands x) -> -- read parameters
(Operands y -> Build (Named Instruction) m (Named Terminator)) -> -- return results
Build (BasicBlock) m () -- function body
) ->
Build Definition m (
Operands x -> Build (Named Instruction) m (Operands y)) -- call function
defineRecursive def = do
functionName <- getName
inPtrName <- getName
outPtrName <- getName
let
inType = StructureType False . toList $ types (Proxy :: Proxy x)
outType = StructureType False . toList $ types (Proxy :: Proxy y)
outPtrType = ptr outType
inPtrType = ptr inType
go regs = do
inPtr <- bind (ptr inType) $ alloca inType
outPtr <- bind (ptr outType) $ alloca outType
writePtr inPtr regs
instr $ call
(constant $ C.GlobalReference (FunctionType void [ptr outType, ptr inType] False) functionName)
[outPtr, inPtr]
readPtr outPtr
ret regs = do
writePtr (LocalReference outPtrType outPtrName) regs
return (retVoid)
read = readPtr (LocalReference inPtrType inPtrName)
(blocks, _) <- collect (def go read ret)
yield $ global $ define void functionName [(outPtrType, outPtrName), (inPtrType, inPtrName)] blocks
return go
----------------------------
-- Store registers in memory
----------------------------
elemPtrs :: (Monad m, Traversable f) => Operand -> f Type -> Build (Named Instruction) m (f Operand)
elemPtrs struct ts = do
sequence $ number getElemPtr ts
where
getElemPtr t n = bind (ptr t) $ getelementptr struct [C.Int 32 0, C.Int 32 n]
readPtr :: forall r m. (Registerable r, Monad m) => Operand -> Build (Named Instruction) m (Operands r)
readPtr struct = do
let ts = types (Proxy :: Proxy r)
elems <- elemPtrs struct ts
sequence $ (bind <$> ts) <.> (load <$> elems)
writePtr :: forall r m. (Registerable r, Monad m) => Operand -> Operands r -> Build (Named Instruction) m ()
writePtr struct ops = do
let ts = types (Proxy :: Proxy r)
elems <- elemPtrs struct ts
sequence_ $ instr . Do <$> (store <$> ops <.> elems)
--------------------------------------
-- Instructions
--------------------------------------
c :: String -> C.Constant
c str = C.Array i8 (map (C.Int 8 . fromIntegral . ord) str)
constant :: C.Constant -> Operand
constant = ConstantOperand
global :: Global -> Definition
global = GlobalDefinition
add :: Operand -> Operand -> Instruction
add op1 op2 = Add False False op1 op2 []
sub :: Operand -> Operand -> Instruction
sub op1 op2 = Sub False False op1 op2 []
icmp :: ICmp.IntegerPredicate -> Operand -> Operand -> Instruction
icmp c op1 op2 = ICmp c op1 op2 []
alloca :: Type -> Instruction
alloca t = Alloca t Nothing 0 []
getelementptr :: Operand -> [C.Constant] -> Instruction
getelementptr op indices = GetElementPtr True op (map constant indices) []
load :: Operand -> Instruction
load op = Load False op Nothing 0 []
store :: Operand -> Operand -> Instruction
store value addr = Store False addr value Nothing 0 []
call :: Operand -> [Operand] -> Named Instruction
call op params = Do $ Call False CallingConvention.C [] (Right op) (map (\x->(x,[])) params) [] []
retVoid :: Named Terminator
retVoid = Do $ Ret Nothing []
br :: Name -> Named Terminator
br name = Do $ Br name []
condbr :: Operand -> Name -> Name -> Named Terminator
condbr cond true false = Do $ CondBr cond true false []
define :: Type -> Name -> [(Type, Name)] -> [BasicBlock] -> Global
define t name params blocks = functionDefaults {returnType=t, name=name, parameters=(map (\(t,n) -> Parameter t n []) params, False), basicBlocks=blocks}
declare :: Type -> Name -> [Type] -> Bool -> Global
declare t name params vargs = functionDefaults {returnType=t, name=name, parameters=(map (\t -> Parameter t (UnName 0) []) params, vargs)}
var = True
globalConstant :: Name -> Type -> C.Constant -> Global
globalConstant name t c = globalVariableDefaults {name=name, type'=t, initializer=Just c, isConstant=True}
--------------------------------------
-- Pipes
-------------------------------------
bind :: (Monad m) => Type -> Instruction -> Build (Named Instruction) m (Operand)
bind t instruction = do
name <- getName
instr $ name := instruction
return (LocalReference t name)
block :: (Monad m) => Name -> Build (Named Instruction) m (Named Terminator, r) -> Build BasicBlock m r
block name definition = do
(instructions, (terminator, r)) <- collect definition
yield $ BasicBlock name instructions terminator
return r
block' name = block name . (>>= \x -> return (x,()))
collect :: (Monad m) => Pipe a b m r -> Pipe a c m ([b], r)
collect subDef = do
(r, w) <- runWriterP $
hoist lift subDef >->
forever (await >>= \x -> lift $ tell (++[x]))
return (w [], r)
runBuild :: (Monad m) => Build a m r -> m ([a], r)
runBuild subDef =
runEffect $
names (1 :: Integer) >->
collect subDef
where
names e = yield (Name ('n':show e)) >> names (e+1)
---------------------------------
-- Prelude
---------------------------------
formatIn = globalConstant (Name "formatIn") (ArrayType 5 i8) (c "%llu\00")
formatOut = globalConstant (Name "formatOut") (ArrayType 6 i8) (c "%llu\n\00")
scanf = declare i32 (Name "scanf") [ptr i8] var
printf = declare i32 (Name "printf") [ptr i8] var
prelude = [formatIn, formatOut, scanf, printf]
globalRef :: Global -> Operand
globalRef g = ConstantOperand $ C.GlobalReference (ptr (typeOf g)) (name g)
typeOf (GlobalVariable {type'=t}) = t
typeOf (GlobalAlias {type'=t}) = t
typeOf (Function {returnType=returnType, parameters=parameters})
= let (params, isVarArg) = parameters
in FunctionType returnType (map (\(Parameter t _ _) -> t) params) isVarArg
getFormatPtr :: Monad m => Operand -> Build (Named Instruction) m (Operand)
getFormatPtr format = bind (ptr i8) $ getelementptr format [C.Int 32 0,C.Int 32 0]
input :: (Monad m) => PRFCompiled m () Nat
input = BlockLike $ \_ -> RegisterArrow . return $ \_ -> do
dest <- bind (ptr i64) $ alloca i64
fmt <- getFormatPtr (globalRef formatIn)
instr $ call (globalRef scanf) [fmt, dest]
r <- bind i64 $ load dest
return (Registers (Identity r))
output :: (Monad m) => PRFCompiled m Nat ()
output = BlockLike $ \_ -> RegisterArrow . return $ \(Registers (Identity r)) -> do
fmt <- getFormatPtr (globalRef formatOut)
instr $ call (globalRef printf) [fmt, r]
return (Registers Proxy)
-- examples
match :: PrimRec a => a b c -> a (Nat, b) c -> a (Nat, b) c
match fz fs = prec fz (fs . snd)
one :: PrimRec a => a b Nat
one = succ . zero
isZero :: PrimRec a => a Nat Nat
isZero = match one zero . (id &&& id)
isOdd :: PrimRec a => a Nat Nat
isOdd = prec zero (isZero . fst) . (id &&& id)
--
compile :: Monad m => PRFCompiled m () () -> m Module
compile (BlockLike df) = do
let RegisterArrow mf = df Dict
(defs, buildInstructs) <- runBuild mf
(instrs, _) <- runBuild (buildInstructs $ Registers Proxy)
let mod = Module "main" Nothing Nothing (
map global prelude ++
defs ++
[global main])
main = define void (Name "main") [] [BasicBlock (Name "mainBlock") instrs retVoid]
return mod
compileNatNat :: (Monad m) => PRFCompiled m Nat Nat -> m Module
compileNatNat p = compile (output . p . input)
main = do
putStrLn . ppllvm . runIdentity . compileNatNat $ isOdd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment