Last active
May 5, 2023 18:57
-
-
Save saharNooby/bb54519a7d3735afb6949825608c00f0 to your computer and use it in GitHub Desktop.
Probably the dumbest, no-dependencies, pure Python implementation of 20B_tokenizer.json (a BPE tokenizer for GPT-NeoX model)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import regex | |
import json | |
import unicodedata | |
from typing import Tuple, Callable, Union | |
# Parses the tokenizer config and returns encode and decode functions. | |
def load_tokenizer(config_path: str) -> Tuple[Callable[[str], list[int]], Callable[[list[int]], str]]: | |
# Maps any byte 0..255 to a printable Unicode character. | |
byte_to_unicode: dict[int, str] = { | |
33: "!", | |
34: "\"", | |
35: "#", | |
36: "$", | |
37: "%", | |
38: "&", | |
39: "\'", | |
40: "(", | |
41: ")", | |
42: "*", | |
43: "+", | |
44: ",", | |
45: "-", | |
46: ".", | |
47: "/", | |
48: "0", | |
49: "1", | |
50: "2", | |
51: "3", | |
52: "4", | |
53: "5", | |
54: "6", | |
55: "7", | |
56: "8", | |
57: "9", | |
58: ":", | |
59: ";", | |
60: "<", | |
61: "=", | |
62: ">", | |
63: "?", | |
64: "@", | |
65: "A", | |
66: "B", | |
67: "C", | |
68: "D", | |
69: "E", | |
70: "F", | |
71: "G", | |
72: "H", | |
73: "I", | |
74: "J", | |
75: "K", | |
76: "L", | |
77: "M", | |
78: "N", | |
79: "O", | |
80: "P", | |
81: "Q", | |
82: "R", | |
83: "S", | |
84: "T", | |
85: "U", | |
86: "V", | |
87: "W", | |
88: "X", | |
89: "Y", | |
90: "Z", | |
91: "[", | |
92: "\\", | |
93: "]", | |
94: "^", | |
95: "_", | |
96: "`", | |
97: "a", | |
98: "b", | |
99: "c", | |
100: "d", | |
101: "e", | |
102: "f", | |
103: "g", | |
104: "h", | |
105: "i", | |
106: "j", | |
107: "k", | |
108: "l", | |
109: "m", | |
110: "n", | |
111: "o", | |
112: "p", | |
113: "q", | |
114: "r", | |
115: "s", | |
116: "t", | |
117: "u", | |
118: "v", | |
119: "w", | |
120: "x", | |
121: "y", | |
122: "z", | |
123: "{", | |
124: "|", | |
125: "}", | |
126: "~", | |
161: "¡", | |
162: "¢", | |
163: "£", | |
164: "¤", | |
165: "¥", | |
166: "¦", | |
167: "§", | |
168: "¨", | |
169: "©", | |
170: "ª", | |
171: "«", | |
172: "¬", | |
174: "®", | |
175: "¯", | |
176: "°", | |
177: "±", | |
178: "²", | |
179: "³", | |
180: "´", | |
181: "µ", | |
182: "¶", | |
183: "·", | |
184: "¸", | |
185: "¹", | |
186: "º", | |
187: "»", | |
188: "¼", | |
189: "½", | |
190: "¾", | |
191: "¿", | |
192: "À", | |
193: "Á", | |
194: "Â", | |
195: "Ã", | |
196: "Ä", | |
197: "Å", | |
198: "Æ", | |
199: "Ç", | |
200: "È", | |
201: "É", | |
202: "Ê", | |
203: "Ë", | |
204: "Ì", | |
205: "Í", | |
206: "Î", | |
207: "Ï", | |
208: "Ð", | |
209: "Ñ", | |
210: "Ò", | |
211: "Ó", | |
212: "Ô", | |
213: "Õ", | |
214: "Ö", | |
215: "×", | |
216: "Ø", | |
217: "Ù", | |
218: "Ú", | |
219: "Û", | |
220: "Ü", | |
221: "Ý", | |
222: "Þ", | |
223: "ß", | |
224: "à", | |
225: "á", | |
226: "â", | |
227: "ã", | |
228: "ä", | |
229: "å", | |
230: "æ", | |
231: "ç", | |
232: "è", | |
233: "é", | |
234: "ê", | |
235: "ë", | |
236: "ì", | |
237: "í", | |
238: "î", | |
239: "ï", | |
240: "ð", | |
241: "ñ", | |
242: "ò", | |
243: "ó", | |
244: "ô", | |
245: "õ", | |
246: "ö", | |
247: "÷", | |
248: "ø", | |
249: "ù", | |
250: "ú", | |
251: "û", | |
252: "ü", | |
253: "ý", | |
254: "þ", | |
255: "ÿ", | |
0: "Ā", | |
1: "ā", | |
2: "Ă", | |
3: "ă", | |
4: "Ą", | |
5: "ą", | |
6: "Ć", | |
7: "ć", | |
8: "Ĉ", | |
9: "ĉ", | |
10: "Ċ", | |
11: "ċ", | |
12: "Č", | |
13: "č", | |
14: "Ď", | |
15: "ď", | |
16: "Đ", | |
17: "đ", | |
18: "Ē", | |
19: "ē", | |
20: "Ĕ", | |
21: "ĕ", | |
22: "Ė", | |
23: "ė", | |
24: "Ę", | |
25: "ę", | |
26: "Ě", | |
27: "ě", | |
28: "Ĝ", | |
29: "ĝ", | |
30: "Ğ", | |
31: "ğ", | |
32: "Ġ", | |
127: "ġ", | |
128: "Ģ", | |
129: "ģ", | |
130: "Ĥ", | |
131: "ĥ", | |
132: "Ħ", | |
133: "ħ", | |
134: "Ĩ", | |
135: "ĩ", | |
136: "Ī", | |
137: "ī", | |
138: "Ĭ", | |
139: "ĭ", | |
140: "Į", | |
141: "į", | |
142: "İ", | |
143: "ı", | |
144: "IJ", | |
145: "ij", | |
146: "Ĵ", | |
147: "ĵ", | |
148: "Ķ", | |
149: "ķ", | |
150: "ĸ", | |
151: "Ĺ", | |
152: "ĺ", | |
153: "Ļ", | |
154: "ļ", | |
155: "Ľ", | |
156: "ľ", | |
157: "Ŀ", | |
158: "ŀ", | |
159: "Ł", | |
160: "ł", | |
173: "Ń" | |
} | |
# Reverse of byte_to_unicode. | |
unicode_to_bytes: dict[str, int] = {byte_to_unicode[b]: b for b in byte_to_unicode.keys()} | |
with open(config_path, 'r') as f: | |
config = json.loads(f.read()) | |
vocab: dict[str, int] = config['model']['vocab'] | |
merges: list[str] = config['model']['merges'] | |
added_tokens: dict[int, str] = {t['id']: t['content'] for t in config['added_tokens']} | |
for added_token_id in added_tokens: | |
encoded_added_token = ''.join([byte_to_unicode[b] for b in added_tokens[added_token_id].encode('utf-8')]) | |
vocab[encoded_added_token] = added_token_id | |
vocab_reversed: dict[int, str] = {vocab[t]: t for t in vocab.keys()} | |
def replace_subsequence(lst: list, a: list, b: list) -> None: | |
for i in range(len(lst)): | |
if lst[i:i + len(a)] == a: | |
lst[i:i + len(a)] = b | |
def split_words(s: str) -> list[str]: | |
result: list[str] = [] | |
pattern: regex.Pattern = regex.compile(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+") | |
for m in regex.finditer(pattern, s): | |
result.append(m.group()) | |
return result | |
def encode_added_tokens(s: str) -> list[Union[int, str]]: | |
result: list[Union[int, str]] = [] | |
remainder: str = s | |
while len(remainder) > 0: | |
nearest_pos = len(remainder) | |
nearest_token = -1 | |
for added_token_id in added_tokens: | |
pos = remainder.find(added_tokens[added_token_id]) | |
if pos != -1 and pos < nearest_pos: | |
nearest_pos = pos | |
nearest_token = added_token_id | |
if nearest_pos == len(remainder): | |
result.append(remainder) | |
break | |
if nearest_pos != 0: | |
result.append(remainder[:nearest_pos]) | |
result.append(nearest_token) | |
remainder = remainder[nearest_pos + len(added_tokens[nearest_token]):] | |
return result | |
# Converts a string to a list of tokens. | |
def encode(s: str) -> list[int]: | |
s = unicodedata.normalize('NFC', s) | |
result: list[int] = [] | |
for part in encode_added_tokens(s): | |
if type(part) == int: | |
result.append(part) | |
continue | |
for word in split_words(part): | |
tokens = [vocab[byte_to_unicode[b]] for b in word.encode('utf-8')] | |
for added_token_id in added_tokens.keys(): | |
added_token_tokens = [vocab[byte_to_unicode[b]] for b in added_tokens[added_token_id].encode('utf-8')] | |
replace_subsequence(tokens, added_token_tokens, [added_token_id]) | |
for merge in merges: | |
space = merge.find(' ') | |
assert space != -1 | |
token_a = vocab[merge[0:space]] | |
token_b = vocab[merge[space + 1:]] | |
token_merged = vocab[merge[0:space] + merge[space + 1:]] | |
for i in range(len(tokens) - 1): | |
if i + 1 < len(tokens) and tokens[i] == token_a and tokens[i + 1] == token_b: | |
# Replace and shift | |
tokens[i] = token_merged | |
tokens[i + 1:] = tokens[i + 2:] | |
result += tokens | |
return result | |
# Converts list of tokens to a string. | |
def decode(tokens: list[int]) -> str: | |
result = bytes() | |
for token in tokens: | |
result += bytes([unicode_to_bytes[c] for c in vocab_reversed[token]]) | |
return result.decode('utf-8') | |
return encode, decode | |
# Code below is to test correctness of the tokenizer. | |
# It may safely be removed. | |
def test() -> None: | |
config_path = r"./20B_tokenizer.json" | |
encode, decode = load_tokenizer(config_path) | |
from tokenizers import Tokenizer | |
tokenizer = Tokenizer.from_file(config_path) | |
# --- | |
test_strings = [ | |
'\n a', | |
# An ambigious edge case, should tokenize into ["\n", " ~"], not ["\n ", "~"] | |
# This test will fail unless tokenizer splits words by the regex above | |
'\n ~', | |
'\n \u597d', | |
# Special tokens | |
'-> <|endoftext|><|padding|> int', | |
# Just some Unicode stuff | |
'I\'ll \'d test блабла 以下は、]) -> <|endoftext|><|padding|> int', | |
# This test will fail unless tokenizer does normalization to NFC | |
"κόσμε" | |
] | |
import random | |
r = random.Random(42) | |
for i in range(256): | |
test_strings += [' ' * i] | |
for i in range(256): | |
x = chr(r.randrange(0, 256)) | |
x = x * r.randrange(1, 32) | |
try: | |
x.encode('utf-8') | |
test_strings += [x] | |
except: | |
pass | |
for i in range(256): | |
x = chr(r.randrange(0, 1114112)) | |
x = x * r.randrange(1, 4) | |
try: | |
x.encode('utf-8') | |
test_strings += [x] | |
except: | |
pass | |
for test_string in test_strings: | |
print() | |
print(json.dumps(test_string)) | |
encoded_expected = tokenizer.encode(test_string).ids | |
print('expect', encoded_expected) | |
encoded_actual = encode(test_string) | |
print('actual', encoded_actual) | |
assert str(encoded_expected) == str(encoded_actual) | |
decoded_actual = decode(encoded_actual) | |
print(json.dumps(decoded_actual)) | |
assert unicodedata.normalize('NFC', test_string) == decoded_actual | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment