Skip to content

Instantly share code, notes, and snippets.

@ahue
Created April 14, 2021 20:58
Show Gist options
  • Save ahue/98eb5a7744c24bb2b2b358985acefdc4 to your computer and use it in GitHub Desktop.
Save ahue/98eb5a7744c24bb2b2b358985acefdc4 to your computer and use it in GitHub Desktop.
Filter a 3d array on one dimension using and 2d index to receive a 2d array
import numpy as np
def filter_tensor_by_mx(tensor, mx):
tshape = tensor.shape
row_ix = np.arange(mx.shape[0])
col_ix = np.arange(mx.shape[1])
ix = np.array(np.meshgrid(row_ix, col_ix)).T.reshape(-1,2).T
pane = rnd_ix_mx.ravel()
res = tensor[pane, ix[0], ix[1]].reshape(mx.shape)
return res
rnd_ix_mx = np.random.uniform(0,2,size=(4,3)).astype(int)
print(rnd_ix_mx)
mxt = np.array([[[1,2,3],[4,5,6],[7,8,9],[10,11,12]],[[13,14,15],[16,17,18],[19,20,21],[22,23,24]]])
print(mxt)
print(filter_tensor_by_mx(mxt, rnd_ix_mx))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment