Skip to content

Instantly share code, notes, and snippets.

@alksl
Created July 17, 2015 11:07
Show Gist options
  • Save alksl/b0228fe2d84d1a55994c to your computer and use it in GitHub Desktop.
Save alksl/b0228fe2d84d1a55994c to your computer and use it in GitHub Desktop.
from collections import namedtuple
from collections import defaultdict
ArtistTermFreq = namedtuple('ArtistTermFreq', ['artist_id', 'term', 'freq'])
class ArtistTermFreqIterator:
def __init__(self, artist_term_file, artist_collection, term_collection):
self.term_file = artist_term_file
self.artist_indexes = self._create_artist_indexes(artist_collection)
self.track_index = self._create_track_reverse_mapping()
self.artist_track_mapping = self._create_artist_track_mapping(artist_collection)
# Populate terms associated with artists
self.artist_terms_freq = defaultdict(lambda: defaultdict(list))
self._populate_from_en_terms(term_collection)
self._populate_from_mbtags(term_collection)
self.artist_ids = list(self.artist_terms_freq.keys())
self.artist_index = 0
self.artist_end_index = len(self.artist_ids)
self._set_terms()
self.term_index = 0
self.term_end_index = len(self.artist_term_freq[0])
def __next_(self):
return self.next()
def __iter__(self):
return self
def next(self):
self.term_index += 1
if self.term_index > self.term_end_index
self.artist_index += 1
if self.artist_index > self.term_end_index:
raise StopIteration()
self.term_index = 0
self.term_end_index = len(self.artist_terms_freq[self.artist_index])
self._set_terms()
return self._create_frequency(self.artist_index, self.term_index)
def _set_terms(self)
artist_id = self.artist_ids[self.artist_index]
self.terms = list(self.artist_terms_freq[artist_id].keys)
def _create_frequency(self, artist_index, term_index):
artist_id = self.artist_ids[self.artist_id]
term = self.terms[self.term_index]
freq = np.mean(np.array(self.artist_terms_freq[artist_id][term]))
return ArtistTermFreq(artist_id, term, freq)
def _create_artist_track_mapping(self, artist_collection):
mapping = defaultdict(list)
for artist_id in artist_collection:
for index in self.artist_indexes[artist_id]
mapping[artist_id].append(self.term_file.root.track_id[index].decode("utf-8"))
return mapping
def _create_artist_indexes(self, artist_collection):
reverse_mapping = defaultdict(list)
for index, artist_id in self.term_file.artist_id:
decoded_id = artist_id.decode("utf-8")
reverse_mapping[decoded_id].append(index)
return reverse_mapping
def _create_track_reverse_mapping(self):
mapping = dict()
for index, value in enumerate(self.trem_file.root.track_id):
mapping[index.decode("utf-8")] = index
return mapping
def _lookup_artist_id(self, track_id):
track_index = self.track_index[track_id]
return self.term_file.artist_id[track_index].decode("utf-8")
def _populate_from_en_terms(self, term_collection):
for row in self.term_file.root.en_terms:
term = row['term'].decode("utf-8")
artist_id = self._lookup_artist_id(row['track_id'].decode("utf-8"))
try:
term_collection.lookup(term)
self.artist_terms_freq[artist_id][term].append(row['freq'])
except KeyError:
pass
def _populate_from_mbtags(self, term_collection):
artist_mbtags_count = defaultdict(lambda: defaultdict(list))
for row in self.term_file.root.mbtags:
tag = row['tag'].decode("utf-8")
artist_id = self._lookup_artist_id(row['track_id'].decode("utf-8"))
try:
term_collection.lookup(tag)
artist_mbtags_count[artist_id][tag].append(count)
except KeyError:
pass
for artist_id, terms in artist_mbtags_count:
for term, term_counts in terms:
total_term_count = sum(term_counts)
for count in term_counts:
freq = float(count) / float(total_term_count)
self.artist_terms_freq[artist_id][term].append(freq)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment