Last active
October 7, 2022 12:15
-
-
Save gante/b7fde5df116e95f22f8716062dfa5079 to your computer and use it in GitHub Desktop.
OpenAI Whisper Benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from datetime import timedelta | |
from functools import wraps | |
from tqdm import tqdm | |
# PyTorch imports and settings | |
import torch | |
from transformers.testing_utils import torch_device | |
torch.backends.cuda.matmul.allow_tf32 = True # All frameworks using TF32 | |
# TF imports and settings | |
import tensorflow as tf | |
physical_devices = tf.config.list_physical_devices('GPU') | |
tf.config.experimental.set_memory_growth(physical_devices[0], True) | |
# Transformers imports last, as they would import the framerowks (and we want to set options before) | |
from transformers import ( | |
WhisperProcessor, TFWhisperForConditionalGeneration, WhisperForConditionalGeneration | |
) | |
from datasets import load_dataset | |
MODEL_NAME = "openai/whisper-base" | |
NUM_RUNS = 100 | |
MAX_NEW_TOKENS = 256 | |
NUM_BEAMS = 1 | |
BATCH_SIZE = 16 | |
DS = load_dataset("librispeech_asr", "clean", split="validation") | |
SPEECH_SAMPLES = DS.select(range(NUM_RUNS * BATCH_SIZE))[:NUM_RUNS * BATCH_SIZE]["audio"] | |
def measure_time(function): | |
""" Decorator to print execution time of a function """ | |
@wraps(function) | |
def decorated(inputs): | |
start = time.time() | |
fn_output = function(inputs) | |
end = time.time() | |
duration = timedelta(seconds=end - start) | |
return fn_output, duration | |
return decorated | |
def get_processor(model_name): | |
processor = WhisperProcessor.from_pretrained(model_name) | |
return processor | |
def get_inputs(processor, index, return_tensors): | |
input_speech = [SPEECH_SAMPLES[(index * BATCH_SIZE) + i]["array"] for i in range(BATCH_SIZE)] | |
inputs = processor.feature_extractor(raw_speech=input_speech, return_tensors=return_tensors) | |
if return_tensors == "tf": | |
inputs["input_ids"] = inputs.pop("input_features") | |
return inputs | |
def get_model(framework, model_name): | |
if framework == "tf": | |
model = TFWhisperForConditionalGeneration.from_pretrained(model_name, from_pt=True) | |
elif framework == "pt": | |
model = WhisperForConditionalGeneration.from_pretrained(model_name) | |
return model | |
def print_status(all_outputs, all_durations): | |
print(f"Execution time -- 1st call: {all_durations[0]/1000:.2f} ms") | |
all_durations = all_durations[1:] | |
print(f"Execution time -- mean: {(sum(all_durations)/len(all_durations))/1000:.2f} ms") | |
all_outputs = all_outputs[1:] | |
try: | |
mean_length = sum([out.shape[1] for out in all_outputs]) / len(all_outputs) | |
except: | |
mean_length = sum([out.sequences.shape[1] for out in all_outputs]) / len(all_outputs) | |
print(f"Outputs -- mean length: {mean_length:.2f} tokens") | |
def main_tf_eager(): | |
model = get_model("tf", MODEL_NAME) | |
processor = get_processor(MODEL_NAME) | |
@measure_time | |
def _generate(inputs): | |
inputs.update({"num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS}) | |
return model.generate(**inputs) | |
all_durations = [] | |
all_outputs = [] | |
for i in tqdm(range(int(NUM_RUNS / 10))): # TF Eager is very slow :( | |
inputs = get_inputs(processor, index=i, return_tensors="tf") | |
gen_out, duration = _generate(inputs) | |
all_durations.append(duration.microseconds + duration.seconds * 1e6) | |
all_outputs.append(gen_out) | |
print_status(all_outputs, all_durations) | |
return all_outputs | |
def main_tf_xla(): | |
model = get_model("tf", MODEL_NAME) | |
processor = get_processor(MODEL_NAME) | |
xla_generate = tf.function(model.generate, jit_compile=True) | |
@measure_time | |
def _generate(inputs): | |
inputs.update({"num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS}) | |
return xla_generate(**inputs) | |
all_durations = [] | |
all_outputs = [] | |
for i in tqdm(range(NUM_RUNS)): | |
inputs = get_inputs(processor, index=i, return_tensors="tf") | |
gen_out, duration = _generate(inputs) | |
all_durations.append(duration.microseconds + duration.seconds * 1e6) | |
all_outputs.append(gen_out) | |
print_status(all_outputs, all_durations) | |
return all_outputs | |
def main_pt(): | |
model = get_model("pt", MODEL_NAME) | |
model.to(torch_device) | |
processor = get_processor(MODEL_NAME) | |
@measure_time | |
def _generate(inputs): | |
with torch.no_grad(): | |
inputs.update({"num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS}) | |
return model.generate(**inputs) | |
all_durations = [] | |
all_outputs = [] | |
for i in tqdm(range(NUM_RUNS)): | |
inputs = get_inputs(processor, index=i, return_tensors="pt") | |
inputs.to(torch_device) | |
gen_out, duration = _generate(inputs) | |
all_durations.append(duration.microseconds + duration.seconds * 1e6) | |
all_outputs.append(gen_out) | |
print_status(all_outputs, all_durations) | |
del model | |
torch.cuda.empty_cache() | |
return all_outputs | |
if __name__ == "__main__": | |
print("\n\nPYTORCH") | |
main_pt() | |
print("\n\nTF (NO XLA)") | |
main_tf_eager() | |
print("\n\nTF (XLA)") | |
main_tf_xla() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment