Last active
August 13, 2020 01:43
-
-
Save binshengliu/23e516a2ac596b37a61130c66c4a4a54 to your computer and use it in GitHub Desktop.
Convert different lengths into mask tensor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# lens = [3, 5, 4] | |
# What we want: | |
# mask = [[1, 1, 1, 0, 0], | |
# [1, 1, 1, 1, 1], | |
# [1, 1, 1, 1, 0]] | |
# https://stackoverflow.com/questions/53403306/how-to-batch-convert-sentence-lengths-to-masks-in-pytorch | |
def len_to_mask(lens: np.ndarray, seq_len: Optional[int] = None) -> np.ndarray: | |
if seq_len is None: | |
seq_len = max(lens) | |
return np.arange(seq_len)[None, :] < lens[:, None] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment