Skip to content

Instantly share code, notes, and snippets.

@oscartangqh
Created July 26, 2024 02:40
Show Gist options
  • Save oscartangqh/6a6816cc86590f62b6861cc89c58173a to your computer and use it in GitHub Desktop.
Save oscartangqh/6a6816cc86590f62b6861cc89c58173a to your computer and use it in GitHub Desktop.
fishaudio/fish-speech. Continue this conversation at https://readfm.vercel.app?gist=6a6816cc86590f62b6861cc89c58173a

[github] fishaudio/fish-speech

Source

Python / 11.6K lines of code. Brand new TTS solution

URL: https://github.com/fishaudio/fish-speech

Conversation

β–² Explain the major function of the project and how it works by referencing the most important and relevant code snippets.

The project fishaudio/fish-speech is an open-source text-to-speech (TTS) model that leverages VQ-GAN and LLAMA to generate speech from text. It supports multiple languages, including Chinese, Japanese, and English.

Here's a breakdown of its key functionalities and how it works:

1. Encoding:

  • The project uses a VQ-GAN model to encode the input audio into a sequence of semantic tokens. This process is handled by the tools/vqgan/extract_vq.py script.
  • The VQ-GAN model is defined in fish_speech/models/vqgan/lit_module.py and utilizes a WaveNet encoder (fish_speech/models/vqgan/modules/wavenet.py) and a quantizer (fish_speech/models/vqgan/modules/fsq.py).
  • The encode method in fish_speech/models/vqgan/lit_module.py performs the encoding process:
    def encode(self, audios, audio_lengths):
        audios = audios.float()

        mels = self.encode_mel_transform(audios)
        mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
        mel_masks_float_conv = mel_masks[:, None, :].float()
        mels = mels * mel_masks_float_conv

        # Encode
        encoded_features = self.encoder(mels) * mel_masks_float_conv
        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)

        return self.quantizer.encode(encoded_features), feature_lengths

2. Text-to-Semantic Generation:

  • The project utilizes a LLAMA model to generate semantic tokens from the input text. This is handled by the tools/llama/generate.py script.
  • The LLAMA model is defined in fish_speech/models/text2semantic/llama.py and uses a transformer architecture (fish_speech/models/text2semantic/llama.py).
  • The generate function in tools/llama/generate.py performs the generation process:
@torch.no_grad()
@torch.inference_mode()
def generate(
    *,
    model: NaiveTransformer,
    prompt: torch.Tensor,
    max_new_tokens: int,
    im_end_id: int = 4,
    decode_one_token=decode_one_token_naive,
    **sampling_kwargs,
) -> torch.Tensor:
    """
    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
    """

    # create an empty tensor of the expected final shape and fill in the current tokens
    T = prompt.size(1)

    if max_new_tokens:
        if T + max_new_tokens > model.config.max_seq_len:
            max_new_tokens = model.config.max_seq_len - T
            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")

        T_new = T + max_new_tokens
    else:
        T_new = model.config.max_seq_len
        max_new_tokens = T_new - T

    device, dtype = prompt.device, prompt.dtype
    with torch.device(device):
        model.setup_caches(
            max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
        )

    codebook_dim = 1 + model.config.num_codebooks
    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
    empty[:, :T] = prompt
    seq = empty
    input_pos = torch.arange(0, T, device=device)

    # Use non-accelerated version for now, to avoid compilation overhead
    prefill_decode = (
        decode_one_token_naive
        if isinstance(model, NaiveTransformer)
        else decode_one_token_ar
    )

    next_token = prefill_decode(
        model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
    )
    seq[:, T : T + 1] = next_token

    input_pos = torch.tensor([T], device=device, dtype=torch.int)
    x = decode_n_tokens(
        model,
        next_token.view(1, codebook_dim, -1),
        input_pos,
        max_new_tokens - 1,
        im_end_id=im_end_id,
        decode_one_token=decode_one_token,
        **sampling_kwargs,
    )
    # x = torch.cat(generated_tokens, dim=1)
    seq = seq[:, : T + 1 + x.size(1)]
    seq[:, T + 1 :] = x

    return seq

3. Decoding:

  • The generated semantic tokens are then decoded back into audio using the VQ-GAN decoder. This is handled by the tools/vqgan/inference.py script.
  • The VQ-GAN decoder is defined in fish_speech/models/vqgan/lit_module.py and utilizes a WaveNet decoder (fish_speech/models/vqgan/modules/wavenet.py).
  • The decode method in fish_speech/models/vqgan/lit_module.py performs the decoding process:
    def decode(self, indices, feature_lengths, return_audios=False):
        factor = math.prod(self.quantizer.downsample_factor)
        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
        mel_masks_float_conv = mel_masks[:, None, :].float()

        z = self.quantizer.decode(indices) * mel_masks_float_conv
        z = (
            z
            + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
                :, :, None
            ]
        )

        gen_mel = (
            self.decoder(
                torch.randn_like(z) * mel_masks_float_conv,
                condition=z,
            )
            * mel_masks_float_conv
        )

        if return_audios:
            return self.vocoder(gen_mel)

        return gen_mel

4. Fine-tuning:

  • The project allows for fine-tuning the LLAMA model on a specific dataset to improve its performance for a particular speaker or voice style. This is done using the tools/llama/build_dataset.py and fish_speech/train.py scripts.
  • The fine-tuning process involves preparing the dataset in a specific format, extracting semantic tokens from the audio, and then training the LLAMA model with LoRA (Low-Rank Adaptation) to adapt its weights to the new dataset.

Overall, the project combines VQ-GAN and LLAMA to achieve high-quality TTS generation. It provides a flexible framework for encoding, generating, and decoding speech, and allows for fine-tuning to customize the model for specific voice styles.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment