Last active
June 5, 2024 03:15
-
-
Save vwxyzjn/a9c952e0d2baf603b394f20145532d5e to your computer and use it in GitHub Desktop.
The generation logprobs and forward logprobs are different under bf16
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
import torch.nn.functional as F | |
import transformers | |
torch.set_printoptions(precision=4, sci_mode=False) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--bf16", action="store_true") | |
parser.add_argument("--fp16", action="store_true") | |
parser.add_argument("--fp32", action="store_true") | |
parser.add_argument("--seed", type=int, default=1) | |
args = parser.parse_args() | |
print(args) | |
torch.manual_seed(args.seed) | |
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/pythia-1b-deduped", padding_side="right") | |
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |
pad_id = tokenizer.pad_token_id | |
if args.bf16: | |
policy = transformers.AutoModelForCausalLM.from_pretrained( | |
"EleutherAI/pythia-1b-deduped", torch_dtype=torch.bfloat16 | |
) | |
elif args.fp16: | |
policy = transformers.AutoModelForCausalLM.from_pretrained( | |
"EleutherAI/pythia-1b-deduped", torch_dtype=torch.float16 | |
) | |
elif args.fp32: | |
policy = transformers.AutoModelForCausalLM.from_pretrained( | |
"EleutherAI/pythia-1b-deduped", torch_dtype=torch.float32 | |
) | |
device = torch.device("cuda") | |
policy = policy.to(device) | |
policy.generation_config.pad_token_id = policy.generation_config.eos_token_id | |
query = torch.tensor( | |
[ | |
[pad_id, pad_id, tokenizer.eos_token_id], | |
[pad_id, pad_id, tokenizer.eos_token_id], | |
] | |
).to(device) | |
temperature = 0.7 | |
context_length = query.shape[1] | |
def forward(model, query_responses, tokenizer): | |
attention_mask = query_responses != tokenizer.pad_token_id | |
position_ids = attention_mask.cumsum(1) - attention_mask.long() | |
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) | |
return model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
return_dict=True, | |
output_hidden_states=True, | |
) | |
def generate(lm_backbone, queries, tokenizer, generation_config): | |
"""generate in a way that does not affect padding tokens""" | |
context_length = queries.shape[1] | |
attention_mask = queries != tokenizer.pad_token_id | |
input_ids = torch.masked_fill(queries, ~attention_mask, 0) | |
output = lm_backbone.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
# position_ids=attention_mask.cumsum(1) - attention_mask.long(), # already handled in generation | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
# output_scores=True, | |
output_logits=True, | |
) | |
# logits = torch.stack(output.scores, 1) | |
return torch.cat((queries, output.sequences[:, context_length:]), dim=1), output.logits | |
generation_config = transformers.GenerationConfig( | |
max_new_tokens=50, | |
min_new_tokens=50, | |
temperature=temperature, | |
top_k=0.0, | |
top_p=1.0, | |
do_sample=True, | |
) | |
query_response, logits = generate(policy, query, tokenizer, generation_config) | |
logits = torch.stack(logits, 1) | |
logits /= temperature | |
response = query_response[:, context_length:] | |
all_logprob = F.log_softmax(logits, dim=-1) | |
generation_logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) | |
print(f"{generation_logprob[:,:5]=}") | |
output = forward(policy, query_response, tokenizer) | |
logits = output.logits[:, context_length - 1 : -1] | |
logits /= temperature | |
all_logprob = F.log_softmax(logits, dim=-1) | |
forward_logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) | |
print(f"{forward_logprob[:,:5]=}") | |
ratio = (generation_logprob - forward_logprob).exp() | |
print(f"{ratio=}") | |
print(f"ratio.mean()={ratio.mean().item()}") | |
print(f"ratio.std()={ratio.std().item()}") | |
print(f"ratio.max()={ratio.max().item()}") | |
print(f"ratio.min()={ratio.min().item()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment