Created
January 23, 2019 10:26
-
-
Save taras-sereda/130ef385382577ee7f8c6e9b2ce9dd90 to your computer and use it in GitHub Desktop.
waveglow CPU inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/glow_old.py b/glow_old.py | |
index 0de2375..5895300 100644 | |
--- a/glow_old.py | |
+++ b/glow_old.py | |
@@ -183,7 +183,7 @@ class WaveGlow(torch.nn.Module): | |
self.n_remaining_channels, | |
spect.size(2)).normal_() | |
else: | |
- audio = torch.cuda.FloatTensor(spect.size(0), | |
+ audio = torch.FloatTensor(spect.size(0), | |
self.n_remaining_channels, | |
spect.size(2)).normal_() | |
@@ -215,7 +215,7 @@ class WaveGlow(torch.nn.Module): | |
self.n_early_size, | |
spect.size(2)).normal_() | |
else: | |
- z = torch.cuda.FloatTensor(spect.size(0), | |
+ z = torch.FloatTensor(spect.size(0), | |
diff --git a/glow_old.py b/glow_old.py | |
index 0de2375..5895300 100644 | |
--- a/glow_old.py | |
+++ b/glow_old.py | |
@@ -183,7 +183,7 @@ class WaveGlow(torch.nn.Module): | |
self.n_remaining_channels, | |
spect.size(2)).normal_() | |
else: | |
- audio = torch.cuda.FloatTensor(spect.size(0), | |
+ audio = torch.FloatTensor(spect.size(0), | |
self.n_remaining_channels, | |
spect.size(2)).normal_() | |
@@ -215,7 +215,7 @@ class WaveGlow(torch.nn.Module): | |
self.n_early_size, | |
spect.size(2)).normal_() | |
else: | |
- z = torch.cuda.FloatTensor(spect.size(0), | |
+ z = torch.FloatTensor(spect.size(0), | |
self.n_early_size, | |
spect.size(2)).normal_() | |
audio = torch.cat((sigma*z, audio),1) | |
diff --git a/inference.py b/inference.py | |
index 2c67605..61cf6f2 100644 | |
--- a/inference.py | |
+++ b/inference.py | |
@@ -32,9 +32,9 @@ from mel2samp import files_to_list, MAX_WAV_VALUE | |
def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16): | |
mel_files = files_to_list(mel_files) | |
- waveglow = torch.load(waveglow_path)['model'] | |
+ waveglow = torch.load(waveglow_path, map_location=lambda storage, loc: storage)['model'] | |
waveglow = waveglow.remove_weightnorm(waveglow) | |
- waveglow.cuda().eval() | |
+ waveglow.eval() | |
if is_fp16: | |
waveglow.half() | |
for k in waveglow.convinv: | |
@@ -43,7 +43,7 @@ def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16): | |
for i, file_path in enumerate(mel_files): | |
file_name = os.path.splitext(os.path.basename(file_path))[0] | |
mel = torch.load(file_path) | |
- mel = torch.autograd.Variable(mel.cuda()) | |
+ mel = torch.autograd.Variable(mel) | |
mel = torch.unsqueeze(mel, 0) | |
mel = mel.half() if is_fp16 else mel | |
with torch.no_grad(): | |
diff --git a/tacotron2 b/tacotron2 | |
--- a/tacotron2 | |
+++ b/tacotron2 | |
@@ -1 +1 @@ | |
-Subproject commit fc0cf6a89a47166350b65daa1beaa06979e4cddf | |
+Subproject commit fc0cf6a89a47166350b65daa1beaa06979e4cddf-dirty |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment