Created January 25, 2022 17:59
from typing import Dict, Optional
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import AutoModel
class TransformerSlidingWindower(nn.Module):
"""Apply model on a strided sliding window
Good for use with transformer models on long sequences.
We split the sequences in windows of fixed length with an overlap and apply the model on these windows
Some context is preserved because of the overlap.
Below we can see this operation. We have an intut sequence I of length 8.
We create a sliding window, which consists of 2 subsequences S and T of length 5 with overlap 2
and pass each subsequence through BERT.
The output sequence features are pooled (for strided sequence elements). We can use max or mean pooling.
O1 O2 O3 O4 O5 O6 O7 O8
| |
| |
| _______|______
BERT | |
_______|______ |
| | | |
S1 S2 S3 S4 S5 |
T1 T2 T3 T4 T5
I1 I2 I3 I4 I5 I6 I7 I8
This class works well with Bert-esque models from the transformers library.
underlying_model (nn.Module): The model to use as a feature extractor (e.g. BertModel, RobertaModel) etc.
window_size (int): Size of sliding window
stride(int): overlap length
def __init__(
underlying_model: nn.Module,
window_size: int = 512,
stride: int = 128,
stride_aggregation: str = "mean",
pooler_aggregation: str = "mean",
super(TransformerSlidingWindower, self).__init__()
self.underlying_model = underlying_model
self.hidden_size = underlying_model.config.hidden_size
self.window_size = (
window_size - 2
) # Keep 2 elements for [CLS] and sep in each subsequence
self.stride = stride
assert stride_aggregation in [
], "Unsupported stride aggregation method. Only [mean, max] are supported"
assert pooler_aggregation in [
], "Unsupported pooler aggregation method. Only [mean, max] are supported"
self.stride_aggregation = stride_aggregation
self.pooler_aggregation = pooler_aggregation
def slider(self, sequence_length: int):
start_index = 0
end_index = min(sequence_length, self.window_size)
while True:
if sequence_length <= self.window_size:
yield start_index, end_index
ost = start_index
oet = min(end_index, sequence_length)
start_index += self.window_size - self.stride
end_index = start_index + self.window_size
yield ost, oet
if oet >= sequence_length:
def _augment_with_cls(
self, window_input_ids, window_attention_mask, window_token_type_ids
aug_cls = (
torch.zeros_like(window_input_ids[:, 0]).unsqueeze(1) + 101
) # Add CLS tokens
window_input_ids =, window_input_ids), dim=1)
if window_attention_mask is not None:
aug_attmask = window_attention_mask[:, 0].clone().unsqueeze(1)
window_attention_mask = # Use previous attention_mask value
(aug_attmask, window_attention_mask), dim=1
if window_token_type_ids is not None:
aug_ttids = window_token_type_ids[:, 0].clone().unsqueeze(1)
window_token_type_ids = # Use previous token type
(aug_ttids, window_token_type_ids), dim=1
return window_input_ids, window_attention_mask, window_token_type_ids
def _augment_with_sep(
self, window_input_ids, window_attention_mask, window_token_type_ids
aug_sep = (
torch.zeros_like(window_input_ids[:, 0]).unsqueeze(1) + 102
) # Add SEP tokens
window_input_ids =, aug_sep), dim=1)
if window_attention_mask is not None:
aug_attmask = window_attention_mask[:, -1].clone().unsqueeze(1)
window_attention_mask =
(window_attention_mask, aug_attmask), dim=1
if window_token_type_ids is not None:
aug_ttids = window_token_type_ids[:, -1].clone().unsqueeze(1)
window_token_type_ids =, aug_ttids), dim=1)
return window_input_ids, window_attention_mask, window_token_type_ids
def _aggregator(self, x1, x2, selector):
if selector == "mean":
return torch.mean(torch.stack((x1, x2)), dim=0)
elif selector == "max":
return torch.maximum(x1, x2)
raise ValueError(f"Unsupported aggregation method {selector}")
def _aggregate_hidden_states(
self, start_index, end_index, last_hidden_prev, current_hidden
# last_hidden_prev: (B, S, F)
# current_hidden: (B, W, F)
last_hidden_whole_sequence = last_hidden_prev.clone()
if start_index == 0:
last_hidden_whole_sequence[:, start_index:end_index] = current_hidden
return last_hidden_whole_sequence
:, start_index : start_index + self.stride
] = self._aggregator(
:, start_index : start_index + self.stride
current_hidden[:, : self.stride],
) # Average with last strided output
:, start_index + self.stride : end_index
] = current_hidden[
:, self.stride :
] # Set the rest of the hidden states
return last_hidden_whole_sequence
def _aggregate_pooler_output(self, up2now_out, current_out):
return self._aggregator(up2now_out, current_out, self.pooler_aggregation)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
kwargs = {
"past_key_values": None, # incompatible arguments. DON'T USE THIS FOR SEQUENTIAL DECODING
"use_cache": False,
"return_dict": True,
batch_size, sequence_length = input_ids.shape
window_attention_mask, window_token_type_ids = None, None
last_hidden_state = torch.zeros(
(batch_size, sequence_length, self.hidden_size), dtype=torch.float # type: ignore
pooler_output = torch.zeros(
(batch_size, self.hidden_size), dtype=torch.float # type: ignore
for start_index, end_index in self.slider(sequence_length):
window_input_ids = input_ids[:, start_index:end_index]
if attention_mask is not None:
window_attention_mask = attention_mask[:, start_index:end_index]
if token_type_ids is not None:
window_token_type_ids = token_type_ids[:, start_index:end_index]
if start_index > 0:
) = self._augment_with_cls(
window_input_ids, window_attention_mask, window_token_type_ids
if end_index < sequence_length:
) = self._augment_with_sep(
window_input_ids, window_attention_mask, window_token_type_ids
outputs = self.underlying_model(
pooler_output = self._aggregate_pooler_output(
pooler_output, outputs.pooler_output
current_last_hidden_state = outputs.last_hidden_state
if start_index > 0:
current_last_hidden_state = current_last_hidden_state[
:, 1:, :
] # remove extra cls hidden
if end_index < sequence_length:
current_last_hidden_state = current_last_hidden_state[
:, :-1, :
] # remove extra sep hidden
last_hidden_state = self._aggregate_hidden_states(
return last_hidden_state, pooler_output
class TransformerDocumentCRF(nn.Module):
[ BERT Sliding windower ]
O1 O2 O3 ............. ON
/ | \
/ | \
[ARG1 CRF] [Connector CRF] [ARG2 CRF]
def __init__(
multitask_num_tags: Dict[str, int],
multitask_weights: Optional[Dict[str, float]] = None,
pretrained_model: str = "bert-base-uncased",
window_size: int = 512,
stride: int = 128,
stride_aggregation: str = "mean",
pooler_aggregation: str = "mean",
super(TransformerDocumentCRF, self).__init__()
underlying_model = AutoModel.from_pretrained(pretrained_model)
self.sliding_bert = TransformerSlidingWindower(
self.hidden_size = self.sliding_bert.hidden_size
self.projectors = nn.ModuleDict(
task: nn.Linear(self.hidden_size, num_tags)
for task, num_tags in multitask_num_tags.items()
self.decoders = nn.ModuleDict(
task: CRF(num_tags=num_tags, batch_first=True)
for task, num_tags in multitask_num_tags.items()
if multitask_weights is None:
task_weight = 1.0 / len(multitask_num_tags)
self.multitask_weights = { # Same weight for all losses if not provided
task: task_weight for task in multitask_num_tags.keys()
self.multitask_weights = multitask_weights
def forward(self, input_ids, tags, attention_mask=None, token_type_ids=None):
emissions, _ = self.sliding_bert(
negative_loglik = 0
task_losses = {}
for task, decoder in self.decoders.items():
w = self.multitask_weights[task]
emissions_logits = self.projectors[task](emissions)
task_loss = decoder(
emissions_logits, tags[task], mask=attention_mask.type(torch.bool)
negative_loglik += w * task_loss
task_losses[task] = task_loss
return negative_loglik, task_losses
def decode(self, input_ids, attention_mask=None, token_type_ids=None):
emissions, _ = self.sliding_bert(
tags = {}
for task, decoder in self.decoders.items():
tags[task] = decoder.decode(
) # , mask=attention_mask.type(torch.bool)
return tags
if __name__ == "__main__":
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
