Created
May 10, 2015 17:49
-
-
Save alksl/642d5a908f89470666bb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import sys | |
import tables | |
import numpy as np | |
from scipy import sparse | |
from scipy import io | |
from collections import defaultdict | |
from multiprocessing import Queue, Process | |
NUM_WORKERS = 8 | |
ARTIST_TAGS_FILE = os.path.abspath(sys.argv[1]) | |
WORDLIST_FILE = os.path.abspath(sys.argv[2]) | |
ARTIST_IDS_FILE = os.path.abspath(sys.argv[3]) | |
TERM_ARTIST_MATRIX_FILE = os.path.abspath(sys.argv[4]) | |
print("ARTIST_TAGS_FILE: ", ARTIST_TAGS_FILE) | |
print("WORDLIST_FILE: ", WORDLIST_FILE) | |
print("ARTIST_IDS_FILE: ", ARTIST_IDS_FILE) | |
print("TERM_ARTIST_MATRIX_FILE: ", TERM_ARTIST_MATRIX_FILE) | |
print("NUM_WORKERS: ", NUM_WORKERS) | |
def create_reverse_mapping(original_list): | |
return {val: idx for idx, val in enumerate(original_list)} | |
def print_progress(index, total): | |
print("\r", end="") | |
print("Progress: ", index, "/", total, end="") | |
tags_file = tables.open_file(ARTIST_TAGS_FILE, mode="r") | |
print("Calculating artist track mapping") | |
artist_to_track = defaultdict(list) | |
for index, artist_id in enumerate(tags_file.root.artist_id): | |
artist_to_track[artist_id].append(index) | |
print("Getting unique artists") | |
unique_artist_id = list(set(artist_to_track.keys())) | |
reverse_artist_id = create_reverse_mapping(unique_artist_id) | |
print("Getting wordlist") | |
wordlist = list(set(tags_file.root.en_terms.cols.term).union(set(tags_file.root.mbtags.cols.tag))) | |
reverse_wordlist = create_reverse_mapping(wordlist) | |
print("Saving wordlist") | |
io.savemat(WORDLIST_FILE, dict(artist_term_wordlist = wordlist)) | |
print("Saving artist id file") | |
io.savemat(ARTIST_TAGS_FILE, dict(term_artist_id = unique_artist_id)) | |
tags_file.close() | |
def collect_rows(collector_queue): | |
artist_count = len(unique_artist_id) | |
term_artist_matrix = sparse.lil_matrix((len(wordlist), artist_count), dtype=np.float64) | |
print("Stating collector") | |
i = 0 | |
while True: | |
msg = collector_queue.get() | |
if msg is None: | |
break | |
artist_id, all_terms, term_freqs = msg | |
artist_index = reverse_artist_id[artist_id] | |
for term in all_terms: | |
term_index = reverse_wordlist[term] | |
term_artist_matrix[term_index, artist_index] = np.mean(np.array(list(term_freqs[term]))) | |
i += 1 | |
print_progress(i, artist_count) | |
print("Saving matrix") | |
io.savemat(TERM_ARTIST_MATRIX_FILE, dict(term_artist_matrix = term_artist_matrix)) | |
print("Exiting collector") | |
def row_worker(queue, collector_queue, worker_index): | |
print("Starting worker ", worker_index) | |
with tables.open_file(ARTIST_TAGS_FILE, mode="r") as local_tags_file: | |
while True: | |
msg = queue.get() | |
if msg is None: | |
break | |
artist_id = msg | |
all_terms = set() | |
term_freqs = defaultdict(list) | |
total_mbtags_term_count = 0.0 | |
term_counts = defaultdict(list) | |
for track_index in artist_to_track[artist_id]: | |
query = "track_id == {0}".format(local_tags_file.root.track_id[track_index]) | |
for en_row in local_tags_file.root.en_terms.where(query): | |
all_terms.add(en_row['term']) | |
term_freqs[en_row['term']].append(en_row['freq']) | |
for mb_row in local_tags_file.root.mbtags.where(query): | |
all_terms.add(mb_row['tag']) | |
term_counts[mb_row['tag']].append(mb_row['count']) | |
total_mbtags_term_count += mb_row['count'] | |
for term, count in term_counts.items(): | |
term_freqs[term].append(sum(count)/total_mbtags_term_count) | |
collector_queue.put((artist_id, all_terms, term_freqs)) | |
print("Exiting worker ", worker_index) | |
collector_queue = Queue() | |
collector = Process(target=collect_rows, args=(collector_queue,)) | |
workers = [] | |
for worker_index in range(0, NUM_WORKERS): | |
queue = Queue() | |
process = Process(target=row_worker, args=(queue, collector_queue, worker_index)) | |
workers.append((process, queue)) | |
collector.start() | |
for worker in workers: | |
worker[0].start() | |
ticking_index = 0 | |
for artist_id in unique_artist_id: | |
workers[ticking_index % NUM_WORKERS][1].put(artist_id) | |
ticking_index += 1 | |
for worker in workers: | |
worker[1].put(None) | |
collector_queue.put(None) | |
for worker in workers: | |
worker[0].join() | |
print("All workers joined") | |
collector.join() | |
print("Collector joined") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment