Skip to content

Instantly share code, notes, and snippets.

@alksl
Created May 10, 2015 17:49
Show Gist options
  • Save alksl/642d5a908f89470666bb to your computer and use it in GitHub Desktop.
Save alksl/642d5a908f89470666bb to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import os
import sys
import tables
import numpy as np
from scipy import sparse
from scipy import io
from collections import defaultdict
from multiprocessing import Queue, Process
NUM_WORKERS = 8
ARTIST_TAGS_FILE = os.path.abspath(sys.argv[1])
WORDLIST_FILE = os.path.abspath(sys.argv[2])
ARTIST_IDS_FILE = os.path.abspath(sys.argv[3])
TERM_ARTIST_MATRIX_FILE = os.path.abspath(sys.argv[4])
print("ARTIST_TAGS_FILE: ", ARTIST_TAGS_FILE)
print("WORDLIST_FILE: ", WORDLIST_FILE)
print("ARTIST_IDS_FILE: ", ARTIST_IDS_FILE)
print("TERM_ARTIST_MATRIX_FILE: ", TERM_ARTIST_MATRIX_FILE)
print("NUM_WORKERS: ", NUM_WORKERS)
def create_reverse_mapping(original_list):
return {val: idx for idx, val in enumerate(original_list)}
def print_progress(index, total):
print("\r", end="")
print("Progress: ", index, "/", total, end="")
tags_file = tables.open_file(ARTIST_TAGS_FILE, mode="r")
print("Calculating artist track mapping")
artist_to_track = defaultdict(list)
for index, artist_id in enumerate(tags_file.root.artist_id):
artist_to_track[artist_id].append(index)
print("Getting unique artists")
unique_artist_id = list(set(artist_to_track.keys()))
reverse_artist_id = create_reverse_mapping(unique_artist_id)
print("Getting wordlist")
wordlist = list(set(tags_file.root.en_terms.cols.term).union(set(tags_file.root.mbtags.cols.tag)))
reverse_wordlist = create_reverse_mapping(wordlist)
print("Saving wordlist")
io.savemat(WORDLIST_FILE, dict(artist_term_wordlist = wordlist))
print("Saving artist id file")
io.savemat(ARTIST_TAGS_FILE, dict(term_artist_id = unique_artist_id))
tags_file.close()
def collect_rows(collector_queue):
artist_count = len(unique_artist_id)
term_artist_matrix = sparse.lil_matrix((len(wordlist), artist_count), dtype=np.float64)
print("Stating collector")
i = 0
while True:
msg = collector_queue.get()
if msg is None:
break
artist_id, all_terms, term_freqs = msg
artist_index = reverse_artist_id[artist_id]
for term in all_terms:
term_index = reverse_wordlist[term]
term_artist_matrix[term_index, artist_index] = np.mean(np.array(list(term_freqs[term])))
i += 1
print_progress(i, artist_count)
print("Saving matrix")
io.savemat(TERM_ARTIST_MATRIX_FILE, dict(term_artist_matrix = term_artist_matrix))
print("Exiting collector")
def row_worker(queue, collector_queue, worker_index):
print("Starting worker ", worker_index)
with tables.open_file(ARTIST_TAGS_FILE, mode="r") as local_tags_file:
while True:
msg = queue.get()
if msg is None:
break
artist_id = msg
all_terms = set()
term_freqs = defaultdict(list)
total_mbtags_term_count = 0.0
term_counts = defaultdict(list)
for track_index in artist_to_track[artist_id]:
query = "track_id == {0}".format(local_tags_file.root.track_id[track_index])
for en_row in local_tags_file.root.en_terms.where(query):
all_terms.add(en_row['term'])
term_freqs[en_row['term']].append(en_row['freq'])
for mb_row in local_tags_file.root.mbtags.where(query):
all_terms.add(mb_row['tag'])
term_counts[mb_row['tag']].append(mb_row['count'])
total_mbtags_term_count += mb_row['count']
for term, count in term_counts.items():
term_freqs[term].append(sum(count)/total_mbtags_term_count)
collector_queue.put((artist_id, all_terms, term_freqs))
print("Exiting worker ", worker_index)
collector_queue = Queue()
collector = Process(target=collect_rows, args=(collector_queue,))
workers = []
for worker_index in range(0, NUM_WORKERS):
queue = Queue()
process = Process(target=row_worker, args=(queue, collector_queue, worker_index))
workers.append((process, queue))
collector.start()
for worker in workers:
worker[0].start()
ticking_index = 0
for artist_id in unique_artist_id:
workers[ticking_index % NUM_WORKERS][1].put(artist_id)
ticking_index += 1
for worker in workers:
worker[1].put(None)
collector_queue.put(None)
for worker in workers:
worker[0].join()
print("All workers joined")
collector.join()
print("Collector joined")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment