Skip to content

Instantly share code, notes, and snippets.

@niklio
Last active April 18, 2017 04:03
Show Gist options
  • Save niklio/898f7771b41ce8bbeb94beea77774e68 to your computer and use it in GitHub Desktop.
Save niklio/898f7771b41ce8bbeb94beea77774e68 to your computer and use it in GitHub Desktop.
Reorder arbitrary partition of a .wav file.
import math
import os
import pdb
import matplotlib.pyplot as plt
from munkres import Munkres # Because big data + I only remember Hungarian
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
# params
media_id = 1080
max_amp = 5
initial_path = 'media_{}_initial.wav'.format(media_id)
segment_dir = 'meeting_{}_media'.format(media_id)
checkpoint_dir = os.path.abspath('checkpoints')
def load_wav(f, amp_max=2):
# File stat
f_size_b = os.path.getsize(f)
f_size_mb = f_size_b / 1024.0 / 1024.0
f_path = os.path.relpath(f)
_, f_ext = os.path.splitext(f)
# Load Audio file
freq, wav = wavfile.read(f)
# Expecting wav in R^1 or R^2
dim = len(wav.shape)
dim_err_msg = "Malformed audio file, expecting 2 dimensions, got {}".format(dim)
assert dim < 3, dim_err_msg
# Handle multi channel audio
if len(wav.shape) == 2:
wav = wav[:,0]
# Normalize signal
amp = np.abs(wav).max()
wav = wav * (max_amp / amp)
return wav, freq
def split(wav, freq, seg_sec):
# Wav stat
seg_smp = math.ceil(seg_sec * freq) # Number of samples in a segment
wav_smp = wav.shape[0] # Number of samples in wav
wav_sec = math.ceil(wav_smp / freq) # Length of wav in seconds
pad_smp = seg_smp - wav_smp % seg_smp # Number of samples appended as padding
seg_count = (wav_smp + pad_smp) // seg_smp # Number of segments
# Log info
info = {'wav_sec': wav_sec, 'seg_count': seg_count, 'seg_sec': seg_sec}
info_msg = "Splitting {wav_sec}s signal into {seg_count} {seg_sec}s segments".format(**info)
print(info_msg)
# Pad the end of signal
wav = np.resize(wav, (wav_smp + pad_smp,))
# Split
seg_vec = np.zeros((seg_count, seg_smp))
for i in tqdm(range(seg_count)):
seg_vec[i] = wav[i * seg_smp:(i + 1) * seg_smp]
return seg_vec
def load_segments():
# Segment dir stat
seg_files = os.listdir(segment_dir)
seg_count = len(seg_files)
# Info
info = {'num': seg_count, 'path': segment_dir}
info_msg = 'Loading {num} files from {path}'.format(**info)
print(info_msg)
# Max samples, max seconds
seg_wav_smp = seg_wav_sec = 0
# Store wav and frequency of each segment
seg_wav_vec = np.zeros((seg_count, seg_wav_smp))
seg_freq_vec = np.zeros(seg_count)
# Load files
for i, seg_fname in enumerate(tqdm(seg_files)):
seg_path = os.path.join(segment_dir, seg_fname)
seg_wav, seg_freq = load_wav(seg_path)
seg_wav_smp = max(seg_wav.shape[0], seg_wav_smp)
seg_wav_sec = max(seg_wav.shape[0] / seg_freq, seg_wav_sec)
# Resize if we found a bigger segment
if seg_wav_vec.shape[1] != seg_wav_smp:
seg_wav_vec.resize((seg_count, seg_wav_smp))
seg_wav_vec[i, 0:seg_wav.shape[0]] = seg_wav
seg_freq_vec[i] = seg_freq
return seg_wav_vec, seg_wav_sec
""" This function isn't right, but the matching alg seems to be """
def plot_segments(init_wav_vec, seg_wav_vec):
# Info
print('Allocating memory to matplotlib, this may take a minute')
# plt.rcParams['agg.path.chunksize'] = 20000
plt.rcParams['agg.path.chunksize'] = 5000
init_wav = init_wav_vec.flatten()
seg_wav = seg_wav_vec.flatten()
f, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(40, 5))
ax1.plot(init_wav)
ax1.set_title('Initial')
ax2.plot(seg_wav)
ax2.set_title('Segments')
plt.show()
# Load segments
segment_wav_vec, segment_wav_sec = load_segments()
# FFT is expensive so let's use a checkpoint
segment_fft_checkpoint = os.path.join(checkpoint_dir, 'segment_fft.npy')
if os.path.isfile(segment_fft_checkpoint):
# Info
info_msg = 'Loading segment Fourier transforms from {}'.format(segment_fft_checkpoint)
print(info_msg)
# Load checkpoint
segment_fft_vec = np.load(segment_fft_checkpoint)
else:
# FFT Info
segment_smp = segment_wav_vec.shape[0] * segment_wav_vec.shape[1]
info_msg = 'Calculating Fourier transform of {} sample signal'.format(segment_smp)
print(info_msg)
# Apply FFT to every segment
segment_fft_vec = np.empty(segment_wav_vec.shape, dtype=np.complex64)
for i in tqdm(range(segment_fft_vec.shape[0])):
segment_fft_vec[i,:] = np.fft.fft(segment_wav_vec[i])
# Save
np.save(segment_fft_checkpoint, segment_fft_vec)
# Save info
info_msg = 'Saved Fourier transform to {}'.format(segment_fft_checkpoint)
print(info_msg)
global initial_wav_vec
# FFT is still expensive
initial_fft_checkpoint = os.path.join(checkpoint_dir, 'initial_fft.npy')
if os.path.isfile(initial_fft_checkpoint):
# Info
info_msg = 'Loading initial Fourier transforms from {}'.format(initial_fft_checkpoint)
print(info_msg)
# Load checkpoint
initial_wav_vec, initial_fft_vec = np.load(initial_fft_checkpoint)
else:
# Load and split initial
initial_wav, initial_freq = load_wav(initial_path)
initial_wav_vec = split(initial_wav, initial_freq, segment_wav_sec)
# FFT info
initial_smp = initial_wav.shape[0]
info_msg = 'Calculating Fourier transform of {} sample signal'.format(initial_smp)
print(info_msg)
# Apply FFT to every segment
initial_fft_vec = np.empty(initial_wav_vec.shape, dtype=np.complex64)
for i in tqdm(range(initial_wav_vec.shape[0])):
initial_fft_vec[i] = np.fft.fft(initial_wav_vec[i])
# Save
np.save(initial_fft_checkpoint, (initial_wav_vec, initial_fft_vec))
# Save info
info_msg = 'Saved Fourier transform to {}'.format(initial_fft_checkpoint)
print(info_msg)
# FFT similarty matrix for convex optimization
global fft_similarity_mat
# At this point, segment_fft_vec is at least as large as initial_fft_vec
if segment_fft_vec.shape != initial_fft_vec.shape:
segment_fft_vec.resize(initial_fft_vec.shape)
assert segment_fft_vec.shape == initial_fft_vec.shape
fft_similarity_checkpoint = os.path.join(checkpoint_dir, 'fft_similarity.npy')
if os.path.isfile(fft_similarity_checkpoint):
# Info
info_msg = 'Loading Fourier similarity matrix from {}'.format(fft_similarity_checkpoint)
print(info_msg)
# Load checkpoint
fft_similarity_mat = np.load(fft_similarity_checkpoint)
else:
# FFT vec stat
n = segment_fft_vec.shape[0]
# Similarity matrix info
info_msg = 'Building {0}x{0} fourier similarity matrix'.format(n)
print(info_msg)
# Build similarity matrix from fft norms
fft_similarity_mat = np.empty((n, n), dtype=np.float64)
for i in tqdm(range(n)):
for j in range(n):
initial = np.sum(initial_wav_vec[i] ** 2) * initial_fft_vec[i]
segment = np.sum(segment_wav_vec[j] ** 2) * segment_fft_vec[j]
fft_similarity_mat[i, j] = np.linalg.norm(initial - segment)
# Save
np.save(fft_similarity_checkpoint, fft_similarity_mat)
# Save info
info_msg = 'Saved similarity matrix to {}'.format(fft_similarity_checkpoint)
print(info_msg)
# Bipartite match
global bipartite_match
bipartite_match_checkpoint = os.path.join(checkpoint_dir, 'bipartite_match.npy')
if os.path.isfile(bipartite_match_checkpoint):
# Info
info_msg = 'Loading bipartite match from {}'.format(bipartite_match_checkpoint)
print(info_msg)
# Load checkpoint
bipartite_match = np.load(bipartite_match_checkpoint)
else:
# Bipartite match info
print('Starting Munkres bipartite match')
# Munkres matching alg
m = Munkres()
bipartite_match = m.compute(fft_similarity_mat)
# Save
np.save(bipartite_match_checkpoint, bipartite_match)
# Save info
info_msg = 'Saved bipartite match results to {}'.format(bipartite_match_checkpoint)
print(info_msg)
ordered_segment_wav_vec = np.empty(segment_wav_vec.shape)
for i, j in bipartite_match:
ordered_segment_wav_vec[i] = segment_wav_vec[j]
plot_segments(initial_wav_vec, ordered_segment_wav_vec)
# for i, row in enumerate(fft_similarity_mat):
# pairs = []
# for j, x in enumerate(row):
# if x == 0:
# pairs.append((j, x))
#
# print(i, ': ', pairs)
for row in bipartite_match:
print(row[0] + 1, ':', row[1] + 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment