Created
November 15, 2021 04:59
-
-
Save purple4reina/500d042eeece2f8794da4b2c0e06131e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://realpython.com/generative-adversarial-networks/ | |
# https://salu133445.github.io/lakh-pianoroll-dataset/ | |
import torch | |
from torch import nn | |
import time | |
import json | |
import math | |
import matplotlib.pyplot as plt | |
import pypianoroll as ppr | |
import numpy as np | |
import os | |
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" | |
import pygame | |
lr = 0.001 | |
num_epochs = 10000 | |
data_length = 1024 | |
samples = 32 | |
with open('final-project/datasets/lpd/lpd_cleansed/midi_info_v2.json') as f: | |
metadata = json.loads(f.read()) | |
row_length = data_length * 128 | |
train_data = [] | |
for file, data in metadata.items(): | |
if not data['constant_tempo']: | |
continue | |
if data['tempo'] != 96: | |
continue | |
filename = f'final-project/datasets/lpd/lpd_full/{file[0]}/{file}.npz' | |
mtrack = ppr.load(filename) | |
for track in mtrack.tracks: | |
if 'clarinet' not in track.name.lower(): | |
continue | |
piece = track.pianoroll | |
if piece.shape[0] < data_length: | |
continue | |
data = piece[:data_length].reshape(row_length) | |
data = data.astype(np.float32) | |
train_data.append(data) | |
samples -= 1 | |
if not samples: | |
break | |
assert not samples, samples | |
train_data_length = len(train_data) | |
train_labels = torch.zeros(train_data_length) | |
train_set = [ | |
(train_data[i], train_labels[i]) for i in range(train_data_length) | |
] | |
batch_size = 32 | |
train_loader = torch.utils.data.DataLoader( | |
train_set, batch_size=batch_size, shuffle=True | |
) | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.Linear(row_length, 256), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(128, 64), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(64, 1), | |
nn.Sigmoid(), | |
) | |
def forward(self, x): | |
output = self.model(x) | |
return output | |
discriminator = Discriminator() | |
class Generator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.Linear(1024, 512), | |
nn.ReLU(), | |
nn.Linear(512, 128), | |
nn.ReLU(), | |
nn.Linear(128, 32), | |
nn.ReLU(), | |
nn.Linear(32, row_length), | |
nn.Sigmoid(), | |
) | |
def forward(self, x): | |
output = self.model(x) | |
return output | |
generator = Generator() | |
loss_function = nn.BCELoss() | |
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr) | |
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr) | |
for epoch in range(num_epochs): | |
for real_samples, _ in train_loader: | |
# Data for training the discriminator | |
real_samples_labels = torch.ones((batch_size, 1)) | |
latent_space_samples = torch.randn((batch_size, 1024)) | |
generated_samples = generator(latent_space_samples) | |
generated_samples_labels = torch.zeros((batch_size, 1)) | |
all_samples = torch.cat((real_samples, generated_samples)) | |
all_samples_labels = torch.cat( | |
(real_samples_labels, generated_samples_labels) | |
) | |
# Training the discriminator | |
discriminator.zero_grad() | |
output_discriminator = discriminator(all_samples) | |
loss_discriminator = loss_function( | |
output_discriminator, all_samples_labels) | |
loss_discriminator.backward() | |
optimizer_discriminator.step() | |
# Data for training the generator | |
latent_space_samples = torch.randn((batch_size, 1024)) | |
# Training the generator | |
generator.zero_grad() | |
generated_samples = generator(latent_space_samples) | |
output_discriminator_generated = discriminator(generated_samples) | |
loss_generator = loss_function( | |
output_discriminator_generated, real_samples_labels | |
) | |
loss_generator.backward() | |
optimizer_generator.step() | |
# Show loss | |
print(f"Epoch: {epoch} Loss D.: {loss_discriminator} Loss G.: {loss_generator}") | |
latent_space_samples = torch.randn(1, 1024) | |
generated_samples = generator(latent_space_samples) | |
generated_samples = generated_samples.detach() | |
piece = generated_samples.reshape((data_length, 128)) | |
piece[piece>0.99] = 1 | |
piece[piece<=0.99] = 0 | |
track = ppr.BinaryTrack(pianoroll=piece, name='piano') | |
mtrack = ppr.Multitrack(tracks=[track]) | |
fname = 'final-project/original.mid' | |
mtrack.write(fname) | |
pygame.init() | |
pygame.mixer.music.load(fname) | |
pygame.mixer.music.play() | |
mtrack.plot() | |
plt.show() | |
while pygame.mixer.music.get_busy(): | |
time.sleep(0.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment