Created
March 30, 2020 15:09
-
-
Save mrdrozdov/1815cbd096a77f5e10f20479044eeb69 to your computer and use it in GitHub Desktop.
context_insensitive_word_embeddings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import hashlib | |
from allennlp.commands.elmo import ElmoEmbedder | |
from allennlp.data.token_indexers.elmo_indexer import ELMoCharacterMapper | |
import numpy as np | |
def save_elmo_cache(path, vectors): | |
np.save(path, vectors) | |
def load_elmo_cache(path): | |
vectors = np.load(path) | |
return vectors | |
def hash_vocab(tokens, version='v1.0.0'): | |
m = hashlib.sha256() | |
m.update(str.encode(version)) | |
for w in tokens: | |
m.update(str.encode(w)) | |
return m.hexdigest() | |
def context_insensitive_character_embeddings(weights_path, options_path, tokens, cuda=False, cache_dir=None): | |
if cache_dir is not None: | |
key = hash_vocab(tokens) | |
cache_path = os.path.join(cache_dir, 'elmo_{}.npy'.format(key)) | |
if os.path.exists(cache_path): | |
print('Loading cached elmo vectors: {}'.format(cache_path)) | |
return load_elmo_cache(cache_path) | |
if cuda: | |
device = 0 | |
else: | |
device = -1 | |
batch_size = 256 | |
nbatches = len(tokens) // batch_size + 1 | |
elmo = ElmoEmbedder(options_file=options_path, weight_file=weights_path, cuda_device=device) | |
assert tokens[0] == ELMoCharacterMapper.bos_token # <S> | |
assert tokens[1] == ELMoCharacterMapper.eos_token # </S> | |
assert tokens[2] == '_PAD_' | |
elmo.elmo_bilm.create_cached_cnn_embeddings(tokens[2:]) | |
bos_vector = elmo.elmo_bilm._bos_embedding.numpy().reshape(1, -1) | |
eos_vector = elmo.elmo_bilm._eos_embedding.numpy().reshape(1, -1) | |
word_vectors = elmo.elmo_bilm._word_embedding.weight.numpy() | |
vectors = np.concatenate([bos_vector, eos_vector, word_vectors]) | |
if cache_dir is not None: | |
print('Saving cached elmo vectors: {}'.format(cache_path)) | |
save_elmo_cache(cache_path, vectors) | |
return vectors |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment