Created
January 16, 2022 00:12
-
-
Save Algomancer/885f7dd645b3f5b03c8b5e08484b70b2 to your computer and use it in GitHub Desktop.
Spectral augmentation using Mozaicing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author Adam Hibble @algomancer | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import tqdm | |
def get_padding(padding_type, kernel_size): | |
assert padding_type in ['SAME', 'VALID'] | |
if padding_type == 'SAME': | |
return tuple((k - 1) // 2 for k in kernel_size) | |
return tuple(0 for _ in kernel_size) | |
def maximum_filter(input, kernel_size=None): | |
"""Calculate multidimensional maximum filter. | |
returns maximum_filter Has the same shape as `input`. | |
""" | |
should_squeeze = False | |
if len(input.shape) == 2: | |
input = input[None, :, :] | |
should_squeeze = True | |
x = F.max_pool2d(input, kernel_size, stride=1, padding=get_padding('SAME', kernel_size)) | |
if should_squeeze: | |
x = x.squeeze(0) | |
return x | |
class NMFMozaicing(nn.Module): | |
def __init__(self, r_width, c_width, polyphony, iterations): | |
""" | |
r_width: Width of the repeated activation filter | |
c_width: Half length of time-continuous activation filter | |
polyphony: Number of polyphonic voices | |
""" | |
super(NMFMozaicing, self).__init__() | |
self.r_width = r_width | |
self.c_width = c_width | |
self.polyphony = polyphony | |
self.iterations = iterations | |
def step(self, activation_matrix, factor): | |
#Step 1: Avoid repeated activations | |
K, N = activation_matrix.shape | |
activation_filter = maximum_filter(activation_matrix, kernel_size=self.r_width) | |
activation_matrix[activation_matrix < activation_filter] = activation_matrix[activation_matrix < activation_filter] * factor | |
#Step 2: Restrict number of simultaneous activations | |
cut_off = torch.topk(activation_matrix, self.polyphony+1, dim=0)[0][self.polyphony, :] | |
activation_matrix[activation_matrix > cut_off[None, :]] = activation_matrix[activation_matrix > cut_off[None, :]] * factor | |
#Step 3: Supporting time-continuous activations | |
di = K-1 | |
dj = 0 | |
for k in range(-activation_matrix.shape[0]+1, activation_matrix.shape[1]): | |
z = torch.cumsum(torch.cat((torch.zeros(self.c_width).to(activation_matrix.device), torch.diag(activation_matrix, k), torch.zeros(self.c_width).to(activation_matrix.device))), dim=0) | |
x2 = z[2*self.c_width::] - z[0:-2*self.c_width] | |
activation_matrix[di+torch.arange(len(x2)), dj+torch.arange(len(x2))] = x2 | |
if di == 0: | |
dj += 1 | |
else: | |
di -= 1 | |
return activation_matrix | |
def forward(self, target, template): | |
N, K = target.shape[1], template.shape[1] | |
activation_matrix = torch.rand(K, N).to(target.device) | |
for i in tqdm.tqdm(range(self.iterations)): | |
factor = 1 - (i+1)/self.iterations | |
activation_matrix = self.step(activation_matrix, factor) | |
#print(template.shape, activation_matrix.shape) | |
template_weighted = template.matmul(activation_matrix) | |
template_weighted[template_weighted == 0] = 1 | |
target_lamda = (target / template_weighted) | |
template_denom = torch.sum(template, 0) | |
template_denom[template_denom == 0] = 1 | |
inner = template.t().matmul(target_lamda) / template_denom[:, None] | |
activation_matrix = activation_matrix * inner | |
return activation_matrix |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment