Skip to content

Instantly share code, notes, and snippets.

@dpwright
Last active June 21, 2018 03:09
Show Gist options
  • Save dpwright/6474401 to your computer and use it in GitHub Desktop.
Save dpwright/6474401 to your computer and use it in GitHub Desktop.
Monadic operations in C++

Monadic operations in C++

This began as a further attempt to implement the Maybe monad in C++, but quickly spiralled out of control and now includes an implementation of the List monad as well (using std::list!). This is really for my own amusement rather than to try and do anything useful with them. It also gave me an excuse to try out C++ 11 lambda functions for the first time.

My original implementation defined a macro called MBind which caused a number of problems -- thankfully PJayB managed to find a way around that so this version uses C++ lambdas directly (take a look at the changelog if you want to see the changes made).

Please bear in mind I made this purely out of curiosity -- I realise a lot of this is pathological in C++.

I've also included the equivalent Haskell code, so you can see what I based the syntax on. In Haskell you'd probably use do syntax for a case like this, but I stuck with manually binding lambda functions so the similarity with the C++ is more obvious.

Next up... The State monad!

//listm.h
//
//The List monad.
//Provide a way to build up lists over a sequence of operations. Adds monadic
//functionality to std::list
#pragma once
#include <list>
#include <type_traits>
#include "monad.h"
template<> struct Monad<std::list>
{
template<typename a> static const std::list<a> unit (a value)
{
return std::list<a>(1, value);
}
};
template<typename a, typename b>
auto operator>>=(const std::list<a>& in, const b&& f) -> decltype(f(in.front()))
{
typedef typename std::remove_const<decltype(f(in.front()))>::type tmpList;
tmpList out;
for(auto i = in.begin(); i != in.end(); ++i)
{
tmpList current = f(*i);
out.splice(out.end(), current);
}
return out;
}
//maybe.h
//
//The Maybe type.
//Provides a safer way to return possible failure than NULL values. Of course,
//since C++ doesn't have pattern matching, it is down to you to ensure that you
//always call isJust() to confirm there is a value available before calling
//fromJust() to retrieve it.
#pragma once
#include <cassert>
template<typename a> class Maybe
{
public:
static const Maybe<a> Just(const a value) { return Maybe(value); }
static const Maybe<a> Nothing() { return Maybe(); }
const bool isJust() const { return m_valid; }
const a fromJust() const { assert(isJust());
return m_value; }
private:
Maybe() : m_valid(false) {}
Maybe(const a value) : m_value(value), m_valid(true) {}
Maybe(const a value, const bool valid) : m_value(value), m_valid(valid) {}
a m_value;
const bool m_valid;
};
template<typename a> const Maybe<a> Just(const a value)
{
return Maybe<a>::Just(value);
}
template<typename a> const Maybe<a> Nothing()
{
return Maybe<a>::Nothing();
}
//maybem.h
//
//The Maybe monad.
//Adds monadic behaviour to Maybe. Allows you to sequence a number of
//operations of type Maybe<a>, and drop out at the first point of failure.
#pragma once
#include "monad.h"
#include "maybe.h"
template<> struct Monad<Maybe>
{
template<typename a> static const Maybe<a> unit (a value)
{
return Just(value);
}
};
template<typename a, typename b>
auto operator>>=(const Maybe<a>&& in, const b&& f) -> decltype(f(in.fromJust()))
{
typedef decltype(f(in.fromJust())) maybeType;
return in.isJust() ? f(in.fromJust()) : maybeType::Nothing();
}
//monad.h
//
//Defines the two monad operations. Every class which wants to be treated as a
//monad must provide a specialization for the Monad class or a constructor which
//takes a single parameter of the type of the valuse it's wrapping. As well as
//that, It must also overload operator>>= for binding purposes.
#pragma once
template <template <typename a, typename...> class m>
struct Monad
{
template <typename a> static m<a> unit(a value) { return m<a>(value); }
};
#include <iostream>
#include <cmath>
#include "monad.h"
#include "maybem.h"
#include "listm.h"
#include "show.h"
using namespace std;
template<typename a> const Maybe<a> imSquare(a x) {
return (abs(x) >= 0x8000) ? Maybe<a>::Nothing() : Maybe<a>::Just(x * x);
}
const Maybe<int> testMaybe()
{
return Monad<Maybe>::unit(0x4000) >>= [&] (const int a) {
return imSquare(a) >>= [&] (const int b) {
int c = b / a;
return Monad<Maybe>::unit(0.5f) >>= [&] (const float d) {
return imSquare(c); };};};
}
template<typename a, typename b>
const list< pair<a, b> > cartesianProduct(const list<a> xs, const list<b> ys)
{
return xs >>= [&] (const a x) {
return ys >>= [&] (const b y) {
pair<a, b> product(x, y);
return Monad<list>::unit(product); };};
}
const list< pair<int, int> > testList()
{
static const int LIST_SIZE=5;
list<int> xs_list;
for(int i = 1; i <= LIST_SIZE; ++i)
xs_list.push_back(i);
list<int> ys_list;
for(int i = 1; i <= LIST_SIZE; ++i)
ys_list.push_back(i*i);
return cartesianProduct(xs_list, ys_list);
}
int main(int argc, char** argv)
{
cout << show(testMaybe()) << endl;
cout << show(testList()) << endl;
return 0;
}
imSquare x | abs x >= 0x8000 = Nothing
| otherwise = Just (x * x)
testMaybe =
return 0x4000 >>= \a ->
imSquare a >>= \b ->
let c = b `div` a in
return 0.5 >>= \d ->
imSquare c
cartesianProduct xs ys =
xs >>= \x ->
ys >>= \y ->
return (x, y)
testList = cartesianProduct xs ys
where xs = [1..5]
ys = map sq xs
sq x = x * x
main = print testMaybe >> print testList
//show.h
//
//Defines the "show" operation. Defaults to using the stream operator.
#pragma once
#include <sstream>
template<typename t> std::string show(t x)
{
std::stringstream out;
out << x;
return out.str();
}
template<typename a> std::string show(Maybe<a> x)
{
if(x.isJust()) {
std::stringstream out;
out << "Just " << x.fromJust();
return out.str();
} else return "Nothing";
}
template<typename a, typename b> std::string show(std::pair<a, b> x)
{
std::stringstream out;
out << "(" << x.first << "," << x.second << ")";
return out.str();
}
template<typename a> std::string show(std::list<a> x)
{
std::stringstream out;
out << "[";
for(auto i = x.begin(); i != x.end(); ++i)
{
out << show(*i);
if(i != --x.end())
out << ",";
}
out << "]";
return out.str();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment