Skip to content

Instantly share code, notes, and snippets.

@pao
Last active July 29, 2019 15:15
Show Gist options
  • Save pao/3677645 to your computer and use it in GitHub Desktop.
Save pao/3677645 to your computer and use it in GitHub Desktop.
Work in progress overkill solution to final 7L7W exercise

Now being developed at https://github.com/pao/Monads.jl and available from Pkg as "Monads".

Full documentation is available at https://monadsjl.rtfd.org/.

module Monads

# types
export Monad, Identity, MList, Maybe, State
# combinators
export mreturn, join, fmap, bind, mcomp, mthen, (>>)
# utilities
export liftM
# do syntax
export @mdo
# MonadPlus
export MonadPlus, mzero, mplus, guard
# State
export runState, put, get, evalState, execState

import Base.*

abstract Monad
abstract MonadPlus <: Monad

## Buy two monad combinators, get the third free!
mreturn{M<:Monad}(::Type{M}, val) = M(val)
join(m::Monad) = bind(identity, m)
fmap{M<:Monad}(f::Function, m::M) = bind(m) do x
    mreturn(M, f(x))
end
bind(f::Function, m::Monad) = join(fmap(f, m))

## Extra combinators
mcomp(g::Function, f::Function) = x -> bind(g, f(x))
mthen(k::Monad, m::Monad) = bind(_ -> k, m)
(>>)(m::Monad, k::Monad) = mthen(k, m)

## A MonadPlus function
guard{M<:MonadPlus}(::Type{M}, c::Bool) = c ? mreturn(M, nothing) : mzero(M)

## Friendly monad blocks
macro mdo(mtype, body)
    esc(mdo_desugar(mdo_patch(mtype, body)))
end

## patch up functions to insert the right monad
mdo_patch(mtype, expr) = expr
function mdo_patch(mtype, expr::Expr)
    expr.args = map(arg->mdo_patch(mtype, arg), expr.args)
    if expr.head == :return
        expr.head = :call
        insert(expr.args, 1, :mreturn)
    end
    if expr.head == :call && any(expr.args[1] .== [:mreturn, :mzero, :guard, :liftM])
        insert(expr.args, 2, mtype)
    end
    expr
end

## desugaring mdo syntax is a right fold
mdo_desugar(exprIn) = reduce(mdo_desugar_helper, :(), reverse(exprIn.args))
mdo_desugar_helper(rest, expr) = rest
function mdo_desugar_helper(rest, expr::Expr)
    if expr.head == :call && expr.args[1] == :(<-)
        # replace "<-" with monadic binding
        quote
            bind($(expr.args[3])) do $(expr.args[2])
                $rest
            end
        end
    elseif expr.head == :(=)
        # replace assignment with let binding
        quote
            let
                $expr
                $rest
            end
        end
    elseif expr.head == :line
        rest
    elseif rest == :()
        expr
    else
        # replace with sequencing
        :(mthen($rest, $expr))
    end
end

## Function lifting
liftM{M<:Monad}(::Type{M}, f::Function) = m1 -> @mdo M begin
    x1 <- m1
    return f(x1)
end

## Starting slow: Identity
type Identity{T} <: Monad
    value :: T
end

bind(f::Function, m::Identity) = f(m.value)

## List
type MList <: MonadPlus
    value :: Vector
end
MList(v) = MList([v])

function join(m::MList)
    if !isempty(m.value)
        val = nothing
        for arr in m.value[1:end]
            if !isempty(arr.value)
                if val === nothing
                    val = arr.value
                else
                    append!(val, arr.value)
                end
            end
        end
        if val === nothing
            mzero(MList)
        else
            mreturn(MList, val)
        end
    else
        mzero(MList)
    end
end
fmap(f::Function, m::MList) = MList(map(f, m.value))

# It's also a MonadPlus
mzero(::Type{MList}) = MList([])
mplus(m1::MList, m2::MList) = join(MList([m1, m2]))

## Maybe
type Maybe{T} <: Monad
    value :: Union(T, Nothing)
end

bind(f::Function, m::Maybe) = isa(m.value, Nothing) ? Maybe(nothing) : f(m.value)

## State
type State <: Monad
    runState :: Function # s -> (a, s)
end
state(f) = State(f)

runState(s::State) = s.runState
runState(s::State, st) = s.runState(st)

function bind(f::Function, s::State)
      state(st -> begin
          (x, stp) = runState(s, st)
          runState(f(x), stp)
            end
            )
end
mreturn(::Type{State}, x) = state(st -> (x, st))

put(newState) = state(_ -> (nothing, newState))
get() = state(st -> (st, st))

evalState(s::State, st) = runState(s, st)[1]
execState(s::State, st) = runState(s, st)[2]

end
@tautologico
Copy link

When I try

@mdo begin
  a <- Maybe(2)
  return a
end

I get an error saying @mdo is missing an argument. Trying instead

@mdo Maybe begin
a <- Maybe(2)
return a
end

works. However, if I try to do any operation on the value of a and b like

@mdo begin
a <- Maybe(2)
b <- Maybe(3)
return a + 2*b
end

Then it complains about missing methods. Am I doing something wrong?

@pao
Copy link
Author

pao commented Oct 10, 2012

Should all be fixed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment