Created
March 10, 2019 17:47
-
-
Save gelisam/64674fd783584ef8eeab7a11bb72c2a5 to your computer and use it in GitHub Desktop.
dynamic programming using recursion schemes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Solving a dynamic programming in many ways, including using existing | |
-- recursion schemes and by defining new ones. The problem of solving this | |
-- particular problem using recursion schemes was posed by Sandy Maguire. | |
{-# LANGUAGE FlexibleContexts, RankNTypes, TypeApplications, TypeFamilies, ScopedTypeVariables #-} | |
{-# OPTIONS -Wno-orphans #-} | |
module Dyna where | |
import Test.DocTest | |
import Data.Functor.Foldable (Base, Fix, Recursive(project), Corecursive(embed, ana), hylo, cataA) | |
import Control.Comonad.Trans.Env (EnvT(..), ask) | |
import Control.Monad.State (State, get, modify, evalState) | |
import Data.Array (Array, (!)) | |
import Data.Ix (Ix(inRange)) | |
import Data.List (maximumBy, transpose) | |
import Data.Map (Map) | |
import Data.Maybe (catMaybes, fromMaybe, listToMaybe) | |
import Data.Ord (comparing) | |
import Data.Tree (Tree(Node)) | |
import qualified Data.Array as Array | |
import qualified Data.Map as Map | |
-------------------------------------------------------------------------------- | |
-- type synonyms -- | |
-------------------------------------------------------------------------------- | |
type Pos = (Int,Int) | |
type Path = [Int] | |
type Result = Maybe Path -- Nothing if there are no decreasing paths to the goal | |
type Cache = Map Pos Result | |
-------------------------------------------------------------------------------- | |
-- Array helpers -- | |
-------------------------------------------------------------------------------- | |
-- | partial if the [[a]] isn't rectangular | |
-- | |
-- >>> array2D [] | |
-- array ((1,1),(0,0)) [] | |
-- >>> array2D [[],[],[]] | |
-- array ((1,1),(0,3)) [] | |
-- >>> array2D [[1,2],[3,4],[5,6]] | |
-- array ((1,1),(2,3)) [((1,1),1),((1,2),3),((1,3),5),((2,1),2),((2,2),4),((2,3),6)] | |
array2D :: [[a]] -> Array Pos a | |
array2D [] = Array.listArray ((1,1), (0,0)) [] | |
array2D xss = Array.listArray ((1,1), (w,h)) | |
$ concat | |
$ transpose | |
$ xss | |
where | |
w = length $ fromMaybe [] $ listToMaybe xss | |
h = length xss | |
inArray :: Ix i | |
=> Array i a -> i -> Bool | |
inArray = inRange . Array.bounds | |
-------------------------------------------------------------------------------- | |
-- Decreasing paths -- | |
-------------------------------------------------------------------------------- | |
-- The task is to find the longest decreasing path from one corner of a grid to | |
-- the other. "Path" means a sequence of cells with are next to each other, | |
-- either vertically or horizontally. "Decreasing" means that the contents of | |
-- each cell is strictly smaller than the contents of the previous cell along | |
-- the path. | |
-- Instead of parameterizing everything by a grid, we hardcode the grid here in | |
-- order to make the code shorter. Note that the longest path moves in all | |
-- cardinal directions, so unlike with the longest-common-subsequence problem, | |
-- the sub-problems aren't always to the bottom and to the right of the current | |
-- position. If they were, we could use a 'histo' on the Peano encoding of the | |
-- number of positions, as that would give us access to the answers to all our | |
-- sub-problems. But this problem is harder, so we will use memoization instead. | |
-- (Alternatively, we could have sorted the positions by their contents, and | |
-- solved the positions with the smaller contents first) | |
grid :: Array Pos Int | |
grid = array2D | |
[ [30,29,28,27,26] | |
, [ 9,10,11,12,25] | |
, [ 8,17,16,13,24] | |
, [ 7,18,15,14,23] | |
, [ 6,19,20,21,22] | |
, [ 5, 4, 3, 2, 1] | |
] | |
start, goal :: Pos | |
(start, goal) = Array.bounds grid | |
decreasing :: Pos -> Pos -> Bool | |
decreasing ij ij' = (grid ! ij) > (grid ! ij') | |
neighbours :: Pos -> [Pos] | |
neighbours (i,j) = filter (decreasing (i,j)) | |
$ filter (inArray grid) | |
$ [(i-1,j), (i+1,j), (i,j-1), (i,j+1)] | |
extendResult :: Pos -> [Result] -> Result | |
extendResult ij results | ij == goal = Just [grid ! ij] | |
| otherwise = case xss of | |
[] -> Nothing | |
_ -> Just (x:xs) | |
where | |
x :: Int | |
x = grid ! ij | |
xss :: [Path] | |
xss = catMaybes results | |
-- partial if xss is empty | |
xs :: Path | |
xs = maximumBy (comparing length) xss | |
-------------------------------------------------------------------------------- | |
-- Without recursion-schemes nor caching -- | |
-------------------------------------------------------------------------------- | |
-- | Longest Decreasing Path | |
-- | |
-- >>> ldp | |
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | |
ldp :: Result | |
ldp = go start | |
where | |
-- Check every single path, keeping the longest one. Note that since the | |
-- paths are strictly-decreasing, there can be no cycles, so this will | |
-- terminate. | |
go :: Pos -> Result | |
go ij = extendResult ij $ fmap go $ neighbours ij | |
-------------------------------------------------------------------------------- | |
-- Without recursion-schemes, with caching -- | |
-------------------------------------------------------------------------------- | |
-- | | |
-- >>> ldpCached | |
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | |
ldpCached :: Result | |
ldpCached = flip evalState Map.empty | |
$ cachedGo start | |
where | |
-- Depth-first-search, caching the longest decreasing path for the nodes we | |
-- have already visited. | |
go :: Pos -> State Cache Result | |
go ij = do | |
results <- traverse cachedGo (neighbours ij) | |
pure $ extendResult ij results | |
cachedGo :: Pos -> State Cache Result | |
cachedGo ij = do | |
cache <- get | |
case Map.lookup ij cache of | |
Just result -> pure result | |
Nothing -> do | |
result <- go ij | |
modify (Map.insert ij result) | |
pure result | |
-------------------------------------------------------------------------------- | |
-- With recursion-schemes, no caching -- | |
-------------------------------------------------------------------------------- | |
-- ldp's recursion structure is shaped like a Rose tree, so we can use a 'hylo' | |
-- on that recursive type. | |
type TreeF a = EnvT a [] | |
type instance Base (Tree a) = TreeF a | |
instance Recursive (Tree a) where | |
project (Node a subtrees) = EnvT a subtrees | |
instance Corecursive (Tree a) where | |
embed (EnvT a subtrees) = Node a subtrees | |
-- | | |
-- >>> ldpRs | |
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | |
ldpRs :: Result | |
ldpRs = hylo conquer divide start | |
-- or, equivalently: | |
-- cata conquer $ ana @(Tree Pos) divide start | |
divide :: Pos -> TreeF Pos Pos | |
divide ij = EnvT ij (neighbours ij) | |
conquer :: TreeF Pos Result -> Result | |
conquer (EnvT ij results) = extendResult ij results | |
-------------------------------------------------------------------------------- | |
-- With recursion-schemes and caching -- | |
-------------------------------------------------------------------------------- | |
-- Instead of using 'cata' to combine results, we can use 'cata' to combine | |
-- 'State' computations, so that the final computation computes the desired | |
-- result. The recently-added 'cataA' recursion scheme guides us in that | |
-- direction by allowing us to choose in which order we want to run the | |
-- sub-computations; or, in the case of caching, whether we want to run the | |
-- sub-computations at all! | |
-- | | |
-- >>> ldpRsCached | |
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | |
ldpRsCached :: Result | |
ldpRsCached = flip evalState Map.empty | |
$ cataA cachedConquer | |
$ ana @(Tree Pos) divide | |
$ start | |
cachedConquer :: TreeF Pos (State Cache Result) | |
-> State Cache Result | |
cachedConquer (EnvT ij subcomputations) = do | |
cache <- get | |
case Map.lookup ij cache of | |
Just result -> do | |
-- note that we do not run the sub-computations! | |
pure result | |
Nothing -> do | |
subresults <- sequenceA subcomputations | |
let result = conquer (EnvT ij subresults) | |
modify (Map.insert ij result) | |
pure result | |
-------------------------------------------------------------------------------- | |
-- Capturing the pattern in a new recursion scheme -- | |
-------------------------------------------------------------------------------- | |
-- The idea of caching the results of a 'cata' in order to skip some of the | |
-- sub-trees is hardly unique to this problem, so it might be useful to capture | |
-- it in a new recursion scheme. I should probably add it to recursion-schemes! | |
-- | | |
-- >>> ldpCachedCata | |
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | |
ldpCachedCata :: Result | |
ldpCachedCata = cachedCata ask sequenceA conquer | |
$ ana @(Tree Pos) divide | |
$ start | |
-- I generalize from @Tree k@ to some arbitrary 't', and so I need to ask the | |
-- caller how they want to compute the key. They are not allowed to look at the | |
-- sub-results as they do so. | |
-- | |
-- I also ask in which order they want to run the sub-computations if the key | |
-- isn't found in the cache; most of the time, this will be 'sequenceA', but I | |
-- don't want to presume. | |
cachedCata :: forall t k a. (Recursive t, Ord k) | |
=> (forall x. Base t x -> k) | |
-> (forall f x. Applicative f => Base t (f x) -> f (Base t x)) | |
-> (Base t a -> a) | |
-> t -> a | |
cachedCata getKey sequenceEffects fAlgebra = flip evalState Map.empty | |
. cataA go | |
where | |
go :: Base t (State (Map k a) a) | |
-> State (Map k a) a | |
go fsa = do | |
let k = getKey fsa | |
cache <- get | |
case Map.lookup k cache of | |
Just a -> pure a | |
Nothing -> do | |
fa <- sequenceEffects fsa | |
let a = fAlgebra fa | |
modify (Map.insert k a) | |
pure a | |
-------------------------------------------------------------------------------- | |
-- Capturing dynamic programming in a new recursion scheme -- | |
-------------------------------------------------------------------------------- | |
-- Dynamic programming is a specific use case for caching. So we can write a | |
-- specialized recursion-scheme which captures the idea that we can divide a | |
-- problem into sub-problems, and we can cache the result for all the | |
-- sub-problems in order to get better performance. | |
-- | |
-- This works out nicely, as I implemented 'extendResult' and 'neighbours' | |
-- because I wanted to reduce duplication in the other implementations, not | |
-- because I knew in advance that I wanted to implement a recursion scheme with | |
-- dyna's type! | |
-- | |
-- I should probably add 'dyna' to recursion-schemes as well, but under a | |
-- different name, as "dynamorphism" is already an established recursion scheme | |
-- (which isn't provided by recursion-schemes yet). | |
-- | | |
-- >>> ldpDyna | |
-- Just [30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | |
ldpDyna :: Result | |
ldpDyna = dyna extendResult neighbours start | |
dyna :: forall f a b. (Traversable f, Ord a) | |
=> (a -> f b -> b) | |
-> (a -> f a) | |
-> a -> b | |
dyna solve fCoalgebra = cachedCata ask sequenceA fAlgebra | |
. ana @(Fix (EnvT a f)) fCoalgebra' | |
where | |
fAlgebra :: EnvT a f b -> b | |
fAlgebra (EnvT problem subsolutions) = solve problem subsolutions | |
fCoalgebra' :: a -> EnvT a f a | |
fCoalgebra' a = EnvT a (fCoalgebra a) | |
-------------------------------------------------------------------------------- | |
-- Running the doctests -- | |
-------------------------------------------------------------------------------- | |
main :: IO () | |
main = doctest ["src/Dyna.hs"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is wicked cool, Sam! Does any of it work if you don't have
start
orgoal
available statically?