Created September 8, 2024 10:52
module Rock.Memo where
import Control.Concurrent.Lifted
import Control.Exception.Lifted
import Control.Monad
import Data.Dependent.HashMap (DHashMap)
import Data.Dependent.HashMap qualified as DHashMap
import Data.Foldable
import Data.GADT.Compare (GEq)
import Data.GADT.Show (GShow)
import Data.HashMap.Lazy (HashMap)
import Data.HashMap.Lazy qualified as HashMap
import Data.Hashable
import Data.IORef.Lifted
import Data.Kind (Type)
import Data.Maybe
import Data.Some
import Data.Typeable
import Effectful (Eff, IOE, raise, (:>))
import Rock
-- * Implicit memoisation-
-- | Proof that every key permits IO
class HasIOE f where
withIOE :: f es a -> (IOE :> es => Eff es a) -> Eff es a
-- | Remember what @f@ queries have already been performed and their results in
-- a 'DHashMap', and reuse them if a query is performed again a second time.
-- The 'DHashMap' should typically not be reused if there has been some change that
-- might make a query return a different result.
:: forall f
. (forall es. GEq (f es), forall es a. Hashable (f es a), HasIOE f)
=> IORef (DHashMap (HideEffects f) MVar)
-> Rules f
-> Rules f
memoise startedVar rules (key :: f es a) = withIOE key $ do
maybeValueVar <- DHashMap.lookup (HideEffects key) <$> readIORef startedVar
case maybeValueVar of
Nothing -> do
valueVar <- newEmptyMVar
join $ atomicModifyIORef startedVar $ \started ->
case DHashMap.alterLookup (Just . fromMaybe valueVar) (HideEffects key) started of
(Nothing, started') ->
( started'
, do
value <- rules key
putMVar valueVar value
return value
(Just valueVar', _started') ->
(started, readMVar valueVar')
Just valueVar ->
readMVar valueVar
-- * Explicit memoisation
data MemoQuery f es a where
MemoQuery :: f es a -> MemoQuery f (IOE : es) a
-- Don't actually memoise anything
withoutMemoisation :: Rules f -> Rules (MemoQuery f)
withoutMemoisation r (MemoQuery key) = raise $ r key
-- | Remember what @f@ queries have already been performed and their results in
-- a 'DHashMap', and reuse them if a query is performed again a second time.
-- The 'DHashMap' should typically not be reused if there has been some change that
-- might make a query return a different result.
:: forall f
. (forall es. GEq (f es), forall es a. Hashable (f es a))
=> IORef (DHashMap (HideEffects f) MVar)
-> Rules f
-> Rules (MemoQuery f)
memoiseExplicit startedVar rules (MemoQuery (key :: f es a)) = do
maybeValueVar <- DHashMap.lookup (HideEffects key) <$> readIORef startedVar
case maybeValueVar of
Nothing -> do
valueVar <- newEmptyMVar
join $ atomicModifyIORef startedVar $ \started ->
case DHashMap.alterLookup (Just . fromMaybe valueVar) (HideEffects key) started of
(Nothing, started') ->
( started'
, do
value <- raise $ rules key
putMVar valueVar value
return value
(Just valueVar', _started') ->
(started, readMVar valueVar')
Just valueVar ->
readMVar valueVar
newtype Cyclic f = Cyclic (Some f)
deriving (Show)
instance (GShow f, Typeable f) => Exception (Cyclic (f :: Type -> Type))
data MemoEntry a
= Started !ThreadId !(MVar (Maybe a)) !(MVar (Maybe [ThreadId]))
| Done !a
-- | Like 'memoise', but throw @'Cyclic' f@ if a query depends on itself, directly or
-- indirectly.
-- The 'HashMap' represents dependencies between threads and should not be
-- reused between invocations.
:: forall f
. ( Typeable f
, forall es a. Show (f es a)
, forall es. GEq (f es)
, forall es a. Hashable (f es a)
=> IORef (DHashMap (HideEffects f) MemoEntry)
-> IORef (HashMap ThreadId ThreadId)
-> Rules f
-> Rules (MemoQuery f)
memoiseWithCycleDetection startedVar depsVar rules = rules'
rules' (MemoQuery (key :: f es a)) = do
maybeEntry <- DHashMap.lookup (HideEffects key) <$> readIORef startedVar
case maybeEntry of
Nothing -> do
threadId <- myThreadId
valueVar <- newEmptyMVar
waitVar <- newMVar $ Just []
join $ atomicModifyIORef startedVar $ \started ->
case DHashMap.alterLookup (Just . fromMaybe (Started threadId valueVar waitVar)) (HideEffects key) started of
(Nothing, started') ->
( started'
, ( do
value <- raise $ rules key
join $ modifyMVar waitVar $ \maybeWaitingThreads -> do
case maybeWaitingThreads of
Nothing ->
error "impossible"
Just waitingThreads ->
( Nothing
, atomicModifyIORef depsVar $ \deps ->
( foldl' (flip HashMap.delete) deps waitingThreads
, ()
atomicModifyIORef startedVar $ \started'' ->
(DHashMap.insert (HideEffects key) (Done value) started'', ())
putMVar valueVar $ Just value
return value
`catch` \(e :: Cyclic (HideEffects f)) -> do
atomicModifyIORef startedVar $ \started'' ->
(DHashMap.delete (HideEffects key) started'', ())
putMVar valueVar Nothing
throwIO e
(Just entry, _started') ->
(started, waitFor entry)
Just entry -> waitFor entry
waitFor entry =
case entry of
Started onThread valueVar waitVar -> do
threadId <- myThreadId
modifyMVar_ waitVar $ \maybeWaitingThreads -> do
case maybeWaitingThreads of
Nothing ->
return maybeWaitingThreads
Just waitingThreads -> do
join $ atomicModifyIORef depsVar $ \deps -> do
let deps' = HashMap.insert threadId onThread deps
if detectCycle threadId deps'
( deps
, throwIO $ Cyclic $ Some (HideEffects key)
( deps'
, return ()
return $ Just $ threadId : waitingThreads
maybeValue <- readMVar valueVar
maybe (rules' (MemoQuery key)) return maybeValue
Done value ->
return value
detectCycle threadId deps =
go threadId
go tid =
case HashMap.lookup tid deps of
Nothing -> False
Just dep
| dep == threadId -> True
| otherwise -> go dep
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module Rock where
import Data.Dependent.HashMap (DHashMap)
import Data.Dependent.HashMap qualified as DHashMap
import Data.GADT.Compare (GEq, geq)
import Data.GADT.Show (GShow (gshowsPrec))
import Data.Hashable
import Data.IORef.Lifted
import Data.Kind (Type)
import Data.Some
import Data.Typeable
import Effectful (Dispatch (Static), DispatchOf, Eff, Effect, IOE, Subset, inject, raise, (:>))
import Effectful.Dispatch.Static (SideEffects (NoSideEffects), StaticRep, evalStaticRep, getStaticRep)
import Effectful.Timeout (Timeout, timeout)
import Unsafe.Coerce (unsafeCoerce)
-- * Types
type Rules f = forall a es. (f es a -> Eff es a)
data Rock (f :: [Effect] -> Type -> Type) :: Effect
type instance DispatchOf (Rock f) = Static NoSideEffects
newtype instance StaticRep (Rock f) = Rock (forall a es. f es a -> Eff es a)
runRock :: Rules f -> Eff (Rock f : es) a -> Eff es a
runRock r = evalStaticRep (Rock r)
fetch :: (Subset xs es, Rock f :> es) => f xs a -> Eff es a
fetch key = do
Rock f <- getStaticRep
inject (f key)
-- * Running tasks
data TimeoutQuery f es a where
TimeoutQuery :: f es a -> TimeoutQuery f (Timeout : es) (Maybe a)
timeoutRules :: Rules f -> Rules (TimeoutQuery f)
timeoutRules r (TimeoutQuery k) = do
let a = r k
timeout 1000000 (inject a)
-- * Task combinators
data IOQuery f es a where
IOQuery :: f es a -> IOQuery f (IOE : es) a
-- | Track the query dependencies of a 'Task' in a 'DHashMap'.
:: forall f es k g a
. (GEq k, Hashable (Some k), IOE :> es, Rock f :> es)
=> (forall es' a'. f es' a' -> a' -> (k a', g a'))
-> Eff es a
-> Eff es (a, DHashMap k g)
track f = trackM \key value -> pure (f key value)
:: forall f es k g a
. (GEq k, Hashable (Some k), IOE :> es, Rock f :> es)
=> (forall es' a'. f es' a' -> a' -> Eff es' (k a', g a'))
-> Eff es a
-> Eff es (a, DHashMap k g)
trackM f task = do
depsVar <- newIORef mempty
:: ( (forall a' es'. f es' a' -> Eff es' a')
-> (forall a' es'. (IOQuery f) es' a' -> Eff es' a')
record' fetch' (IOQuery key) = do
value <- raise $ fetch' key
(k, g) <- raise $ f key value
atomicModifyIORef depsVar $ (,()) . DHashMap.insert k g
pure value
result <- transRock record' (raise task)
deps <- readIORef depsVar
return (result, deps)
:: forall f g es a
. (Rock f :> es)
=> ( (forall a' es'. f es' a' -> Eff es' a')
-> (forall a' es'. g es' a' -> Eff es' a')
-> Eff (Rock g : es) a
-> Eff es a
transRock f m = do
Rock r <- getStaticRep @(Rock f)
evalStaticRep (Rock (f r)) m
-- * Utils
-- | A GADT for forgetting the effects required for each key
-- The GEq and Eq instances will unsafeCoerce away information on the Effects,
-- please don't rely on it.
-- This is used for using query keys as map keps
type HideEffects :: ([Effect] -> Type -> Type) -> Type -> Type
data HideEffects f a where
HideEffects :: forall f b a. f b a -> HideEffects f a
instance (forall es a. Show (f es a)) => GShow (HideEffects f) where
gshowsPrec prec (HideEffects x) =
showParen (prec > 10) (showString "HideEffects" . showChar ' ' . showsPrec 11 x)
instance (forall es. GEq (f es)) => GEq (HideEffects f) where
geq (HideEffects (a :: f es a)) (HideEffects (b :: f fs b)) =
case geq a (unsafeCoerce b :: f es b) of
Nothing -> Nothing
Just Refl -> Just Refl
instance (forall es. Hashable (f es a)) => Eq (HideEffects f a) where
HideEffects a == HideEffects b = a == unsafeCoerce b
instance (forall es. Hashable (f es a)) => Hashable (HideEffects f a) where
hashWithSalt s (HideEffects f) = hashWithSalt s f
instance (forall a. Hashable (f a), GEq f) => Hashable (Some f) where
hashWithSalt s (Some f) = hashWithSalt s f
{-# LANGUAGE TemplateHaskell #-}
{-# HLINT ignore "Use id" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module Rock.Test where
import Data.GADT.Compare.TH
import Data.GADT.Show.TH
import Data.Hashable
import Data.IORef.Lifted
import Data.Typeable
import Effectful (Eff, IOE, (:>))
import Effectful.Timeout (runTimeout)
import Generics.Kind.Derive.Hashable
import Generics.Kind.TH (deriveGenericK)
import Rock
import Rock.Memo
data Query es a where
QueryInt :: Query '[Rock (MemoQuery Query), IOE, Rock Query2] Int
QueryString :: Query '[IOE] String
data Query2 es a where
Query2Bool :: Bool -> Query2 '[] Bool
deriving instance Eq (Query es a)
deriving instance Typeable (Query es a)
deriving instance Show (Query es a)
deriveGenericK ''Query
deriveGEq ''Query
deriveGCompare ''Query
deriveGShow ''Query
instance Hashable (Query es a) where
hashWithSalt = ghashWithSalt'
data Query' es a where
QueryInt' :: Query' '[Rock Query', IOE, Rock Query2] Int
QueryString' :: Query' '[IOE] String
deriving instance Eq (Query' es a)
deriving instance Typeable (Query' es a)
deriving instance Show (Query' es a)
deriveGenericK ''Query'
deriveGEq ''Query'
deriveGCompare ''Query'
deriveGShow ''Query'
instance Hashable (Query' es a) where
hashWithSalt = ghashWithSalt'
instance HasIOE Query' where
withIOE = \case
QueryInt' -> \x -> x
QueryString' -> \x -> x
testExplicitMemo :: Rules Query
testExplicitMemo = \case
QueryInt -> do
s <- fetch (MemoQuery QueryString)
s' <- fetch (MemoQuery QueryString)
pure (length (s <> s'))
QueryString -> do
sayErr "Querying String"
pure "hello"
testRules' :: Rules Query'
testRules' = \case
QueryInt' -> do
s <- fetch QueryString'
s' <- fetch QueryString'
b <- fetch (Query2Bool False)
pure (length (if b then s else s'))
QueryString' -> do
sayErr "Querying String"
pure "hello"
test2Rules :: Rules Query2
test2Rules = \case
Query2Bool b -> pure (not b)
test :: (IOE :> es) => Eff es Int
test = runRock testRules' . runRock test2Rules $ fetch QueryInt'
-- test :: (IOE :> es) => Eff es Int
-- test = do
-- memMap <- newIORef mempty
-- memThreadMap <- newIORef mempty
-- runRock testRules
-- . runRock (memoiseWithCycleDetection memMap memThreadMap testRules)
-- . runRock test2Rules
-- $ fetch (MemoQuery QueryInt)
testImplicitMemo :: (IOE :> es) => Eff es Int
testImplicitMemo = do
memMap <- newIORef mempty
runRock (memoise memMap testRules')
. runRock test2Rules
$ fetch QueryInt'
test2 :: (IOE :> es) => Eff es (Maybe Int)
test2 =
. runRock testExplicitMemo
. runRock (withoutMemoisation testExplicitMemo)
. runRock (timeoutRules testExplicitMemo)
. runRock test2Rules
$ fetch (TimeoutQuery QueryInt)
