Python / 11.6K lines of code. Brand new TTS solution
URL: https://github.com/fishaudio/fish-speech
β² Explain the major function of the project and how it works by referencing the most important and relevant code snippets.
The project fishaudio/fish-speech
is an open-source text-to-speech (TTS) model that leverages VQ-GAN and LLAMA to generate speech from text. It supports multiple languages, including Chinese, Japanese, and English.
Here's a breakdown of its key functionalities and how it works:
1. Encoding:
- The project uses a VQ-GAN model to encode the input audio into a sequence of semantic tokens. This process is handled by the
tools/vqgan/extract_vq.py
script. - The VQ-GAN model is defined in
fish_speech/models/vqgan/lit_module.py
and utilizes a WaveNet encoder (fish_speech/models/vqgan/modules/wavenet.py
) and a quantizer (fish_speech/models/vqgan/modules/fsq.py
). - The
encode
method infish_speech/models/vqgan/lit_module.py
performs the encoding process:
def encode(self, audios, audio_lengths):
audios = audios.float()
mels = self.encode_mel_transform(audios)
mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
mel_masks_float_conv = mel_masks[:, None, :].float()
mels = mels * mel_masks_float_conv
# Encode
encoded_features = self.encoder(mels) * mel_masks_float_conv
feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
return self.quantizer.encode(encoded_features), feature_lengths
2. Text-to-Semantic Generation:
- The project utilizes a LLAMA model to generate semantic tokens from the input text. This is handled by the
tools/llama/generate.py
script. - The LLAMA model is defined in
fish_speech/models/text2semantic/llama.py
and uses a transformer architecture (fish_speech/models/text2semantic/llama.py
). - The
generate
function intools/llama/generate.py
performs the generation process:
@torch.no_grad()
@torch.inference_mode()
def generate(
*,
model: NaiveTransformer,
prompt: torch.Tensor,
max_new_tokens: int,
im_end_id: int = 4,
decode_one_token=decode_one_token_naive,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)
if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T
device, dtype = prompt.device, prompt.dtype
with torch.device(device):
model.setup_caches(
max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
)
codebook_dim = 1 + model.config.num_codebooks
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
empty[:, :T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)
# Use non-accelerated version for now, to avoid compilation overhead
prefill_decode = (
decode_one_token_naive
if isinstance(model, NaiveTransformer)
else decode_one_token_ar
)
next_token = prefill_decode(
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
)
seq[:, T : T + 1] = next_token
input_pos = torch.tensor([T], device=device, dtype=torch.int)
x = decode_n_tokens(
model,
next_token.view(1, codebook_dim, -1),
input_pos,
max_new_tokens - 1,
im_end_id=im_end_id,
decode_one_token=decode_one_token,
**sampling_kwargs,
)
# x = torch.cat(generated_tokens, dim=1)
seq = seq[:, : T + 1 + x.size(1)]
seq[:, T + 1 :] = x
return seq
3. Decoding:
- The generated semantic tokens are then decoded back into audio using the VQ-GAN decoder. This is handled by the
tools/vqgan/inference.py
script. - The VQ-GAN decoder is defined in
fish_speech/models/vqgan/lit_module.py
and utilizes a WaveNet decoder (fish_speech/models/vqgan/modules/wavenet.py
). - The
decode
method infish_speech/models/vqgan/lit_module.py
performs the decoding process:
def decode(self, indices, feature_lengths, return_audios=False):
factor = math.prod(self.quantizer.downsample_factor)
mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
mel_masks_float_conv = mel_masks[:, None, :].float()
z = self.quantizer.decode(indices) * mel_masks_float_conv
z = (
z
+ self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
:, :, None
]
)
gen_mel = (
self.decoder(
torch.randn_like(z) * mel_masks_float_conv,
condition=z,
)
* mel_masks_float_conv
)
if return_audios:
return self.vocoder(gen_mel)
return gen_mel
4. Fine-tuning:
- The project allows for fine-tuning the LLAMA model on a specific dataset to improve its performance for a particular speaker or voice style. This is done using the
tools/llama/build_dataset.py
andfish_speech/train.py
scripts. - The fine-tuning process involves preparing the dataset in a specific format, extracting semantic tokens from the audio, and then training the LLAMA model with LoRA (Low-Rank Adaptation) to adapt its weights to the new dataset.
Overall, the project combines VQ-GAN and LLAMA to achieve high-quality TTS generation. It provides a flexible framework for encoding, generating, and decoding speech, and allows for fine-tuning to customize the model for specific voice styles.