Created
October 20, 2023 17:31
-
-
Save kouroshHakha/d89e117f60b8b4dc08c282a7b70a31f6 to your computer and use it in GitHub Desktop.
Studies the diff on precision when loading in fp16 or bf16
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from safetensors import safe_open | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
tensors = {} | |
model_ckpt = "/home/ray/default/7b-chat-lora-ckpt/adapter_model.safetensors" | |
with safe_open(model_ckpt, framework="pt") as f: | |
for k in f.keys(): | |
tensors[k] = f.get_tensor(k) | |
plt.figure() | |
for dtype in [torch.float16, torch.bfloat16]: | |
fp32_np_tensors = {} | |
for k in tensors: | |
fp32_np_tensors[k] = (tensors[k].abs()).to(dtype=dtype).to(dtype=torch.float32).numpy() | |
all_data = np.concatenate([v.flatten() for v in fp32_np_tensors.values()]) | |
# x axis should be log scale | |
epsilon = 1e-10 | |
# Create logarithmically spaced bins | |
bin_edges = 2 ** np.linspace(np.log2(max(min(all_data), epsilon)), np.log2(max(all_data)), 100) | |
# Plot the histogram | |
plt.hist(all_data, bins=bin_edges, log=True, label=str(dtype), alpha=0.5) | |
plt.xscale('log', base=2) | |
plt.legend() | |
plt.savefig("histogram.png") | |
breakpoint() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment