Last active
February 13, 2018 19:55
-
-
Save hiropppe/db7a721a40d594e01c0e098975424fe2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf8 -*- | |
from __future__ import division | |
from __future__ import unicode_literals | |
import codecs | |
import json | |
import multiprocessing | |
import re | |
import sys | |
import unicodedata | |
from collections import defaultdict | |
from functools import partial | |
from gensim.corpora import wikicorpus | |
from itertools import imap | |
from multiprocessing.pool import Pool | |
from nltk import ngrams | |
from smart_open import smart_open | |
from tqdm import tqdm | |
re_blank = re.compile(r'^[\s\u3000]*$') | |
sys.stdout = codecs.getwriter('utf8')(sys.stdout) | |
sys.stderr = codecs.getwriter('utf8')(sys.stderr) | |
mentions = None | |
def get_mention_info(mention_file, line): | |
global mentions | |
if not mentions: | |
with codecs.open(mention_file, mode='r', encoding='utf8') as mf: | |
mentions = {m[:-1] for m in mf} | |
ret = {} | |
data = json.loads(line) | |
for anchor in data['interlinks'].values(): | |
#assert anchor in mentions, u'anchor not found in mentions. {:s}'.format(anchor) | |
if anchor in mentions: | |
ret[anchor] = 1 | |
for text in data['section_texts']: | |
text = text.strip() | |
for n in range(10): | |
for span in ngrams(text, n): | |
span = ''.join(span) | |
if span in mentions and span not in ret: | |
ret[span] = 0 | |
return ret | |
def dump_mention_info(segment_wiki, | |
mention_file, | |
out, | |
pool_size=1, | |
chunk_size=10): | |
global mentions | |
with codecs.open(mention_file, mode='r', encoding='utf8') as mf: | |
mentions = {m[:-1] for m in mf} | |
if pool_size > 1: | |
pool = Pool(pool_size) | |
imap_func = partial(pool.imap_unordered, chunksize=chunk_size) | |
else: | |
imap_func = imap | |
process_func = partial(get_mention_info, mention_file) | |
text_entity_counter = defaultdict(int) | |
anchor_entity_counter = defaultdict(int) | |
pbar = tqdm() | |
for mention_info in imap_func(process_func, smart_open(segment_wiki)): | |
for m, is_anchor in mention_info.items(): | |
if is_anchor: | |
anchor_entity_counter[m] += 1 | |
else: | |
text_entity_counter[m] += 1 | |
pbar.update(1) | |
for m in mentions: | |
try: | |
out.write('{:s}\t{:d}\t{:d}\t{:.3f}\n'.format( | |
m, anchor_entity_counter[m], text_entity_counter[m], | |
anchor_entity_counter[m] / (text_entity_counter[m] + anchor_entity_counter[m]))) | |
except ZeroDivisionError: | |
out.write('{:s}\t{:d}\t{:d}\t{:.3f}\n'.format(m, 0, 0, 0.)) | |
def extract_interlinks(segment_wiki, out): | |
for line in tqdm(smart_open(segment_wiki)): | |
data = json.loads(line) | |
for entity, text in data['interlinks'].items(): | |
out.write(u'{:s}\t{:s}\n'.format(text, entity)) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--segment_wiki_dump', '-s', type=str, default=None, required=True, | |
help='Wikipedia segment_wiki dump file (json.gz)') | |
parser.add_argument( | |
'--mention', '-m', type=str, default=None, required=True, | |
help='Mention file (txt)') | |
parser.add_argument( | |
'--pool_size', '-p', type=int, default=multiprocessing.cpu_count(), required=False, | |
help='') | |
parser.add_argument( | |
'--out', '-o', type=str, required=False, | |
help='Output file.' | |
) | |
args = parser.parse_args() | |
if args.out: | |
out = codecs.open(args.out, mode='w', encoding='utf8') | |
else: | |
out = sys.stdout | |
args = parser.parse_args() | |
#extract_interlinks(args.segment_wiki_dump, out) | |
dump_mention_info(args.segment_wiki_dump, args.mention, out, pool_size=args.pool_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment