Skip to content

Instantly share code, notes, and snippets.

@kouroshHakha
Created October 20, 2023 17:31
Show Gist options
  • Save kouroshHakha/d89e117f60b8b4dc08c282a7b70a31f6 to your computer and use it in GitHub Desktop.
Save kouroshHakha/d89e117f60b8b4dc08c282a7b70a31f6 to your computer and use it in GitHub Desktop.
Studies the diff on precision when loading in fp16 or bf16
from safetensors import safe_open
import torch
import numpy as np
import matplotlib.pyplot as plt
tensors = {}
model_ckpt = "/home/ray/default/7b-chat-lora-ckpt/adapter_model.safetensors"
with safe_open(model_ckpt, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
plt.figure()
for dtype in [torch.float16, torch.bfloat16]:
fp32_np_tensors = {}
for k in tensors:
fp32_np_tensors[k] = (tensors[k].abs()).to(dtype=dtype).to(dtype=torch.float32).numpy()
all_data = np.concatenate([v.flatten() for v in fp32_np_tensors.values()])
# x axis should be log scale
epsilon = 1e-10
# Create logarithmically spaced bins
bin_edges = 2 ** np.linspace(np.log2(max(min(all_data), epsilon)), np.log2(max(all_data)), 100)
# Plot the histogram
plt.hist(all_data, bins=bin_edges, log=True, label=str(dtype), alpha=0.5)
plt.xscale('log', base=2)
plt.legend()
plt.savefig("histogram.png")
breakpoint()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment