Redirects HF paper pages to arXiv.
Chrome: https://chrome.google.com/webstore/detail/redirector/ocgpenflpmgnfapjedencafcfakcekcd
Firefox: https://addons.mozilla.org/en-US/firefox/addon/redirector/
Redirects HF paper pages to arXiv.
Chrome: https://chrome.google.com/webstore/detail/redirector/ocgpenflpmgnfapjedencafcfakcekcd
Firefox: https://addons.mozilla.org/en-US/firefox/addon/redirector/
config = configuration_gpt_neox.GPTNeoXConfig() | |
hf_model = modeling_gpt_neox.GPTNeoXForCausalLM(config).half().cuda() | |
checkpoint_path = "/path/to/global_step150000" | |
loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt")) | |
loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt")) | |
hf_model.gpt_neox.embed_in.load_state_dict({"weight": torch.cat([ | |
loaded_tp1["word_embeddings.weight"], | |
loaded_tp2["word_embeddings.weight"], | |
], dim=0)}) | |
for layer_i in display.trange(44): |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
import transformers.modeling_bert as modeling_bert | |
import nlpr.shared.torch_utils as torch_utils | |
import nlpr.shared.model_resolution as model_resolution |