Skip to content

Instantly share code, notes, and snippets.

@alksl
Created July 17, 2015 12:08
Show Gist options
  • Save alksl/5b9cab996efc00079cce to your computer and use it in GitHub Desktop.
Save alksl/5b9cab996efc00079cce to your computer and use it in GitHub Desktop.
import sys
from collections import namedtuple
from collections import defaultdict
import numpy as np
import logging
from progress_bar import ProgressBar
ArtistTermFreq = namedtuple('ArtistTermFreq', ['artist_id', 'term', 'freq'])
class ArtistTermFreqIterator:
def __init__(self, artist_term_file, term_collection):
self.term_file = artist_term_file
self.track_index = self._create_track_reverse_mapping()
# Populate terms associated with artists
self.artist_terms_freq = defaultdict(lambda: defaultdict(list))
self._populate_from_en_terms(term_collection)
self._populate_from_mbtags(term_collection)
self.artist_ids = list(self.artist_terms_freq.keys())
self.artist_index = 0
self.artist_end_index = len(self.artist_ids)
self._set_terms()
self.term_index = 0
self.term_end_index = len(self.artist_terms_freq[0])
logging.info("Initialization complete")
def __next_(self):
return self.next()
def __iter__(self):
return self
def next(self):
self.term_index += 1
if self.term_index > self.term_end_index:
self.artist_index += 1
if self.artist_index > self.term_end_index:
raise StopIteration()
self.term_index = 0
self.term_end_index = len(self.artist_terms_freq[self.artist_index])
self._set_terms()
return self._create_frequency(self.artist_index, self.term_index)
def _set_terms(self):
artist_id = self.artist_ids[self.artist_index]
self.terms = list(self.artist_terms_freq[artist_id].keys)
def _create_frequency(self, artist_index, term_index):
artist_id = self.artist_ids[self.artist_id]
term = self.terms[self.term_index]
freq = np.mean(np.array(self.artist_terms_freq[artist_id][term]))
return ArtistTermFreq(artist_id, term, freq)
def _create_track_reverse_mapping(self):
logging.info("Create track reverse mapping")
mapping = dict()
for index, track_id in enumerate(self.term_file.root.track_id):
mapping[track_id.decode("utf-8")] = index
return mapping
def _lookup_artist_id(self, track_id):
track_index = self.track_index[track_id]
return self.term_file.root.artist_id[track_index].decode("utf-8")
def _populate_from_en_terms(self, term_collection):
logging.info("Populate from EN terms")
progress = ProgressBar(sys.stdout, 28979585)
for row in self.term_file.root.en_terms:
term = row['term'].decode("utf-8")
artist_id = self._lookup_artist_id(row['track_id'].decode("utf-8"))
try:
term_collection.lookup(term)
self.artist_terms_freq[artist_id][term].append(row['freq'])
except KeyError:
pass
progress.tick()
def _populate_from_mbtags(self, term_collection):
logging.info("Populate from mbtags")
artist_mbtags_count = defaultdict(lambda: defaultdict(list))
progress = ProgressBar(sys.stdout, 1147174)
for row in self.term_file.root.mbtags:
tag = row['tag'].decode("utf-8")
artist_id = self._lookup_artist_id(row['track_id'].decode("utf-8"))
try:
term_collection.lookup(tag)
artist_mbtags_count[artist_id][tag].append(count)
except KeyError:
pass
progress.tick()
logging.info("Convert mbtags count to freq")
for artist_id, terms in artist_mbtags_count:
for term, term_counts in terms:
total_term_count = sum(term_counts)
for count in term_counts:
freq = float(count) / float(total_term_count)
self.artist_terms_freq[artist_id][term].append(freq)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment