Last active
April 18, 2017 04:03
-
-
Save niklio/898f7771b41ce8bbeb94beea77774e68 to your computer and use it in GitHub Desktop.
Reorder arbitrary partition of a .wav file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import os | |
import pdb | |
import matplotlib.pyplot as plt | |
from munkres import Munkres # Because big data + I only remember Hungarian | |
import numpy as np | |
from scipy.io import wavfile | |
from tqdm import tqdm | |
# params | |
media_id = 1080 | |
max_amp = 5 | |
initial_path = 'media_{}_initial.wav'.format(media_id) | |
segment_dir = 'meeting_{}_media'.format(media_id) | |
checkpoint_dir = os.path.abspath('checkpoints') | |
def load_wav(f, amp_max=2): | |
# File stat | |
f_size_b = os.path.getsize(f) | |
f_size_mb = f_size_b / 1024.0 / 1024.0 | |
f_path = os.path.relpath(f) | |
_, f_ext = os.path.splitext(f) | |
# Load Audio file | |
freq, wav = wavfile.read(f) | |
# Expecting wav in R^1 or R^2 | |
dim = len(wav.shape) | |
dim_err_msg = "Malformed audio file, expecting 2 dimensions, got {}".format(dim) | |
assert dim < 3, dim_err_msg | |
# Handle multi channel audio | |
if len(wav.shape) == 2: | |
wav = wav[:,0] | |
# Normalize signal | |
amp = np.abs(wav).max() | |
wav = wav * (max_amp / amp) | |
return wav, freq | |
def split(wav, freq, seg_sec): | |
# Wav stat | |
seg_smp = math.ceil(seg_sec * freq) # Number of samples in a segment | |
wav_smp = wav.shape[0] # Number of samples in wav | |
wav_sec = math.ceil(wav_smp / freq) # Length of wav in seconds | |
pad_smp = seg_smp - wav_smp % seg_smp # Number of samples appended as padding | |
seg_count = (wav_smp + pad_smp) // seg_smp # Number of segments | |
# Log info | |
info = {'wav_sec': wav_sec, 'seg_count': seg_count, 'seg_sec': seg_sec} | |
info_msg = "Splitting {wav_sec}s signal into {seg_count} {seg_sec}s segments".format(**info) | |
print(info_msg) | |
# Pad the end of signal | |
wav = np.resize(wav, (wav_smp + pad_smp,)) | |
# Split | |
seg_vec = np.zeros((seg_count, seg_smp)) | |
for i in tqdm(range(seg_count)): | |
seg_vec[i] = wav[i * seg_smp:(i + 1) * seg_smp] | |
return seg_vec | |
def load_segments(): | |
# Segment dir stat | |
seg_files = os.listdir(segment_dir) | |
seg_count = len(seg_files) | |
# Info | |
info = {'num': seg_count, 'path': segment_dir} | |
info_msg = 'Loading {num} files from {path}'.format(**info) | |
print(info_msg) | |
# Max samples, max seconds | |
seg_wav_smp = seg_wav_sec = 0 | |
# Store wav and frequency of each segment | |
seg_wav_vec = np.zeros((seg_count, seg_wav_smp)) | |
seg_freq_vec = np.zeros(seg_count) | |
# Load files | |
for i, seg_fname in enumerate(tqdm(seg_files)): | |
seg_path = os.path.join(segment_dir, seg_fname) | |
seg_wav, seg_freq = load_wav(seg_path) | |
seg_wav_smp = max(seg_wav.shape[0], seg_wav_smp) | |
seg_wav_sec = max(seg_wav.shape[0] / seg_freq, seg_wav_sec) | |
# Resize if we found a bigger segment | |
if seg_wav_vec.shape[1] != seg_wav_smp: | |
seg_wav_vec.resize((seg_count, seg_wav_smp)) | |
seg_wav_vec[i, 0:seg_wav.shape[0]] = seg_wav | |
seg_freq_vec[i] = seg_freq | |
return seg_wav_vec, seg_wav_sec | |
""" This function isn't right, but the matching alg seems to be """ | |
def plot_segments(init_wav_vec, seg_wav_vec): | |
# Info | |
print('Allocating memory to matplotlib, this may take a minute') | |
# plt.rcParams['agg.path.chunksize'] = 20000 | |
plt.rcParams['agg.path.chunksize'] = 5000 | |
init_wav = init_wav_vec.flatten() | |
seg_wav = seg_wav_vec.flatten() | |
f, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(40, 5)) | |
ax1.plot(init_wav) | |
ax1.set_title('Initial') | |
ax2.plot(seg_wav) | |
ax2.set_title('Segments') | |
plt.show() | |
# Load segments | |
segment_wav_vec, segment_wav_sec = load_segments() | |
# FFT is expensive so let's use a checkpoint | |
segment_fft_checkpoint = os.path.join(checkpoint_dir, 'segment_fft.npy') | |
if os.path.isfile(segment_fft_checkpoint): | |
# Info | |
info_msg = 'Loading segment Fourier transforms from {}'.format(segment_fft_checkpoint) | |
print(info_msg) | |
# Load checkpoint | |
segment_fft_vec = np.load(segment_fft_checkpoint) | |
else: | |
# FFT Info | |
segment_smp = segment_wav_vec.shape[0] * segment_wav_vec.shape[1] | |
info_msg = 'Calculating Fourier transform of {} sample signal'.format(segment_smp) | |
print(info_msg) | |
# Apply FFT to every segment | |
segment_fft_vec = np.empty(segment_wav_vec.shape, dtype=np.complex64) | |
for i in tqdm(range(segment_fft_vec.shape[0])): | |
segment_fft_vec[i,:] = np.fft.fft(segment_wav_vec[i]) | |
# Save | |
np.save(segment_fft_checkpoint, segment_fft_vec) | |
# Save info | |
info_msg = 'Saved Fourier transform to {}'.format(segment_fft_checkpoint) | |
print(info_msg) | |
global initial_wav_vec | |
# FFT is still expensive | |
initial_fft_checkpoint = os.path.join(checkpoint_dir, 'initial_fft.npy') | |
if os.path.isfile(initial_fft_checkpoint): | |
# Info | |
info_msg = 'Loading initial Fourier transforms from {}'.format(initial_fft_checkpoint) | |
print(info_msg) | |
# Load checkpoint | |
initial_wav_vec, initial_fft_vec = np.load(initial_fft_checkpoint) | |
else: | |
# Load and split initial | |
initial_wav, initial_freq = load_wav(initial_path) | |
initial_wav_vec = split(initial_wav, initial_freq, segment_wav_sec) | |
# FFT info | |
initial_smp = initial_wav.shape[0] | |
info_msg = 'Calculating Fourier transform of {} sample signal'.format(initial_smp) | |
print(info_msg) | |
# Apply FFT to every segment | |
initial_fft_vec = np.empty(initial_wav_vec.shape, dtype=np.complex64) | |
for i in tqdm(range(initial_wav_vec.shape[0])): | |
initial_fft_vec[i] = np.fft.fft(initial_wav_vec[i]) | |
# Save | |
np.save(initial_fft_checkpoint, (initial_wav_vec, initial_fft_vec)) | |
# Save info | |
info_msg = 'Saved Fourier transform to {}'.format(initial_fft_checkpoint) | |
print(info_msg) | |
# FFT similarty matrix for convex optimization | |
global fft_similarity_mat | |
# At this point, segment_fft_vec is at least as large as initial_fft_vec | |
if segment_fft_vec.shape != initial_fft_vec.shape: | |
segment_fft_vec.resize(initial_fft_vec.shape) | |
assert segment_fft_vec.shape == initial_fft_vec.shape | |
fft_similarity_checkpoint = os.path.join(checkpoint_dir, 'fft_similarity.npy') | |
if os.path.isfile(fft_similarity_checkpoint): | |
# Info | |
info_msg = 'Loading Fourier similarity matrix from {}'.format(fft_similarity_checkpoint) | |
print(info_msg) | |
# Load checkpoint | |
fft_similarity_mat = np.load(fft_similarity_checkpoint) | |
else: | |
# FFT vec stat | |
n = segment_fft_vec.shape[0] | |
# Similarity matrix info | |
info_msg = 'Building {0}x{0} fourier similarity matrix'.format(n) | |
print(info_msg) | |
# Build similarity matrix from fft norms | |
fft_similarity_mat = np.empty((n, n), dtype=np.float64) | |
for i in tqdm(range(n)): | |
for j in range(n): | |
initial = np.sum(initial_wav_vec[i] ** 2) * initial_fft_vec[i] | |
segment = np.sum(segment_wav_vec[j] ** 2) * segment_fft_vec[j] | |
fft_similarity_mat[i, j] = np.linalg.norm(initial - segment) | |
# Save | |
np.save(fft_similarity_checkpoint, fft_similarity_mat) | |
# Save info | |
info_msg = 'Saved similarity matrix to {}'.format(fft_similarity_checkpoint) | |
print(info_msg) | |
# Bipartite match | |
global bipartite_match | |
bipartite_match_checkpoint = os.path.join(checkpoint_dir, 'bipartite_match.npy') | |
if os.path.isfile(bipartite_match_checkpoint): | |
# Info | |
info_msg = 'Loading bipartite match from {}'.format(bipartite_match_checkpoint) | |
print(info_msg) | |
# Load checkpoint | |
bipartite_match = np.load(bipartite_match_checkpoint) | |
else: | |
# Bipartite match info | |
print('Starting Munkres bipartite match') | |
# Munkres matching alg | |
m = Munkres() | |
bipartite_match = m.compute(fft_similarity_mat) | |
# Save | |
np.save(bipartite_match_checkpoint, bipartite_match) | |
# Save info | |
info_msg = 'Saved bipartite match results to {}'.format(bipartite_match_checkpoint) | |
print(info_msg) | |
ordered_segment_wav_vec = np.empty(segment_wav_vec.shape) | |
for i, j in bipartite_match: | |
ordered_segment_wav_vec[i] = segment_wav_vec[j] | |
plot_segments(initial_wav_vec, ordered_segment_wav_vec) | |
# for i, row in enumerate(fft_similarity_mat): | |
# pairs = [] | |
# for j, x in enumerate(row): | |
# if x == 0: | |
# pairs.append((j, x)) | |
# | |
# print(i, ': ', pairs) | |
for row in bipartite_match: | |
print(row[0] + 1, ':', row[1] + 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment