Last active
October 26, 2023 16:11
-
-
Save Gridflare/8ad002cb2a915435c4d612f7f8a10726 to your computer and use it in GitHub Desktop.
Numpy roll with padding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def rollpad(array, rollx, rolly, mode='edge'): | |
"""Analogous to numpy.roll but using np.pad | |
for greater flexibility""" | |
if rollx > 0: | |
xpad = (rollx,0) | |
xtrim = slice(0,-rollx) | |
else: | |
xpad = (0,-rollx) | |
xtrim = slice(-rollx, None) | |
if rolly > 0: | |
ypad = (rolly,0) | |
ytrim = slice(0,-rolly) | |
else: | |
ypad = (0,-rolly) | |
ytrim = slice(-rolly, None) | |
padded = np.pad(array, (ypad, xpad), mode=mode) | |
trimmed = padded[ytrim, xtrim] | |
assert trimmed.shape == array.shape | |
return trimmed | |
if __name__ == '__main__': | |
myarray = np.arange(16).reshape(4,4) | |
print(myarray) | |
print('+x +y') | |
print(rollpad(myarray,1,1)) | |
print('-x -y') | |
print(rollpad(myarray,-1,-1)) | |
print('-x +y') | |
print(rollpad(myarray,-1,1)) | |
print('+x -y') | |
print(rollpad(myarray,1,-1)) | |
print('+2x -y') | |
print(rollpad(myarray,2,-1)) | |
print('-2x -y') | |
print(rollpad(myarray,-2,-1)) | |
print('+x -2y') | |
print(rollpad(myarray,1,-2)) | |
print('+x +2y') | |
print(rollpad(myarray,1,2)) | |
print('0x, +y') | |
print(rollpad(myarray,0,1)) | |
print('-x, 0y') | |
print(rollpad(myarray,-1,0)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment