- Read 4 batchs of 10 audio signals sequentially without converting into spectrogram via numpy (code not shown in this gist): 30.084 sec
- Read 4 batchs of 10 audio signals and convert into dB mel spectrograms via tf.data.Dataset sequentially without optimization: 9.423 sec
- Read 4 batchs of 10 audio signals and convert into dB mel spectrograms via tf.data.Dataset with optimization: 3.922sec
Last active
September 27, 2020 06:13
-
-
Save nwatab/aa068bd5ff976ab07be337c65e73504c to your computer and use it in GitHub Desktop.
Load mp3 data efficiently (Tensorflow==2.3.0)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
import tensorflow_io as tfio | |
duration = 5 | |
rate = 44100 | |
samples = duration * rate | |
def dbscale(input, top_db, name=None): | |
""" | |
Turn spectrogram into db scale | |
Args: | |
input: A spectrogram Tensor. | |
top_db: Minimum negative cut-off `max(10 * log10(S)) - top_db` | |
name: A name for the operation (optional). | |
Returns: | |
A tensor of mel spectrogram with shape [frames, mels]. | |
""" | |
power = tf.math.square(input) | |
log_spec = 10.0 * (tf.math.log(power) / tf.math.log(10.0)) | |
log_spec = tf.math.maximum(log_spec, tf.math.reduce_max(log_spec) - top_db) | |
return log_spec | |
def process_paths(file_paths): | |
"""receive a batch of filepaths and return a batch of db mel-spectrogram and batch of labels | |
params | |
filepaths: string[] | |
return | |
(db_mel_spectrograms: Tensor<tf.float32>, labels: Tensor<tf.string>) | |
""" | |
db_mel_spectrograms, labels = tf.map_fn( | |
fn=process_path, | |
elems=file_paths, | |
fn_output_signature=(tf.float32, tf.string) | |
) | |
return db_mel_spectrograms, labels | |
def process_path(file_path): | |
""" path => dB mel-spectrogram for 5 sec (pad zero if less than 5sec before making spectrogram)""" | |
label = tf.strings.split(file_path, os.sep)[-2] | |
audio = tfio.audio.AudioIOTensor(file_path, dtype=tf.float32) | |
re_audio = tfio.audio.resample( | |
audio.to_tensor(), | |
rate_in=tf.cast(audio.rate, tf.int64), | |
rate_out=rate | |
) | |
re_audio_1c = tf.reduce_mean(re_audio, axis=-1) # [samples, channels] => [samples] | |
zeros = tf.math.maximum(samples - tf.shape(re_audio_1c)[0], 0) | |
paddings = [[zeros // 2, zeros // 2 + zeros % 2]] | |
pad_audio = tf.pad(re_audio_1c, paddings=paddings, mode='CONSTANT') # pad if audio is too short | |
cropped_audio = tf.image.random_crop(pad_audio, [samples]) # | |
spectrogram = tf.math.abs( | |
tf.signal.stft( | |
cropped_audio, | |
frame_length=512, | |
frame_step=256, | |
fft_length=512, | |
window_fn=tf.signal.hann_window, | |
pad_end=True, | |
) | |
) | |
mel_spectrogram = tfio.experimental.audio.melscale( | |
spectrogram, rate=44100, mels=128, fmin=1000, fmax=12000 | |
) | |
dbscale_mel_spectrogram = tfio.experimental.audio.dbscale( | |
mel_spectrogram, top_db=80) | |
# It seems this function cannot return string for | |
# File "/Users/n/work/birdR/machine_learning/.venv/lib/python3.7/site-packages/tensorflow/python/eager/executor.py", line 67, in wait | |
# pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle) | |
# tensorflow.python.framework.errors_impl.InvalidArgumentError: unable to seek to: 0 | |
# [[{{node IO>AudioReadableRead}}]] | |
return dbscale_mel_spectrogram, label | |
if __name__ == '__main__': | |
import time | |
import matplotlib.pyplot as plt | |
glob_path = './data/train/*/*.mp3' | |
# Optimized way | |
data = ( | |
tf.data.Dataset.range(2) | |
.interleave( | |
lambda *args: tf.data.Dataset.list_files(glob_path), # args is 0 or 1 from tf.data.Dataset.range(2) but not used | |
num_parallel_calls=tf.data.experimental.AUTOTUNE | |
) | |
.batch( | |
4, | |
drop_remainder=True | |
) | |
.map( | |
process_paths, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE | |
) | |
.cache() | |
.prefetch(tf.data.experimental.AUTOTUNE) | |
# .unbatch() | |
) | |
start_time = time.perf_counter() | |
i = 0 | |
for audio in data: | |
i += 1 | |
time.sleep(0.01) | |
if i > 10: | |
break | |
print(time.perf_counter() - start_time) | |
# Normal way | |
start_time = time.perf_counter() | |
i = 0 | |
for _ in tf.data.Dataset.list_files('./data/xeno_canto_ja/train/*/*.mp3').map(process_path).batch(4): | |
i += 1 | |
time.sleep(0.01) | |
if i > 10: | |
break | |
print(time.perf_counter() - start_time) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$python opt_dataloader.py | |
2020-09-27 00:33:41.905781: I tensorflow_io/core/kernels/cpu_check.cc:128] Your CPU supports instructions that this TensorFlow IO binary was not compiled to use: AVX2 FMA | |
2020-09-27 00:33:42.191928: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations: AVX2 FMA | |
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. | |
2020-09-27 00:33:42.202688: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fb7baafcbc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices: | |
2020-09-27 00:33:42.202706: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version | |
2020-09-27 00:33:49.059625: W tensorflow/core/kernels/data/cache_dataset_ops.cc:798] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. | |
3.9223966150000003 | |
9.423265921999999 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment