Created
May 13, 2017 18:31
-
-
Save tforgione/d24d4b279615a5bc29b101e766f4ca67 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "math/Matrix.hpp" | |
#include <fstream> | |
#include <memory> | |
#include <array> | |
#include <iostream> | |
#include <algorithm> | |
enum class BranchIndex { | |
TopLeft = 0, | |
TopRight, | |
BottomRight, | |
BottomLeft | |
}; | |
template<typename T, typename Value> | |
class QTree | |
{ | |
public: | |
template<typename BranchCallback, typename LeafCallback> | |
QTree(vl::MatrixView<T>& matrix, BranchCallback b, LeafCallback l) : | |
matrix(matrix), value(nullptr), branches() | |
{ | |
if (matrix.row_number() == 0 || matrix.col_number() == 0) { | |
value = std::make_unique<Value>(b(std::array<Value*, 4>{})); | |
} else if (matrix.row_number() == 1 && matrix.col_number() == 1) { | |
value = std::make_unique<Value>(l(matrix[0])); | |
} else { | |
std::size_t mid_height = matrix.row_number() / 2; | |
std::size_t mid_width = matrix.col_number() / 2; | |
std::size_t cpl_height = matrix.row_number() - mid_height; | |
std::size_t cpl_width = matrix.col_number() - mid_width; | |
std::array<vl::MatrixView<T>, 4> views {{ | |
matrix.submat(0, 0, mid_height, mid_width), | |
matrix.submat(mid_height, 0, cpl_height, mid_width), | |
matrix.submat(mid_height, mid_width, cpl_height, cpl_width), | |
matrix.submat(0, mid_width, mid_height, cpl_width) | |
}}; | |
std::transform( | |
std::begin(views), std::end(views), std::begin(branches), | |
[&b, &l](vl::MatrixView<T>& view) { | |
return std::make_unique<QTree<T, Value>>(view, b, l); | |
} | |
); | |
value = std::make_unique<Value>(b(std::array<Value*, 4>{{ | |
branches[0]->value.get(), | |
branches[1]->value.get(), | |
branches[2]->value.get(), | |
branches[3]->value.get() | |
}})); | |
} | |
} | |
public: | |
QTree& get_branch(BranchIndex b) { | |
return *branches[static_cast<std::size_t>(b)]; | |
} | |
vl::MatrixView<T> matrix; | |
std::unique_ptr<Value> value; | |
std::array<std::unique_ptr<QTree<T, Value>>, 4> branches; | |
}; | |
int main(int argc, char *argv[]) | |
{ | |
vl::Matrix<double> matrix(481, 321, 1.0); | |
vl::MatrixView<double> view{matrix}; | |
using Value = double; | |
auto branch_function = [] (std::array<Value*, 4> branches) -> Value { | |
Value v = 0.0; | |
for (auto const& branch : branches) { | |
if (branch != nullptr) { | |
v += *branch; | |
} | |
} | |
return v; | |
}; | |
auto leaf_function = [] (double i) { return i; }; | |
QTree<double, Value> tree(view, branch_function, leaf_function); | |
std::cout << *tree.value << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment