Last active
August 29, 2015 14:13
-
-
Save eickenberg/7e2e7e620b246ca34551 to your computer and use it in GitHub Desktop.
Maxpooling with arbitrary pooling strides and pooling shapes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Maxpooling with arbitrary pooling strides and pooling shapes | |
# Based on theano.tensor.signal.downsample.max_pool_2d. This | |
# operation is repeated the minimum necessary times to account for | |
# all stride steps. | |
#Author: Michael Eickenberg, michael.eickenberg@nsup.org | |
import theano | |
import numpy as np | |
import theano.tensor as T | |
from theano.tensor.signal.downsample import max_pool_2d | |
import numbers | |
# Could use fractions.gcd, but this works | |
def _gcd(num1, num2): | |
"""Calculate gcd(num1, num2), greatest common divisor, using euclid's | |
algorithm""" | |
while (num2 != 0): | |
if num1 > num2: | |
num1, num2 = num2, num1 | |
num2 -= (num2 // num1) * num1 | |
return num1 | |
def _lcm(num1, num2): | |
"""Calculate least common multiple of num1 and num2""" | |
return num1 * num2 / _gcd(num1, num2) | |
def fancy_max_pool(input_tensor, pool_shape, pool_stride, | |
ignore_border=False): | |
"""Using theano built-in maxpooling, create a more flexible version. | |
Obviously suboptimal, but gets the work done.""" | |
if isinstance(pool_shape, numbers.Number): | |
pool_shape = pool_shape, | |
if isinstance(pool_stride, numbers.Number): | |
pool_stride = pool_stride, | |
if len(pool_shape) == 1: | |
pool_shape = pool_shape * 2 | |
if len(pool_stride) == 1: | |
pool_stride = pool_stride * 2 | |
lcmh, lcmw = [_lcm(p, s) for p, s in zip(pool_shape, pool_stride)] | |
dsh, dsw = lcmh / pool_shape[0], lcmw / pool_shape[1] | |
pre_shape = input_tensor.shape[:-2] | |
length = T.prod(pre_shape) | |
post_shape = input_tensor.shape[-2:] | |
new_shape = T.concatenate([[length], post_shape]) | |
reshaped_input = input_tensor.reshape(new_shape, ndim=3) | |
sub_pools = [] | |
for sh in range(0, lcmh, pool_stride[0]): | |
sub_pool = [] | |
sub_pools.append(sub_pool) | |
for sw in range(0, lcmw, pool_stride[1]): | |
full_pool = max_pool_2d(reshaped_input[:, sh:, | |
sw:], | |
pool_shape, ignore_border=ignore_border) | |
ds_pool = full_pool[:, ::dsh, ::dsw] | |
concat_shape = T.concatenate([[length], ds_pool.shape[-2:]]) | |
sub_pool.append(ds_pool.reshape(concat_shape, ndim=3)) | |
output_shape = (length, | |
T.sum([l[0].shape[1] for l in sub_pools]), | |
T.sum([i.shape[2] for i in sub_pools[0]])) | |
output = T.zeros(output_shape) | |
for i, line in enumerate(sub_pools): | |
for j, item in enumerate(line): | |
output = T.set_subtensor(output[:, i::lcmh / pool_stride[0], | |
j::lcmw / pool_stride[1]], | |
item) | |
return output.reshape(T.concatenate([pre_shape, output.shape[1:]]), | |
ndim=input_tensor.ndim) | |
def _test_fancy_max_pool(ndarray, pool_shape, pool_stride): | |
shared = theano.shared(ndarray) | |
pooled = fancy_max_pool(shared, pool_shape, pool_stride, | |
ignore_border=True).eval() | |
# replicate pooling with sklearn | |
from sklearn.feature_extraction.image import extract_patches | |
# need to pad pool_shapes and strides with ones | |
extra_dims = ndarray.shape[:-2] | |
pool_shape = (1,) * len(extra_dims) + pool_shape | |
pool_stride = (1,) * len(extra_dims) + pool_stride | |
patches = extract_patches(ndarray, pool_shape, pool_stride) | |
patches_max = patches.reshape(patches.shape[:ndarray.ndim] + (-1,) | |
).max(axis=-1) | |
from numpy.testing import assert_array_equal | |
assert_array_equal(pooled, patches_max) | |
def test_fancy_max_pool(array_shapes=((1, 1, 5, 7), (2, 3, 8, 4)), | |
pool_shapes=((2, 2), (3, 3), (2, 3), (3, 2)), | |
pool_strides=((1, 1), (2, 2), (3, 3), | |
(1, 2), (2, 1)), | |
random_seed=42, num_tests=2): | |
rng = np.random.RandomState(random_seed) | |
from itertools import product | |
for i in range(num_tests): | |
for array_shape, pool_shape, pool_stride in product( | |
array_shapes, pool_shapes, pool_strides): | |
arr = rng.randn(*array_shape) | |
_test_fancy_max_pool(arr, pool_shape, pool_stride) | |
if __name__ == "__main__": | |
test_fancy_max_pool() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment