Skip to content

Instantly share code, notes, and snippets.

@shreyansh26
Forked from huchenxucs/pos_embed.py
Created April 24, 2023 06:29
Show Gist options
  • Save shreyansh26/86b77036725b61b5c46fd47eea7d97fe to your computer and use it in GitHub Desktop.
Save shreyansh26/86b77036725b61b5c46fd47eea7d97fe to your computer and use it in GitHub Desktop.
T5 relative positional embedding
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class RelativePositionBias(nn.Module):
def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=2):
super(RelativePositionBias, self).__init__()
self.bidirectional = bidirectional
self.num_buckets = num_buckets
self.max_distance = max_distance
self.n_heads = n_heads
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger buckets
for larger absolute relative_positions. All relative positions >=max_distance
map to the same bucket. All relative positions <=-max_distance map to the
same bucket. This should allow for more graceful generalization to longer
sequences than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
# now n is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = n < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def compute_bias(self, qlen, klen):
""" Compute binned relative position bias """
context_position = torch.arange(qlen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[:, None]
memory_position = torch.arange(klen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
"""
k
0 1 2 3
q -1 0 1 2
-2 -1 0 1
-3 -2 -1 0
"""
rp_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
return values
def forward(self, qlen, klen):
return self.compute_bias(qlen, klen) # shape (1, num_heads, qlen, klen)
# torch>=1.5.0 F.multi_head_attention_forward
"""
if attn_mask is not None:
# add relative positional embedding to atttn_mask # shape (N*num_heads, L, S)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = softmax(
attn_output_weights, dim=-1)
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment