Last active
August 11, 2020 13:31
-
-
Save jaklinger/2a194867e89a05939b3390131889d568 to your computer and use it in GitHub Desktop.
read arxiv vectors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nesta.core.orms.orm_utils import db_session, get_mysql_engine | |
from nesta.core.orms.arxiv_orm import ArticleVector | |
import numpy as np | |
import json | |
import os | |
os.environ['MYSQLDB'] = "/path/to/innovation-mapping-5712.config" | |
def query_and_bundle(session, fields, start, limit, filter_): | |
q = session.query(*fields) | |
if filter_ is not None: | |
q = q.filter(filter_) | |
else: | |
q = q.offset(start) | |
ids, vectors = zip(*q.limit(limit)) | |
return np.array(ids, dtype=np.dtype('U40')), np.array(vectors, dtype=np.float32) | |
def prefill_inputs(): | |
engine = get_mysql_engine("MYSQLDB", "mysqldb", "production") | |
with db_session(engine) as session: | |
count = session.query(ArticleVector).count() | |
a_vector, = session.query(ArticleVector.vector).limit(1).one() | |
dim = len(a_vector) | |
data = np.empty((count, dim), dtype=np.float32) | |
ids = np.empty((count, ), dtype=np.dtype('U40')) | |
return data, ids | |
def read_data(data, ids, chunksize=10000, start=None, max_chunks=None): | |
engine = get_mysql_engine("MYSQLDB", "mysqldb", "production") | |
fields = (ArticleVector.article_id, ArticleVector.vector) | |
count, _ = data.shape | |
start = sum(ids != '') if start is None else start # resume or take given value | |
filter_ = None | |
n_chunks = 0 | |
while start < count: | |
if max_chunks is not None and max_chunks >= n_chunks: | |
break | |
if start % 100000 == 0: | |
print("Collecting row", start) | |
limit = chunksize if start + chunksize < count else None | |
with db_session(engine) as session: | |
_ids, _data = query_and_bundle(session, fields, start, limit, filter_) | |
filter_ = ArticleVector.article_id > _ids[-1] | |
ids[start:start+_ids.shape[0]] = _ids | |
data[start:start+_data.shape[0]] = _data | |
start += chunksize | |
n_chunks += 1 | |
if __name__ == "__main__": | |
data, ids = prefill_inputs() # empty numpy arrays | |
while "reading data": | |
try: | |
n = sum(ids != '') # number of collected docs since the connection broke | |
if n > 0: | |
print("restarting from", n) | |
read_data(data, ids) # start or continue reading | |
except json.JSONDecodeError: # Happens if your connection drops slightly, corrupting the JSON | |
continue # retry | |
else: | |
break # done | |
np.save('arxiv_vectors.npy', data) | |
np.save('arxiv_vectors_ids.npy', ids) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment