Last active
July 30, 2020 02:15
-
-
Save redwrasse/937ad45951158329b97d8203a50971b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Currently trains with decreasing loss | |
*** epoch: 0 epoch loss: 276.47448682785034 | |
*** epoch: 1 epoch loss: 216.9058997631073 | |
*** epoch: 2 epoch loss: 190.01888144016266 | |
*** epoch: 3 epoch loss: 171.68642991781235 | |
*** epoch: 4 epoch loss: 157.7317717075348 | |
*** epoch: 5 epoch loss: 145.89844578504562 | |
... | |
... | |
*** epoch: 90 epoch loss: 11.323879387229681 | |
*** epoch: 91 epoch loss: 11.176103946752846 | |
*** epoch: 92 epoch loss: 11.033554057590663 | |
*** epoch: 93 epoch loss: 10.898204608820379 | |
""" | |
import torchaudio | |
import torch | |
import torch.nn.functional as F | |
def train_ar_generative_model(): | |
""" | |
Train an auto-regressive generative model | |
simplest model = single convolutional layer | |
with softmax per sample activation | |
line up inputs and outputs appropriately, and | |
optimize with a per-sample cross-entropy loss. | |
""" | |
yesno_data = torchaudio.datasets.YESNO('./', | |
download=True) | |
data_loader = torch.utils.data.DataLoader(yesno_data, | |
batch_size=1, | |
shuffle=True, | |
num_workers=1) | |
KERNEL_SIZE = 100 | |
def left_pad(x): | |
# left pad x with KERNEL_SIZE - 1 zeros to the left | |
return F.pad(x, | |
pad=[KERNEL_SIZE - 1, 0], | |
mode='constant', | |
value=0) | |
conv_layer = torch.nn.Conv1d( | |
in_channels=256, | |
out_channels=256, | |
kernel_size=KERNEL_SIZE | |
) | |
def loss_criterion(output, input): | |
modified_output = output[:, :, KERNEL_SIZE - 1:-1] | |
modified_input = torch.squeeze(input[:, :, KERNEL_SIZE:], | |
dim=1) | |
loss_fn = torch.nn.CrossEntropyLoss() | |
return loss_fn(modified_output, modified_input) | |
optimizer = torch.optim.SGD(conv_layer.parameters(), | |
lr=1e-1) | |
encoding = torchaudio.transforms.MuLawEncoding(quantization_channels=256) | |
nepochs = 10**3 | |
for epoch in range(nepochs): | |
epoch_loss = 0. | |
# break into chunks to spend less computation time on each iteration | |
for i, sample in enumerate(data_loader): | |
waveform, sample_rate, labels = sample | |
n = waveform.shape[-1] | |
n_sample = int(KERNEL_SIZE * 1.5) | |
for j in range(0, n_sample, n - n_sample): # may not be complete | |
waveform_chunk = waveform[:, :, j: j + n_sample] | |
categorical_input = encoding(waveform_chunk) | |
assert categorical_input.shape[:2] == (1, 1) | |
input = torch.squeeze(torch.nn.functional.one_hot(categorical_input, 256), | |
dim=0).permute(0, 2, 1).float() | |
assert input.shape[:2] == (1, 256) | |
optimizer.zero_grad() | |
lp_input = left_pad(input) | |
assert lp_input.shape[:2] == (1, 256) | |
assert lp_input.shape[-1] == (input.shape[-1] + KERNEL_SIZE - 1) | |
output = conv_layer(lp_input) | |
assert output.shape[:2] == (1, 256) | |
assert output.shape[-1] == input.shape[-1] | |
loss = loss_criterion(output, categorical_input) | |
loss.backward() | |
epoch_loss += loss.item() | |
optimizer.step() | |
#print(f'sample chunk loss: {loss.item()}') | |
print(f'*** epoch: {epoch} epoch loss: {epoch_loss}') | |
def download_process_data(): | |
# following the torchaudio docs: | |
# https://pytorch.org/audio/datasets.html | |
yesno_data = torchaudio.datasets.YESNO('./', | |
download=True) | |
data_loader = torch.utils.data.DataLoader(yesno_data, | |
batch_size=1, | |
shuffle=True, | |
num_workers=1) | |
sample_ct = 0 | |
for i, sample in enumerate(data_loader): | |
# from the torchaudio docs: | |
# Each item is a tuple of the form: (waveform, sample_rate, labels) | |
# each waveform of shape [1, 1, n] where n seems to vary between | |
# ~ 45000 and 55000 eg. single-channel waveform of variable length | |
# sample rate is 8000 for all samples | |
waveform, sample_rate, labels = sample | |
assert torch.equal(sample_rate, torch.LongTensor([8000,])) | |
# mu-quantization and reshaped | |
# assumed and verified signal already given between range -1 and 1, | |
# necessary for mu encoding | |
assert -1. < waveform.min() < 1. | |
assert -1. < waveform.max() < 1. | |
encoding = torchaudio.transforms.MuLawEncoding(quantization_channels=256) | |
quantized_waveform = torch.squeeze(torch.nn.functional.one_hot(encoding(waveform), 256), | |
dim=0).permute(0, 2, 1) | |
# shape is (1, 256, n) | |
assert quantized_waveform.shape[:2] == (1, 256) | |
sample_ct += 1 | |
assert sample_ct == 60, 'expected 60 samples' | |
if __name__ == "__main__": | |
train_ar_generative_model() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment