Skip to content

Instantly share code, notes, and snippets.

@chaonan99
Created May 23, 2019 01:33
Show Gist options
  • Save chaonan99/20686fb6e94af757ac5a4c28fa0ff018 to your computer and use it in GitHub Desktop.
Save chaonan99/20686fb6e94af757ac5a4c28fa0ff018 to your computer and use it in GitHub Desktop.
Experiment with tf record file provided by onsets and frames.
import os
import tensorflow as tf
from magenta.models.onsets_frames_transcription import configs
from magenta.models.onsets_frames_transcription import constants
from magenta.models.onsets_frames_transcription import data
from magenta.models.onsets_frames_transcription import split_audio_and_label_data
from magenta.models.onsets_frames_transcription import train_util
from magenta.music import midi_io
from magenta.protobuf import music_pb2
from magenta.music import sequences_lib
from magenta.common import flatten_maybe_padded_sequences
def main():
config = configs.CONFIG_MAP['onsets_frames']
hparams = config.hparams
hparams.use_cudnn = False
hparams.batch_size = 1
examples = os.path.join(os.environ['DATA_DIR'], 'onset_maestro_v1/train.tfrecord*')
filenames = tf.data.Dataset.list_files(examples, shuffle=False)
input_dataset = tf.data.TFRecordDataset(filenames)
iterator = input_dataset.make_initializable_iterator()
# iterator = input_dataset.make_one_shot_iterator()
a = iterator.get_next()
features = {
'id': tf.FixedLenFeature(shape=(), dtype=tf.string),
'sequence': tf.FixedLenFeature(shape=(), dtype=tf.string),
'audio': tf.FixedLenFeature(shape=(), dtype=tf.string),
'velocity_range': tf.FixedLenFeature(shape=(), dtype=tf.string),
}
record = tf.parse_single_example(a, features)
sequence_id, sequence, audio, velocity_range = \
record['id'], record['sequence'], record['audio'], record['velocity_range']
sess = tf.Session()
sess.run(iterator.initializer)
# Dump wav
audio_content = audio.eval(session=sess)
with open('test.wav', 'wb') as f:
f.write(audio_content)
## Dump midi
sequence_content = sequence.eval(session=sess)
note_sequence = music_pb2.NoteSequence.FromString(sequence_content)
midi_filename = ('test_direct.midi')
midi_io.sequence_proto_to_midi_file(note_sequence, midi_filename)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment