Skip to content

Instantly share code, notes, and snippets.

@gante
Last active October 7, 2022 12:15
Show Gist options
  • Save gante/b7fde5df116e95f22f8716062dfa5079 to your computer and use it in GitHub Desktop.
Save gante/b7fde5df116e95f22f8716062dfa5079 to your computer and use it in GitHub Desktop.
OpenAI Whisper Benchmark
import time
from datetime import timedelta
from functools import wraps
from tqdm import tqdm
# PyTorch imports and settings
import torch
from transformers.testing_utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = True # All frameworks using TF32
# TF imports and settings
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
# Transformers imports last, as they would import the framerowks (and we want to set options before)
from transformers import (
WhisperProcessor, TFWhisperForConditionalGeneration, WhisperForConditionalGeneration
)
from datasets import load_dataset
MODEL_NAME = "openai/whisper-base"
NUM_RUNS = 100
MAX_NEW_TOKENS = 256
NUM_BEAMS = 1
BATCH_SIZE = 16
DS = load_dataset("librispeech_asr", "clean", split="validation")
SPEECH_SAMPLES = DS.select(range(NUM_RUNS * BATCH_SIZE))[:NUM_RUNS * BATCH_SIZE]["audio"]
def measure_time(function):
""" Decorator to print execution time of a function """
@wraps(function)
def decorated(inputs):
start = time.time()
fn_output = function(inputs)
end = time.time()
duration = timedelta(seconds=end - start)
return fn_output, duration
return decorated
def get_processor(model_name):
processor = WhisperProcessor.from_pretrained(model_name)
return processor
def get_inputs(processor, index, return_tensors):
input_speech = [SPEECH_SAMPLES[(index * BATCH_SIZE) + i]["array"] for i in range(BATCH_SIZE)]
inputs = processor.feature_extractor(raw_speech=input_speech, return_tensors=return_tensors)
if return_tensors == "tf":
inputs["input_ids"] = inputs.pop("input_features")
return inputs
def get_model(framework, model_name):
if framework == "tf":
model = TFWhisperForConditionalGeneration.from_pretrained(model_name, from_pt=True)
elif framework == "pt":
model = WhisperForConditionalGeneration.from_pretrained(model_name)
return model
def print_status(all_outputs, all_durations):
print(f"Execution time -- 1st call: {all_durations[0]/1000:.2f} ms")
all_durations = all_durations[1:]
print(f"Execution time -- mean: {(sum(all_durations)/len(all_durations))/1000:.2f} ms")
all_outputs = all_outputs[1:]
try:
mean_length = sum([out.shape[1] for out in all_outputs]) / len(all_outputs)
except:
mean_length = sum([out.sequences.shape[1] for out in all_outputs]) / len(all_outputs)
print(f"Outputs -- mean length: {mean_length:.2f} tokens")
def main_tf_eager():
model = get_model("tf", MODEL_NAME)
processor = get_processor(MODEL_NAME)
@measure_time
def _generate(inputs):
inputs.update({"num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS})
return model.generate(**inputs)
all_durations = []
all_outputs = []
for i in tqdm(range(int(NUM_RUNS / 10))): # TF Eager is very slow :(
inputs = get_inputs(processor, index=i, return_tensors="tf")
gen_out, duration = _generate(inputs)
all_durations.append(duration.microseconds + duration.seconds * 1e6)
all_outputs.append(gen_out)
print_status(all_outputs, all_durations)
return all_outputs
def main_tf_xla():
model = get_model("tf", MODEL_NAME)
processor = get_processor(MODEL_NAME)
xla_generate = tf.function(model.generate, jit_compile=True)
@measure_time
def _generate(inputs):
inputs.update({"num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS})
return xla_generate(**inputs)
all_durations = []
all_outputs = []
for i in tqdm(range(NUM_RUNS)):
inputs = get_inputs(processor, index=i, return_tensors="tf")
gen_out, duration = _generate(inputs)
all_durations.append(duration.microseconds + duration.seconds * 1e6)
all_outputs.append(gen_out)
print_status(all_outputs, all_durations)
return all_outputs
def main_pt():
model = get_model("pt", MODEL_NAME)
model.to(torch_device)
processor = get_processor(MODEL_NAME)
@measure_time
def _generate(inputs):
with torch.no_grad():
inputs.update({"num_beams": NUM_BEAMS, "max_new_tokens": MAX_NEW_TOKENS})
return model.generate(**inputs)
all_durations = []
all_outputs = []
for i in tqdm(range(NUM_RUNS)):
inputs = get_inputs(processor, index=i, return_tensors="pt")
inputs.to(torch_device)
gen_out, duration = _generate(inputs)
all_durations.append(duration.microseconds + duration.seconds * 1e6)
all_outputs.append(gen_out)
print_status(all_outputs, all_durations)
del model
torch.cuda.empty_cache()
return all_outputs
if __name__ == "__main__":
print("\n\nPYTORCH")
main_pt()
print("\n\nTF (NO XLA)")
main_tf_eager()
print("\n\nTF (XLA)")
main_tf_xla()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment