Skip to content

Instantly share code, notes, and snippets.

@dayyass
Created October 13, 2021 12:04
Show Gist options
  • Save dayyass/b64cdf552576b64c282c3cf06f39b062 to your computer and use it in GitHub Desktop.
Save dayyass/b64cdf552576b64c282c3cf06f39b062 to your computer and use it in GitHub Desktop.
Inverse function for torch.nn.utils.rnn.pack_sequence.
import torch
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
def unpack_sequence(packed_sequences):
"""Unpacks PackedSequence into a list of variable length Tensors"""
unpacked_sequences = []
padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)
max_length = padded_sequences.shape[1]
idx = torch.arange(max_length)
for seq, length in zip(padded_sequences, lengths):
mask = idx < length
unpacked_seq = seq[mask]
unpacked_sequences.append(unpacked_seq)
return unpacked_sequences
a = torch.tensor([1,2,3])
b = torch.tensor([4,5])
c = torch.tensor([6])
sequences = [a, b, c]
packed_sequences = pack_sequence(sequences)
unpacked_sequences = unpack_sequence(packed_sequences)
assert all([torch.allclose(a, b) for a, b in zip(sequences, unpacked_sequences)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment