Skip to content

Instantly share code, notes, and snippets.

@lewoudar
Created June 22, 2024 11:55
Show Gist options
  • Save lewoudar/4e72d106d67ebc131597347618598eb7 to your computer and use it in GitHub Desktop.
Save lewoudar/4e72d106d67ebc131597347618598eb7 to your computer and use it in GitHub Desktop.
Playing with the ctranslate2 backend
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
from faster_whisper import WhisperModel
model_size = "large-v3"
# workaround if we have already installed openai whisper stuff via transformers
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
# Run on GPU with FP16
# model = WhisperModel(model_size, device="cuda", compute_type="float16")
# or run on GPU with INT8
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
model = WhisperModel(model_size, device="cpu", compute_type="int8")
path = Path(__file__).parent.parent / 'sample.wav'
@dataclass
class Segment:
start: float
end: float
text: str
@dataclass
class Transcriber:
model: WhisperModel
audio: Path
_text: str = field(init=False, default='')
_segments: list[Segment] = field(default_factory=list, init=False)
def transcribe(self) -> None:
segments, info = self.model.transcribe(self.audio.as_posix(), beam_size=5)
print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
for segment in segments:
self._text += segment.text
self._segments.append(Segment(segment.start, segment.end, segment.text))
print(self._segments[0])
@staticmethod
def _format_timestamp(seconds: float) -> str:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
return f"{hours:02}:{minutes:02}:{secs:06.3f}".replace('.', ',')
@property
def text(self) -> str:
return self._text
def _get_writable_path(self, filename: str | None, suffix: Literal['vtt', 'srt']) -> Path:
if filename is not None and not os.access(Path(filename), os.W_OK):
raise PermissionError(f'{filename} is not writable')
return Path(filename) if filename else self.audio.resolve().with_suffix(f'.{suffix}')
def create_vtt_file(self, filename: str | None = None) -> None:
vtt_file = self._get_writable_path(filename, 'vtt')
with vtt_file.open('w') as f:
f.write('WEBVTT\n\n')
for segment in self._segments:
start_time = self._format_timestamp(segment.start)
end_time = self._format_timestamp(segment.end)
f.write(f'{start_time} --> {end_time}\n')
f.write(f'{segment.text}\n\n')
def create_srt_file(self, filename: str | None = None) -> None:
srt_file = self._get_writable_path(filename, 'srt')
with srt_file.open('w') as f:
for index, segment in enumerate(self._segments, start=1):
start_time = self._format_timestamp(segment.start)
end_time = self._format_timestamp(segment.end)
f.write(f'{index}\n')
f.write(f'{start_time} --> {end_time}\n')
f.write(f'{segment.text}\n\n')
transcriber = Transcriber(model=model, audio=path)
start = time.perf_counter()
transcriber.transcribe()
print(f'duration: {time.perf_counter() - start:.2f}s')
print(transcriber.text)
transcriber.create_vtt_file()
transcriber.create_srt_file()
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal
import whisperx
import torch
import psutil
from whisperx.asr import FasterWhisperPipeline
device = "gpu" if torch.cuda.is_available() else "cpu"
audio_file = Path(__file__).parent.parent / 'sample.wav'
compute_type = "float16" if device == "gpu" else "int8"
model = whisperx.load_model(
"large-v3", device, compute_type=compute_type, threads=psutil.cpu_count(logical=False), asr_options={'hotwords': None}
)
@dataclass
class Segment:
start: float
end: float
text: str
@dataclass
class Transcriber:
model: FasterWhisperPipeline
audio: Path
batch_size: int
device: Literal['cpu', 'gpu'] = 'cpu'
_text: str = field(init=False, default='')
_segments: list[Segment] = field(default_factory=list, init=False)
def transcribe(self):
audio = whisperx.load_audio(audio_file.as_posix())
result = self.model.transcribe(audio, batch_size=self.batch_size)
self._text = ''.join(segment['text'] for segment in result['segments'])
# align Whisper model
model_a, metadata = whisperx.load_align_model(language_code=result['language'], device=self.device)
result = whisperx.align(result['segments'], model_a, metadata, audio, device, return_char_alignments=False)
for segment in result['segments']:
for single_segment in segment['words']:
self._segments.append(
Segment(start=single_segment['start'], end=single_segment['end'], text=single_segment['word'])
)
@property
def text(self) -> str:
return self._text
@staticmethod
def _format_timestamp(seconds: float) -> str:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
return f"{hours:02}:{minutes:02}:{secs:06.3f}".replace('.', ',')
def _get_writable_path(self, filename: str | None, suffix: Literal['vtt', 'srt']) -> Path:
if filename is not None and not os.access(Path(filename), os.W_OK):
raise PermissionError(f'{filename} is not writable')
return Path(filename) if filename else self.audio.resolve().with_suffix(f'.{suffix}')
def create_vtt_file(self, filename: str | None = None) -> None:
vtt_file = self._get_writable_path(filename, 'vtt')
with vtt_file.open('w') as f:
f.write('WEBVTT\n\n')
for segment in self._segments:
start_time = self._format_timestamp(segment.start)
end_time = self._format_timestamp(segment.end)
f.write(f'{start_time} --> {end_time}\n')
f.write(f'{segment.text}\n\n')
def create_srt_file(self, filename: str | None = None) -> None:
srt_file = self._get_writable_path(filename, 'srt')
with srt_file.open('w') as f:
for index, segment in enumerate(self._segments, start=1):
start_time = self._format_timestamp(segment.start)
end_time = self._format_timestamp(segment.end)
f.write(f'{index}\n')
f.write(f'{start_time} --> {end_time}\n')
f.write(f'{segment.text}\n\n')
transcriber = Transcriber(
model=model,
audio=audio_file,
batch_size=16, # reduce if low on GPU memory
device=device # type: ignore
)
start = time.perf_counter()
transcriber.transcribe()
print(f'duration: {time.perf_counter() - start:.2f}s')
print(transcriber.text)
transcriber.create_vtt_file()
transcriber.create_srt_file()
import time
import os
from pathlib import Path
import ctranslate2
import transformers
# workaround when there are conflicts with transformers installations
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
model_path = Path.home() / '.cache' / 'ctranslate2' / 'm2m100_1.2B'
translator = ctranslate2.Translator(model_path.as_posix())
def translate(text, source_lang, target_lang):
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_1.2B")
tokenizer.src_lang = source_lang
source = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
target_prefix = [tokenizer.lang_code_to_token[target_lang]]
results = translator.translate_batch([source], target_prefix=[target_prefix])
target = results[0].hypotheses[0][1:]
return tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
start_time = time.perf_counter()
# hindi to french
print(translate("जीवन एक चॉकलेट बॉक्स की तरह है।", "hi", "fr"))
print(f'took {time.perf_counter() - start_time:.2f} seconds')
start_time = time.perf_counter()
# chinese to english
print(translate("生活就像一盒巧克力。", "zh", "en"))
print(f'took {time.perf_counter() - start_time:.2f} seconds')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment