Created
October 19, 2023 02:38
-
-
Save VictorSanh/99b313903302525fca49d5a53cc3bcf3 to your computer and use it in GitHub Desktop.
Packing and splitting OBELICS style documents for IDEFICS training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import logging | |
IMAGE_TOKEN = "<image>" | |
FAKE_TOKEN_AROUND_IMAGE_V2 = "<fake_token_around_image>" | |
_MIN_LENGTH_DOCUMENTS_TO_PACK = ( | |
5 # Minimum lengths of documents to pack together (lenghts is measures in number of tokens) | |
) | |
_IMAGE_BONUS_VALUE = 2 # The bonus value for tokens preceding the image token | |
logger = logging.getLogger(__name__) | |
def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1): | |
# This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]] | |
# If any of images index are more than num_classes, set them to -1. | |
# Words after the max number of images allowed have been seen don't attend on anything | |
if num_classes != -1: | |
incremental_mask[incremental_mask >= num_classes] = -1 | |
negatives = incremental_mask == -1 | |
incremental_mask[negatives] = 0 | |
attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes) | |
attn_mask[negatives, :] = 0 | |
return attn_mask | |
def image_attention_mask_for_packed_input_ids(input_ids, tokenizer): | |
image_attention_mask = torch.full_like(input_ids, fill_value=-1) | |
next_image_attention_mask = torch.full_like(input_ids, fill_value=-1) | |
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
eod_token_id = tokenizer.eos_token_id | |
for batch_idx in range(input_ids.size(0)): | |
count = -1 | |
seen_eod = False | |
for idx, token_id in enumerate(input_ids[batch_idx]): | |
if token_id == image_token_id: | |
count += 1 | |
image_attention_mask[batch_idx][idx] = count | |
seen_eod = False | |
else: | |
image_attention_mask[batch_idx][idx] = count | |
if seen_eod: | |
image_attention_mask[batch_idx][idx] = -1 | |
if token_id == eod_token_id: | |
seen_eod = True | |
for batch_idx in range(input_ids.size(0)): | |
count = -1 | |
seen_eod = False | |
for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1): | |
token_id = input_ids[batch_idx][idx] | |
if token_id == image_token_id: | |
count += 1 | |
next_image_attention_mask[batch_idx][idx] = count | |
seen_eod = False | |
else: | |
next_image_attention_mask[batch_idx][idx] = count | |
if token_id == eod_token_id: | |
seen_eod = True | |
if seen_eod: | |
next_image_attention_mask[batch_idx][idx] = -1 | |
non_negative_indices = next_image_attention_mask[batch_idx] != -1 | |
next_image_attention_mask[batch_idx][non_negative_indices] -= count | |
next_image_attention_mask[batch_idx][non_negative_indices] *= -1 | |
return image_attention_mask, next_image_attention_mask | |
def split_pack_and_pad( | |
sample, | |
tokenizer, | |
max_seq_len, | |
max_num_images, | |
max_num_samples_per_document=10, | |
prefix_seed=(0, 0), | |
add_begin_of_doc_token=False, | |
add_end_of_doc_token=True, | |
max_num_images_per_document=None, | |
): | |
""" | |
Return a batch of samples in the format expected by the model which | |
includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`, | |
and `next_image_attention_mask`. The `input_ids` are sampled from the document to | |
ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together. | |
For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document` | |
(where the latter is proportional to the length of the document and inversely proportional to the length of subsequences) | |
`input_ids` with sequence length `max_seq_len` from the document. This means that | |
each sample sampled can have different start index. Based on the start index of sample that | |
has been sampled, we also sample a maximum of `max_num_images` images from the document. | |
If there are less than `max_num_images` images in the document, we pad the images with zeros. | |
The start indexes are skewed towards subsequences that contain images. | |
Args: | |
sample (Dict): A sample object containing the document with images and texts. | |
Each of the key contains a list of interleaved elements. | |
For instance, for a given document is represented by two list `images` and `texts` of the same length, where for each position, only one element in the two lists can be NOT None: `images=[image1, None, None, image2, None]`, `texts=[None, text1, text2, None, text3]` | |
tokenizer (PretrainedTokenizer): Text tokenizer to be used. | |
max_seq_len (int): Maximum sequence length of the returned text tokens. | |
max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros. | |
max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10. | |
prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice` | |
Returns: | |
_type_: _description_ | |
""" | |
text_batch = sample["texts"] | |
image_batch = sample.get("images", None) | |
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
last_was_image = False | |
all_images = [] | |
all_texts = [] | |
for raw_images, raw_texts in zip(image_batch, text_batch): | |
# Filter ones that don't have either one image and one text word | |
if not any(raw_images) or not any(raw_texts): | |
continue | |
if max_num_images_per_document: | |
num_images = sum([1 if image is not None else 0 for image in raw_images]) | |
if num_images > max_num_images_per_document: | |
continue | |
splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts] | |
for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts): | |
images, web_text = [], "" | |
for image, text in zip(s_r_ims, s_r_txts): | |
if text is None and image is None: | |
continue | |
if image is not None: | |
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}" | |
images.append(torch.tensor(image)) | |
last_was_image = True | |
elif text is not None: | |
if last_was_image: | |
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}" | |
last_was_image = False | |
else: | |
web_text += f" {text}" if web_text != "" else text | |
if last_was_image: | |
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}" | |
web_text = web_text.strip(" ") | |
# This is mostly a sanity check. Cases like that should not happen at that point. | |
if web_text == "" or len(images) == 0: | |
continue | |
images = torch.stack(images) | |
all_images.append(images) | |
web_text_ids = tokenizer.encode(web_text, add_special_tokens=False) | |
if add_end_of_doc_token: | |
web_text_ids += [tokenizer.eos_token_id] | |
if add_begin_of_doc_token: | |
web_text_ids = [tokenizer.bos_token_id] + web_text_ids | |
all_texts.append(web_text_ids) | |
output_input_ids = [] | |
output_images = [] | |
output_attention_masks = [] | |
output_num_images = [] | |
output_num_text_tokens = [] | |
input_ids_to_pack = [] | |
images_to_pack = [] | |
for images, text in zip(all_images, all_texts): | |
# We save all the documents which are shorter than the max_seq_len to pack them together. | |
if len(text) <= max_seq_len: | |
if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK: # Filter out extremely short sequences | |
continue | |
input_ids_to_pack.extend(text) | |
images_to_pack.extend(images) | |
else: | |
# Computing the bonus scores for tokens near images to skew the sampling towards them | |
# The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled. | |
# Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images. | |
# We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution. | |
# Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added. | |
# Now the remaining question is which precedding tokens do we distribue bonuses to. | |
# We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx]. | |
# For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order). | |
# If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to. | |
# In practice, images will tend to be at the beginning of the sampled sub-sequence. | |
# If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to. | |
# In practice, images will tend to be at the end of the sampled sub-sequence. | |
# To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part. | |
# In practise, images will tend to be in the middle of the sampled sequence. | |
# Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence), | |
# and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image. | |
all_scores = np.array([1] * len(text)) | |
for img_token_idx in np.where(np.array(text) == image_token_id)[0]: | |
all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE | |
# all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents) | |
all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK] | |
# The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length | |
max_num_samples_for_curr_document = len(text) // max_seq_len | |
# Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len) | |
choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice( | |
range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK), # shorter sub-sequences are reserved for packing | |
min( | |
len(text) - max_seq_len, 2 * max_num_samples_per_document | |
), # Sampling more than necessary and then breaking out of the for loop once we have enough samples | |
p=all_scores / np.linalg.norm(all_scores, ord=1), | |
replace=False, | |
) | |
nb_effective_sequences_out_of_sampling = 0 | |
for start_index in choices: | |
image_start_index = text[:start_index].count(image_token_id) | |
text_sub_sequence = text[start_index : start_index + max_seq_len] | |
image_count = text_sub_sequence.count(image_token_id) | |
if image_count == 0: | |
# Skip if there are no images in the sequence | |
continue | |
if len(text_sub_sequence) < max_seq_len: | |
# If the sub-sequence is shorter than max_seq_len, we reserve it for packing | |
# It necessarily mean that the sub-sequence was sampled towards the end of the document, | |
# which implies that we only need the `image_start_index` and not the `image_end_index` | |
if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]): | |
# A safeguard for this | |
logger.warning( | |
"Skipping this sample because of mismatch in actual number of images and " | |
"the '<image>' tokens in the text" | |
) | |
continue | |
input_ids_to_pack.extend(text_sub_sequence) | |
images_to_pack.extend(images[image_start_index:]) | |
continue | |
current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)] | |
if len(current_images) != min(max_num_images, image_count): | |
# A safeguard for something off about this document, maybe `<image>` tag that | |
# by there from before or some issue in parsing the image? | |
logger.warning( | |
"Skipping this sample because of mismatch in actual number of images and " | |
"the '<image>' tokens in the text" | |
) | |
break | |
padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:]) | |
padded_image_tensor[: min(max_num_images, image_count)] = current_images | |
output_images.append(padded_image_tensor) | |
output_num_images.append(min(max_num_images, image_count)) | |
output_input_ids.append(torch.tensor(text_sub_sequence)) | |
output_num_text_tokens.append(len(text_sub_sequence)) | |
attention_mask = torch.ones((max_seq_len,), dtype=torch.long) | |
output_attention_masks.append(attention_mask) | |
nb_effective_sequences_out_of_sampling += 1 | |
if nb_effective_sequences_out_of_sampling >= min( | |
max_num_samples_for_curr_document, max_num_samples_per_document | |
): | |
# We got all the samples we need for this document, so breaking out | |
break | |
# Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack` | |
if input_ids_to_pack: | |
image_counter = 0 | |
for i in range(0, len(input_ids_to_pack), max_seq_len): | |
current_input_ids = input_ids_to_pack[i : i + max_seq_len] | |
unpadded_seq_len = len(current_input_ids) | |
num_images = current_input_ids.count(image_token_id) | |
if num_images == 0: | |
continue | |
current_images = images_to_pack[image_counter : image_counter + num_images] | |
image_counter += num_images | |
if unpadded_seq_len < max_seq_len: | |
padded_input_ids = [tokenizer.pad_token_id] * max_seq_len | |
padded_input_ids[:unpadded_seq_len] = current_input_ids | |
current_input_ids = padded_input_ids | |
elif unpadded_seq_len > max_seq_len: | |
# This case has no purpose other than safeguard | |
continue | |
try: | |
current_images = torch.stack(current_images)[:max_num_images] | |
except Exception: | |
continue | |
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) | |
padded_image_tensor[: current_images.size(0)] = current_images | |
attention_mask = torch.zeros((max_seq_len,), dtype=torch.long) | |
attention_mask[:unpadded_seq_len] = 1 | |
output_images.append(padded_image_tensor) | |
output_input_ids.append(torch.tensor(current_input_ids)) | |
output_num_text_tokens.append(unpadded_seq_len) | |
output_num_images.append(min(max_num_images, num_images)) | |
output_attention_masks.append(attention_mask) | |
if len(output_images) == 0 or len(output_input_ids) == 0: | |
result = { | |
"input_ids": torch.tensor([], dtype=torch.long), | |
"attention_mask": torch.tensor([], dtype=torch.bool), | |
"image_attention_mask": torch.tensor([], dtype=torch.bool), | |
"next_image_attention_mask": torch.tensor([], dtype=torch.bool), | |
"num_images": torch.tensor([], dtype=torch.long), | |
"num_text_tokens": torch.tensor([], dtype=torch.long), | |
"pixel_values": torch.tensor([], dtype=torch.float32), | |
} | |
return result | |
output_input_ids = torch.stack(output_input_ids) | |
output_images = torch.stack(output_images) | |
output_attention_masks = torch.stack(output_attention_masks) | |
# We create two image attention masks: normal and next. | |
# In the normal one, a given text token can only attend to an image that precedes it | |
# In the next-attention_mask, a given text token can only attend to an image that follows it | |
# During training, only one of this image_attention_mask is fed to the model (as `image_attention_mask`), | |
# we flip a coin to decide which one and ditch the other. | |
image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids( | |
output_input_ids, tokenizer | |
) | |
image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images) | |
next_image_attention_mask = incremental_to_binary_attention_mask( | |
next_image_attention_mask, num_classes=max_num_images | |
) | |
result = { | |
"input_ids": output_input_ids, | |
"attention_mask": output_attention_masks, | |
"image_attention_mask": image_attention_mask, | |
"next_image_attention_mask": next_image_attention_mask, | |
"num_images": torch.tensor(output_num_images), | |
"num_text_tokens": torch.tensor(output_num_text_tokens), | |
"pixel_values": output_images | |
} | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment