Skip to content

Instantly share code, notes, and snippets.

@aleksandr-smechov
Last active November 17, 2023 05:09
Show Gist options
  • Save aleksandr-smechov/c789caa0b65772865a3dc1e60e0f2c5d to your computer and use it in GitHub Desktop.
Save aleksandr-smechov/c789caa0b65772865a3dc1e60e0f2c5d to your computer and use it in GitHub Desktop.
Client-side distil-whisper streaming script
import datetime
import platform
import subprocess
import sys
import asyncio
import websockets
import numpy as np
from typing import Tuple, Optional, Union
def _ffmpeg_stream(ffmpeg_command, buflen: int):
"""
Internal function to create the generator of data through ffmpeg
"""
bufsize = 2**24
try:
with subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, bufsize=bufsize) as ffmpeg_process:
while True:
raw = ffmpeg_process.stdout.read(buflen)
if raw == b"":
break
yield raw
except FileNotFoundError as error:
raise ValueError("ffmpeg was not found but is required to stream audio files from filename") from error
def ffmpeg_microphone(
sampling_rate: int,
chunk_length_s: float,
format_for_conversion: str = "f32le",
):
"""
Helper function ro read raw microphone data.
"""
ar = f"{sampling_rate}"
ac = "1"
if format_for_conversion == "s16le":
size_of_sample = 2
elif format_for_conversion == "f32le":
size_of_sample = 4
else:
raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
system = platform.system()
if system == "Linux":
format_ = "alsa"
input_ = "default"
elif system == "Darwin":
format_ = "avfoundation"
input_ = ":0"
elif system == "Windows":
format_ = "dshow"
input_ = "default"
ffmpeg_command = [
"ffmpeg",
"-f",
format_,
"-i",
input_,
"-ac",
ac,
"-ar",
ar,
"-f",
format_for_conversion,
"-fflags",
"nobuffer",
"-hide_banner",
"-loglevel",
"quiet",
"pipe:1",
]
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
iterator = _ffmpeg_stream(ffmpeg_command, chunk_len)
for item in iterator:
yield item
def chunk_bytes_iter(iterator, chunk_len: int, stride: Tuple[int, int], stream: bool = False):
"""
Reads raw bytes from an iterator and does chunks of length `chunk_len`. Optionally adds `stride` to each chunks to
get overlaps. `stream` is used to return partial results even if a full `chunk_len` is not yet available.
"""
acc = b""
stride_left, stride_right = stride
if stride_left + stride_right >= chunk_len:
raise ValueError(
f"Stride needs to be strictly smaller than chunk_len: ({stride_left}, {stride_right}) vs {chunk_len}"
)
_stride_left = 0
for raw in iterator:
acc += raw
if stream and len(acc) < chunk_len:
stride = (_stride_left, 0)
yield {"raw": acc[:chunk_len], "stride": stride, "partial": True}
else:
while len(acc) >= chunk_len:
stride = (_stride_left, stride_right)
item = {"raw": acc[:chunk_len], "stride": stride}
if stream:
item["partial"] = False
yield item
_stride_left = stride_left
acc = acc[chunk_len - stride_left - stride_right :]
if len(acc) > stride_left:
item = {"raw": acc, "stride": (_stride_left, 0)}
if stream:
item["partial"] = False
yield item
def ffmpeg_microphone_live(
sampling_rate: int,
chunk_length_s: float,
stream_chunk_s: Optional[int] = None,
stride_length_s: Optional[Union[Tuple[float, float], float]] = None,
format_for_conversion: str = "f32le",
):
if stream_chunk_s is not None:
chunk_s = stream_chunk_s
else:
chunk_s = chunk_length_s
microphone = ffmpeg_microphone(sampling_rate, chunk_s, format_for_conversion=format_for_conversion)
if format_for_conversion == "s16le":
dtype = np.int16
size_of_sample = 2
elif format_for_conversion == "f32le":
dtype = np.float32
size_of_sample = 4
else:
raise ValueError(f"Unhandled format `{format_for_conversion}`. Please use `s16le` or `f32le`")
if stride_length_s is None:
stride_length_s = chunk_length_s / 6
chunk_len = int(round(sampling_rate * chunk_length_s)) * size_of_sample
if isinstance(stride_length_s, (int, float)):
stride_length_s = [stride_length_s, stride_length_s]
stride_left = int(round(sampling_rate * stride_length_s[0])) * size_of_sample
stride_right = int(round(sampling_rate * stride_length_s[1])) * size_of_sample
audio_time = datetime.datetime.now()
delta = datetime.timedelta(seconds=chunk_s)
for item in chunk_bytes_iter(microphone, chunk_len, stride=(stride_left, stride_right), stream=True):
item["raw"] = np.frombuffer(item["raw"], dtype=dtype)
item["stride"] = (
item["stride"][0] // size_of_sample,
item["stride"][1] // size_of_sample,
)
item["sampling_rate"] = sampling_rate
audio_time += delta
if datetime.datetime.now() > audio_time + 10 * delta:
continue
yield item
async def send_audio(websocket, mic_capture):
try:
while True:
audio_chunk = await asyncio.to_thread(next, mic_capture)
await websocket.send(audio_chunk["raw"].tobytes())
except Exception as e:
print(f"Error sending audio: {e}")
async def display_transcription(websocket, typing_speed=0.05):
displayed_text = ""
previous_text = ""
async for transcription in websocket:
min_length = min(len(displayed_text), len(transcription))
diff_index = next((i for i in range(min_length) if displayed_text[i] != transcription[i]), min_length)
sys.stdout.write('\r' + ' ' * len(displayed_text) + '\r')
sys.stdout.flush()
displayed_text = transcription
if displayed_text != previous_text:
sys.stdout.write(displayed_text[:diff_index])
for char in displayed_text[diff_index:]:
sys.stdout.write(char)
sys.stdout.flush()
await asyncio.sleep(typing_speed)
previous_text = displayed_text
print()
async def receive_transcription(websocket, display_func):
try:
while True:
transcription = await websocket.recv()
await display_func(websocket, transcription)
except Exception as e:
print(f"Error receiving transcription: {e}")
async def send_audio_and_receive_transcription(uri, sampling_rate, chunk_length_s, stream_chunk_s):
async with websockets.connect(uri) as websocket:
mic_capture = await asyncio.to_thread(
ffmpeg_microphone_live,
sampling_rate=sampling_rate,
chunk_length_s=chunk_length_s,
stream_chunk_s=stream_chunk_s
)
send_task = asyncio.create_task(send_audio(websocket, mic_capture))
display_task = asyncio.create_task(display_transcription(websocket))
receive_task = asyncio.create_task(receive_transcription(websocket, display_task))
await asyncio.gather(send_task, receive_task)
asyncio.run(send_audio_and_receive_transcription(f"ws://IP:PORT/ws/transcribe", 16000, 10.0, 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment