Last active
May 21, 2024 19:20
-
-
Save vwxyzjn/ec4e30cd82f2cad14c7412181eddbc7b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
from datasets import load_dataset | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
def first_true_indices(bools: torch.Tensor, dtype=torch.long): | |
""" | |
Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving | |
the position of the first True in each "row". | |
Returns the length of the rows (bools.size(-1)) if no element is True in a given row. | |
Args: | |
bools (`torch.Tensor`): | |
An N-dimensional boolean tensor. | |
dtype (`torch.dtype`, optional): | |
The desired data type of the output tensor. Defaults to `torch.long`. | |
Returns: | |
`torch.Tensor`: | |
An (N-1)-dimensional tensor of integers indicating the position of the first True | |
in each row. If no True value is found in a row, returns the length of the row. | |
""" | |
row_len = bools.size(-1) | |
zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) | |
return torch.min(zero_or_index, dim=-1).values | |
def get_reward( | |
model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Computes the reward logits and the rewards for a given model and query responses. | |
Args: | |
model (`torch.nn.Module`): | |
The model used to compute the reward logits. | |
query_responses (`torch.Tensor`): | |
The tensor containing the query responses. | |
pad_token_id (`int`): | |
The token ID representing the pad token. | |
context_length (`int`): | |
The length of the context in the query responses. | |
Returns: | |
tuple: | |
- `reward_logits` (`torch.Tensor`): | |
The logits for the reward model. | |
- `final_rewards` (`torch.Tensor`): | |
The final rewards for each query response. | |
- `sequence_lengths` (`torch.Tensor`): | |
The lengths of the sequences in the query responses. | |
""" | |
attention_mask = query_responses != pad_token_id | |
position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum | |
lm_backbone = getattr(model, model.base_model_prefix) | |
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | |
output = lm_backbone( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
return_dict=True, | |
output_hidden_states=True, | |
use_cache=False, # otherwise mistral-based RM would error out | |
) | |
reward_logits = model.score(output.hidden_states[-1]) | |
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length | |
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 | |
return ( | |
reward_logits, | |
reward_logits[ | |
torch.arange(reward_logits.size(0), device=reward_logits.device), | |
sequence_lengths, | |
].squeeze(-1), | |
sequence_lengths, | |
) | |
model = AutoModelForSequenceClassification.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr") | |
tokenizer = AutoTokenizer.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr") | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
dataset = load_dataset( | |
"vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144", | |
split="validation", | |
) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
with torch.no_grad(): | |
# pad from the right | |
_, reward1, _ = get_reward( | |
model, torch.LongTensor(dataset[:4]["query_reference_response_token"]).to(device), tokenizer.pad_token_id, 0 | |
) | |
# prompt pad from the left, response pad from the right | |
query_token = torch.LongTensor(dataset[:4]["query_token"]).to(device) | |
reference_response_token = torch.LongTensor(dataset[:4]["reference_response_token"]).to(device) | |
query_reference_response_token = torch.cat((query_token, reference_response_token), dim=1) | |
_, reward2, _ = get_reward(model, query_reference_response_token, tokenizer.pad_token_id, query_token.size(1)) | |
# different batch sizes | |
_, reward3, _ = get_reward( | |
model, torch.LongTensor(dataset[:2]["query_reference_response_token"]).to(device), tokenizer.pad_token_id, 0 | |
) | |
print(f"{reward1=}") | |
print(f"{reward2=}") | |
print(f"{reward3=}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment