Created
May 12, 2023 15:26
-
-
Save saharNooby/c4f341cb14d3f9f6c826fa25dd3484d4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import time | |
import numpy as np | |
import torch | |
from typing import List, Tuple, Union | |
from torch.nn import functional as F | |
from numpy.linalg import norm | |
from RWKV_model import tokenizer, RWKV_RNN, args_430M | |
from util.sampling import tail_free_sampling | |
######################################################################################################## | |
context = "..." | |
args = args_430M | |
NUM_GENERATIONS = 1 | |
TOKENS_PER_GENERATION = 200 | |
# Generation will be SEARCH_K times slower | |
SEARCH_K = 5 | |
ALPHA = 0.5 | |
USE_TAIL_FREE_SAMPLING = False | |
DEBUG = False | |
######################################################################################################## | |
def debug(*args): | |
if DEBUG: | |
print(*args) | |
model = RWKV_RNN(args) | |
model.warm_up() | |
representations = [] | |
print('Preprocessing context') | |
start = time.time() | |
prompt_tokens = tokenizer.encode(context).ids | |
prompt_token_count = len(prompt_tokens) | |
init_out, init_state = None, None | |
for i in range(prompt_token_count): | |
init_out, init_state = model.forward(prompt_tokens[i], init_state, save_representation=True) | |
representations.append(model.representation) | |
if prompt_token_count < 5 or i % (prompt_token_count // 5) == 0: | |
print(f'{i}/{prompt_token_count}') | |
delay = time.time() - start | |
print('Took %.3f sec, %d tokens in context, %d ms per token' % (delay, prompt_token_count, delay / prompt_token_count * 1000)) | |
def cosine_similarity(x: np.ndarray, y: np.ndarray): | |
return np.dot(x, y) / (norm(x) * norm(y)) | |
for GENERATION in range(NUM_GENERATIONS): | |
print(f'\n--- Generation {GENERATION} ---\n') | |
print(context, end='[') | |
start = time.time() | |
out, state = init_out.clone(), init_state.clone() | |
for TOKEN in range(TOKENS_PER_GENERATION): | |
debug('') | |
debug('out', out) | |
if USE_TAIL_FREE_SAMPLING: | |
out = tail_free_sampling(out) | |
debug('out', out) | |
probs = F.softmax(out, dim=-1).cpu().numpy() | |
debug('probs', probs) | |
top_tokens = (-probs).argsort() | |
debug('top_tokens', top_tokens) | |
first_zero_index = np.where(probs[top_tokens] == 0.0)[0][0] | |
debug('first_zero_index', first_zero_index) | |
top_tokens = top_tokens[:min(max(1, first_zero_index), SEARCH_K)] | |
else: | |
probs = F.softmax(out, dim=-1).cpu().numpy() | |
top_tokens = (-probs).argsort()[:SEARCH_K] | |
debug('top_tokens', top_tokens) | |
top_tokens_probs = probs[top_tokens] | |
debug('top_tokens_probs', top_tokens_probs) | |
top_tokens_max_similarities = np.zeros_like(top_tokens, dtype=float) | |
next_states_and_representations: List[Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [None] * len(top_tokens) | |
for j in range(len(top_tokens)): | |
token = top_tokens[j] | |
candidate_out, candidate_state = model.forward(token, state.clone(), save_representation=True) | |
candidate_representation = model.representation | |
max_similarity = 0 | |
for representation in representations: | |
max_similarity = max(max_similarity, cosine_similarity(representation.cpu().numpy(), candidate_representation.cpu().numpy())) | |
top_tokens_max_similarities[j] = max_similarity | |
next_states_and_representations[j] = (candidate_out, candidate_state, candidate_representation) | |
debug('top_tokens_max_similarities', top_tokens_max_similarities) | |
top_tokens_scores = (1 - ALPHA) * top_tokens_probs - ALPHA * top_tokens_max_similarities | |
debug('top_tokens_scores', top_tokens_scores) | |
selected_token_index = np.argmax(top_tokens_scores) | |
debug('selected_token_index', selected_token_index) | |
token = top_tokens[selected_token_index] | |
debug('token', token) | |
out, state, representation = next_states_and_representations[selected_token_index] | |
representations.append(representation) | |
if DEBUG: | |
print(json.dumps(tokenizer.decode([token]))) | |
else: | |
print(tokenizer.decode([token]), end='') | |
delay = time.time() - start | |
print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / TOKENS_PER_GENERATION * 1000)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment