Skip to content

Instantly share code, notes, and snippets.

@aboSamoor
Last active December 23, 2015 06:59
Show Gist options
  • Save aboSamoor/6597632 to your computer and use it in GitHub Desktop.
Save aboSamoor/6597632 to your computer and use it in GitHub Desktop.
{
"metadata": {
"name": ""
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"from io import open\n",
"from glob import glob\n",
"from numpy import asarray"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"files = glob('/media/jarra/word2vec/Word2VecExample/vectors/*')\n",
"words = []\n",
"embeddings = []\n",
"for file in files:\n",
" for line in open(file):\n",
" ws = line.strip().split()\n",
" if len(ws) != 251:\n",
" continue\n",
" words.append(ws[0])\n",
" embeddings.append([float(x) for x in ws[1:]])\n",
"embeddings = asarray(embeddings)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 54
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from sklearn.neighbors import NearestNeighbors\n",
"\n",
"knn = NearestNeighbors(n_neighbors=20, algorithm='ball_tree', p=2)\n",
"knn.fit(embeddings)\n",
"\n",
"word_id = {w:i for i, w in enumerate(words)}\n",
"id_word = {i:w for i, w in enumerate(words)}"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 74
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def get_neighbors(word):\n",
" id_ = word_id[word]\n",
" point = embeddings[id_]\n",
" get_knn_point(point)\n",
"\n",
"def get_knn_point(point):\n",
" distances, indices = knn.kneighbors(point)\n",
" for i, (index, d) in enumerate(zip(indices[0], distances[0])):\n",
" print '{: <3}{: <20}{: <10}'.format(i, id_word[index], d)\n",
" \n",
"def get_knn_phrase(phrase):\n",
" accumulator = get_phrase_point(phrase)\n",
" get_knn_point(accumulator)\n",
" \n",
"def get_phrase_point(phrase):\n",
" accumulator = 0\n",
" for word in phrase.strip().split():\n",
" accumulator += embeddings[word_id[word]]\n",
" return accumulator"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 82
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nearest Neighbours of words\n",
"============================\n",
"\n",
"This is a typical experiment where we sort the words by their Euclidean distance.\n",
"\n",
"**Observations**\n",
"\n",
"* apple is mainly a company in these embeddings.\n",
"\n",
"* book and books are close to each other. The space is partitioned by the meaning not by part of speech.\n",
"\n",
"* The neighbors of \"the\" are not quite expected."
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_neighbors('apple')"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 apple 0.0 \n",
"1 ibook 2.45441387259\n",
"2 ilife 2.6324920856\n",
"3 iifx 2.68339841114\n",
"4 laserwriter 2.68894626153\n",
"5 macintoshes 2.72100867622\n",
"6 macintosh 2.7465602912\n",
"7 hypercard 2.74976202767\n",
"8 emac 2.76713004891\n",
"9 pcjr 2.76880134334\n",
"10 iix 2.77605432771\n",
"11 filemaker 2.78940634298\n",
"12 iigs 2.80644647114\n",
"13 macwrite 2.82640246101\n",
"14 appleworks 2.8268557242\n",
"15 iici 2.82994404995\n",
"16 ultrix 2.83341137466\n",
"17 desknote 2.8353197512\n",
"18 idvd 2.84045199398\n",
"19 garageband 2.84403010264\n"
]
}
],
"prompt_number": 86
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_neighbors('january')"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 january 0.0 \n",
"1 march 0.990577776379\n",
"2 february 0.995405360592\n",
"3 december 1.00770528005\n",
"4 april 1.06711900891\n",
"5 november 1.12036214401\n",
"6 october 1.15229411791\n",
"7 july 1.20908610312\n",
"8 august 1.22016681435\n",
"9 september 1.33152866717\n",
"10 june 1.4797197978\n",
"11 one 1.80048644633\n",
"12 nine 1.8731231034\n",
"13 eight 1.95164583491\n",
"14 two 1.9520374482\n",
"15 seven 1.95837319942\n",
"16 six 1.97104882071\n",
"17 four 2.01665230496\n",
"18 ogumahideo 2.02079829584\n",
"19 five 2.02127420304\n"
]
}
],
"prompt_number": 87
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_neighbors('the')"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 the 0.0 \n",
"1 omobe 1.11880177048\n",
"2 zandramas 1.1371528579\n",
"3 nyabinghi 1.13799240546\n",
"4 smilga 1.13822478466\n",
"5 armisael 1.13952230814\n",
"6 braiterman 1.14003221416\n",
"7 hoshanah 1.14125947399\n",
"8 azetbur 1.14194853261\n",
"9 kollupitiya 1.14295614341\n",
"10 librian 1.14349397418\n",
"11 grunitzky 1.14365235389\n",
"12 hokage 1.14478282816\n",
"13 zerchi 1.14522446395\n",
"14 sercia 1.14551989932\n",
"15 besigye 1.14561000423\n",
"16 kentrat 1.14628596163\n",
"17 boedromion 1.14709324335\n",
"18 morlun 1.14856902356\n",
"19 asrar 1.14892649238\n"
]
}
],
"prompt_number": 88
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_neighbors('book')"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 book 0.0 \n",
"1 books 1.88765684011\n",
"2 variorum 1.90695920674\n",
"3 miscellanies 1.91086332941\n",
"4 blackly 1.91939505269\n",
"5 reprinting 1.92737927322\n",
"6 chapbook 1.93713267613\n",
"7 aryabhatiya 1.93724649155\n",
"8 unexpurgated 1.94173677311\n",
"9 helaman 1.95698111353\n",
"10 loompanics 1.9578002733\n",
"11 invisibles 1.96178201394\n",
"12 jasher 1.96534553896\n",
"13 mythadventures 1.96677258553\n",
"14 disneywar 1.97259642784\n",
"15 pallisers 1.97346406837\n",
"16 ripliad 1.9841095887\n",
"17 forewords 1.98496842113\n",
"18 epigraph 1.98529029477\n",
"19 barsetshire 1.98571507086\n"
]
}
],
"prompt_number": 102
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nearest Neighbours of phrases\n",
"==============================\n",
"Here, I sum up the vectors of the phrase and then look at the nearest neighbours.\n",
"\n",
"**Observations**\n",
"\n",
"* The words that make up the phrase are still the closest to the phrase.\n",
"* Notice that not all the words have the same influence, in the phrase \"second month of the year\", only {second, month, year} are close to the phrase it does not seem that {of, the} changed the phrase embeddings that much."
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_knn_phrase(\"very good\")"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 very 2.02757268595\n",
"1 good 2.14511392\n",
"2 quite 2.41730050239\n",
"3 extremely 2.63383599952\n",
"4 too 2.63744652432\n",
"5 fairly 2.65976621896\n",
"6 bad 2.79797117654\n",
"7 enough 2.8268833239\n",
"8 relatively 2.87095658184\n",
"9 exceedingly 2.87283047902\n",
"10 natured 2.87803582614\n",
"11 comparatively 2.87915718488\n",
"12 extraordinarily 2.92273364493\n",
"13 remarkably 2.99119514728\n",
"14 reasonably 2.99165171247\n",
"15 poor 2.99549033148\n",
"16 tough 2.99972125818\n",
"17 surprisingly 3.00423470826\n",
"18 exceptionally 3.00542705902\n",
"19 so 3.01978297284\n"
]
}
],
"prompt_number": 89
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_knn_phrase(\"not good\")"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 not 2.02757268595\n",
"1 good 2.22214639454\n",
"2 never 2.86305661863\n",
"3 indeed 2.8889430765\n",
"4 bad 2.91523993861\n",
"5 whatever 2.92094027292\n",
"6 nothing 2.92411460062\n",
"7 so 2.93375335615\n",
"8 always 2.95124772163\n",
"9 really 2.98023192259\n",
"10 actually 2.98412447259\n",
"11 therefore 2.9869738569\n",
"12 anyway 2.99071470444\n",
"13 only 2.99642716642\n",
"14 merely 2.99831283526\n",
"15 something 3.00975725443\n",
"16 otherwise 3.02370475405\n",
"17 perfectly 3.02575200851\n",
"18 cannot 3.04011813612\n",
"19 thing 3.04675479849\n"
]
}
],
"prompt_number": 90
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_knn_phrase(\"not bad\")"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 bad 2.22214639454\n",
"1 not 2.41156583302\n",
"2 good 2.74208441996\n",
"3 really 3.05039139775\n",
"4 never 3.07487082936\n",
"5 wrong 3.10786612541\n",
"6 nothing 3.11383084188\n",
"7 indeed 3.11558700764\n",
"8 something 3.11843176179\n",
"9 ugly 3.13169091161\n",
"10 actually 3.13367017276\n",
"11 anyway 3.1372355043\n",
"12 so 3.15331458279\n",
"13 simply 3.16526963427\n",
"14 thing 3.19701706287\n",
"15 whatever 3.19772031188\n",
"16 merely 3.19839313588\n",
"17 luck 3.20110591184\n",
"18 obviously 3.20233166983\n",
"19 it 3.20701670968\n"
]
}
],
"prompt_number": 91
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_knn_phrase(\"extremely bad\")"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 extremely 2.41156583302\n",
"1 bad 2.43449467574\n",
"2 very 2.83048791316\n",
"3 too 3.01777278005\n",
"4 quite 3.02441198258\n",
"5 good 3.07320611164\n",
"6 ugly 3.07887115698\n",
"7 incredibly 3.15114734449\n",
"8 fairly 3.17279002574\n",
"9 extraordinarily 3.18893680906\n",
"10 relatively 3.20866241465\n",
"11 tough 3.21635598437\n",
"12 exceedingly 3.21955500629\n",
"13 overly 3.23500915437\n",
"14 dangerous 3.25019820497\n",
"15 natured 3.27363652607\n",
"16 unpredictable 3.28204861682\n",
"17 exceptionally 3.29938492336\n",
"18 unusually 3.29947611282\n",
"19 poor 3.3073176501\n"
]
}
],
"prompt_number": 92
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_knn_phrase(\"second month of the year\")"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 year 5.0029650795\n",
"1 month 5.04977245949\n",
"2 week 5.58140471293\n",
"3 second 5.6983698843\n",
"4 months 5.74828646333\n",
"5 days 5.78854236335\n",
"6 fifth 5.82172430487\n",
"7 day 5.8519089308\n",
"8 fourth 5.85947580542\n",
"9 third 5.92042757004\n",
"10 sixth 5.92176988188\n",
"11 weeks 6.0653731202\n",
"12 eighth 6.08155219206\n",
"13 elul 6.08231423041\n",
"14 ten 6.08644369931\n",
"15 sabbatical 6.09076069834\n",
"16 decade 6.1035932397\n",
"17 fortnight 6.10438356555\n",
"18 last 6.10476745342\n",
"19 intercalary 6.10619566139\n"
]
}
],
"prompt_number": 93
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"get_knn_phrase(\"the month after january\")"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 month 3.98166400675\n",
"1 january 4.20514862572\n",
"2 december 4.33863770453\n",
"3 february 4.35422865987\n",
"4 march 4.35696733076\n",
"5 april 4.35786995855\n",
"6 months 4.37816463422\n",
"7 november 4.38389342062\n",
"8 days 4.42426412924\n",
"9 july 4.43140140413\n",
"10 october 4.43658944455\n",
"11 september 4.44553008295\n",
"12 year 4.4474416121\n",
"13 august 4.44804864964\n",
"14 week 4.47481899869\n",
"15 after 4.47831544409\n",
"16 june 4.60817926598\n",
"17 before 4.6250205963\n",
"18 weeks 4.68978529386\n",
"19 day 4.7146339223\n"
]
}
],
"prompt_number": 94
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Decompose a sentence\n",
"=====================\n",
"\n",
"Given a phrase a sum of the vectors of its own words, to reverse the procedure we have to solve the linear system\n",
"```python\n",
"Phrase = dot(Occurrence Matrix, Embeddings Matrix)\n",
"Occurrence Matrix = dot(Phrase, Embedding Matrix Inverse)\n",
"```\n",
"Therefore, we need to calculate the psuedo inverse of embeddings matrix and look at the largest component of the solution.\n",
"\n",
"I am surprised it works really good! Of course, it is harder to decide at which component we should stop?! However, we can start building a partial solution and compare the sentence to the partial solution and stop we we start diverging from the sentence :).\n"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from numpy.linalg import pinv\n",
"\n",
"embeddings_inverse = pinv(embeddings)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 34
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def decompose(phrase):\n",
" words = dot(get_phrase_point(phrase), embeddings_inverse)\n",
" print '\\n'.join(['{: <3}{: <20}{: <10}'.format(i, id_word[x], words[x]) for i, x in enumerate(reversed(words.argsort()[-15:]))])"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 98
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"decompose('not good')"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 not 0.00454078592865\n",
"1 good 0.00289725566233\n",
"2 couples 0.00276267718622\n",
"3 hth 0.00239669137487\n",
"4 bad 0.00231876458505\n",
"5 neither 0.00231822672031\n",
"6 never 0.00226820676811\n",
"7 nowrap 0.00212882344283\n",
"8 nor 0.00210485322845\n",
"9 nothing 0.00202502414415\n",
"10 cannot 0.00196498505034\n",
"11 poorly 0.00192701451615\n",
"12 luck 0.00189786580745\n",
"13 unable 0.00188411258273\n",
"14 dornier 0.00184270732512\n"
]
}
],
"prompt_number": 99
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"decompose('second month of the year')"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 year 0.00818805499276\n",
"1 month 0.00751032803789\n",
"2 gregorian 0.00636465953281\n",
"3 calendar 0.00546181072521\n",
"4 second 0.00539502185446\n",
"5 months 0.00538200401731\n",
"6 week 0.0052643120389\n",
"7 equinox 0.00497019081682\n",
"8 fifth 0.00486776245734\n",
"9 vernal 0.00484171622925\n",
"10 years 0.00478449424801\n",
"11 days 0.00476471815298\n",
"12 day 0.00470655757576\n",
"13 fourth 0.00468227683716\n",
"14 staggered 0.00467989278889\n"
]
}
],
"prompt_number": 100
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"decompose('president of united states is barrack obama')"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"0 united 0.0117301004546\n",
"1 states 0.0112717923368\n",
"2 wisconsinaccording 0.0100026352853\n",
"3 bureau 0.00920992570253\n",
"4 kingdom 0.00792875990308\n",
"5 naturalized 0.00737911901475\n",
"6 lyndon 0.00713012544686\n",
"7 nominees 0.0066013885303\n",
"8 retrospectively 0.00656759081726\n",
"9 president 0.00651313494474\n",
"10 tempore 0.0064619825011\n",
"11 emirates 0.00613937108443\n",
"12 vice 0.00582700772234\n",
"13 senate 0.00565314532246\n",
"14 presidential 0.00559438495095\n"
]
}
],
"prompt_number": 101
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Influence of Frequenct on Vector Norm\n",
"======================================\n",
"\n",
"The results do not show anything obvious, again we do not know how these vectors were trained!"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"norms = (embeddings ** 2).sum(axis=1) ** 0.5\n",
"indices = norms.argsort()\n",
"print \"Largest vectors\"\n",
"print '\\n'.join(['{: <3}{: <20}{: <10}'.format(i, id_word[x], norms[x]) for i, x in enumerate(reversed(indices[-50:]))])\n",
"\n",
"print \"\\n\\nSmallest vectors\"\n",
"print '\\n'.join(['{: <3}{: <20}{: <10}'.format(i, id_word[x], norms[x]) for i, x in enumerate(indices[:50])])\n",
"\n",
"\n"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"Largest vectors\n",
"0 multilicensewithcc 10.3268512519\n",
"1 jdforrester 10.0154157796\n",
"2 hth 10.0003442263\n",
"3 uploads 9.7868881845\n",
"4 wapcaplet 9.67973261777\n",
"5 retrospectively 9.56892967102\n",
"6 householder 9.33118443663\n",
"7 islander 9.05014823454\n",
"8 chordata 8.93656075554\n",
"9 classis 8.76696996107\n",
"10 namespace 8.6128387668\n",
"11 latino 8.5979219294\n",
"12 ordo 8.17901131632\n",
"13 hereby 8.15611378317\n",
"14 familia 7.95104060729\n",
"15 attribution 7.9373325256\n",
"16 jpeg 7.91507147932\n",
"17 households 7.67802718897\n",
"18 residing 7.52167119742\n",
"19 hispanic 7.48807752896\n",
"20 capita 7.47718165488\n",
"21 mathbf 7.45371968805\n",
"22 animalia 7.45246890654\n",
"23 cdot 7.42279401665\n",
"24 makeup 7.41894034817\n",
"25 couples 7.34979126253\n",
"26 frac 7.16273976919\n",
"27 rangle 7.094790385\n",
"28 bgcolor 7.08054397099\n",
"29 wisconsinaccording 6.96499767777\n",
"30 qquad 6.91234108173\n",
"31 licence 6.88492717492\n",
"32 mammalia 6.88148760443\n",
"33 phylum 6.87925932348\n",
"34 photographed 6.83425987876\n",
"35 cdots 6.79569890312\n",
"36 aves 6.79556269479\n",
"37 mbox 6.73414158759\n",
"38 median 6.72063447693\n",
"39 mathrm 6.70333544242\n",
"40 nowrap 6.69788441221\n",
"41 eeeeaa 6.60665266213\n",
"42 kommunedata 6.57814991347\n",
"43 courtesy 6.57624628835\n",
"44 morwen 6.55818726435\n",
"45 langle 6.55069524732\n",
"46 fdl 6.49573204626\n",
"47 regnum 6.48539039332\n",
"48 kmd 6.47439187599\n",
"49 template 6.47184635727\n",
"\n",
"\n",
"Smallest vectors\n",
"0 tapairu 0.0170473963115\n",
"1 forestburg 0.0174300506597\n",
"2 wilcoxes 0.0175761717959\n",
"3 ramn 0.0176507450551\n",
"4 rotaviruses 0.0177284992033\n",
"5 chahinkapa 0.0178666974564\n",
"6 samogon 0.0179913211299\n",
"7 blinne 0.0180032805344\n",
"8 woldumar 0.0180729706745\n",
"9 kobach 0.0182432231801\n",
"10 lifferth 0.0182973014677\n",
"11 cityofgoleta 0.0184365699901\n",
"12 takere 0.0185601906779\n",
"13 bgct 0.0186364083986\n",
"14 strataflash 0.018769974667\n",
"15 dicephalus 0.0188615523221\n",
"16 </s> 0.0188826574136\n",
"17 varima 0.018995082232\n",
"18 kruczak 0.0190906993324\n",
"19 tusten 0.0215875820554\n",
"20 wurtsboro 0.0224239146449\n",
"21 narrowsburg 0.0243615605617\n",
"22 cochecton 0.0245859408402\n",
"23 braeswood 0.0249381597356\n",
"24 armaou 0.0265067232792\n",
"25 ardenti 0.0274897423233\n",
"26 shandaken 0.0279073472226\n",
"27 takrur 0.0279678326118\n",
"28 allomothers 0.0302510569072\n",
"29 covertext 0.0303092984907\n",
"30 poliphilo 0.0348113839282\n",
"31 yelabuga 0.0358172907267\n",
"32 ictineu 0.0380123283949\n",
"33 epfcg 0.0385076033401\n",
"34 spinn 0.0386281221651\n",
"35 rifton 0.040149943462\n",
"36 kenu 0.0405585154561\n",
"37 hakudoshi 0.0408182057298\n",
"38 crumhorns 0.041133521804\n",
"39 schmiedeleut 0.0411884093648\n",
"40 lehrerleut 0.0425683737181\n",
"41 rublei 0.0435910749466\n",
"42 surrett 0.0443516885361\n",
"43 zaires 0.0451846235129\n",
"44 loveppears 0.0453572166033\n",
"45 prasanthi 0.0455711129994\n",
"46 ritonavir 0.0455812074983\n",
"47 polytechnicien 0.0462011677125\n",
"48 kuebler 0.0469128573314\n",
"49 joydev 0.0471839763585\n"
]
}
],
"prompt_number": 81
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment