Training LLM-Based + Neural Codec TTS Models
Introduction
Training Text-to-Speech (TTS) models from scratch can be approached in multiple ways, each with its own advantages and trade-offs. The most common approaches include:
- End-to-End Neural TTS: Training models like VITS or FastSpeech 2s that directly generate mel-spectrograms or waveforms from text
- LLM-Based + Neural Codec: Leveraging pretrained language models to generate discrete audio tokens from neural codecs, then decoding them to audio
In this guide, we'll focus on the LLM-based + Neural Codec approach, which has gained significant traction due to its ability to leverage powerful pretrained language models and achieve high-quality speech synthesis. This method treats audio generation as a sequence-to-sequence problem, where the model learns to predict discrete audio tokens (from neural codecs like SNAC or EnCodec) given text input.
Why LLM + Neural Codec?
This approach benefits from:
- Transfer Learning: Leveraging knowledge from large pretrained language models
- Discrete Tokenization: Converting audio to discrete tokens makes it compatible with language model architectures
- Scalability: Can handle long audio sequences efficiently
- Flexibility: Easy to extend to multi-speaker or multilingual scenarios
Overview of the Approach
The LLM-based TTS training pipeline consists of several key steps:
- Audio Tokenization: Convert audio waveforms into discrete tokens using a neural codec (e.g., SNAC, EnCodec)
- Vocabulary Extension: Add audio tokens to the language model's vocabulary
- Data Preparation: Format training data as text-audio token pairs
- Model Training: Fine-tune the language model to predict audio tokens given text
- Inference: Generate audio tokens from text, then decode them back to audio using the neural codec
flowchart TD
subgraph Data Preparation
direction LR
A[Text Input] --> B[Audio Recording]
B --> C[Neural Codec]
end
subgraph Model Training
direction RL
D[Extend LLM Vocabulary] --> E[Model Training]
E --> F[Inference]
end
C -- "Discrete Audio Tokens" --> D
F -- "Decode" --> H[Neural Codec Decoder]
H -- "Audio" --> I[Audio Output]
Main Concepts
Let's understand the key concepts of this approach in detail. We will start with a dataset where we have (audio, text) pairs.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | |
As you can see, we have MrDragonFox/Elise dataset 1195 samples with (audio, text) pairs along with metadata like sampling_rate. Now let's pick one random sample and play the audio.
1 2 3 4 5 6 7 8 9 | |
Here is the audio (which has size of 182133 elements at 22050 Hz sampling rate):
Now let's convert the audio to SNAC codes.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | |
This returns a list 3 tensors denoting the SNAC codes for the 3 layers. This means we have transformed the audio from 180k samples to just 97 + 194 + 388 = 679 tokens. This is a significant reduction in the size of the data! To double check, we can decode the codec into audio and listen to match with the original audio.
1 2 3 4 | |
Perfect!
Now, as we can see, we already have a model that can convert any audio to significantly smaller number of tokens, i.e. Audio --> Code and then convert the codes back to audio i.e. Code --> Audio. But the requirement of TTS is that we need to convert text to audio, i.e. Text --> Audio. This is where we can leverage the power of pretrained language models to convert text to audio tokens, i.e. Text --> Code and then use the neural codec to go from Code --> Audio. This is the approach we will be taking in this guide.
Code Walkthrough
Let's walk through a complete implementation for training a TTS model using Gemma-3 and SNAC codec. We'll break down the code into logical sections and explain each component. I have created a complete implementation of this approach that you can use as a reference.
Setup and Imports
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | |
The code imports necessary libraries. The SFTTrainer from the trl library is used for supervised fine-tuning, which is well-suited for causal language modeling tasks.
Configuration Classes
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | |
These dataclasses define configuration arguments for the model and data. The ModelArguments specifies which pretrained model to use (Gemma-3-270m), while DataArguments defines dataset parameters including the maximum sequence length (here, 900 tokens) and validation split percentage.
SNAC Codec Configuration
1 2 3 4 5 6 7 8 9 10 11 12 | |
SNAC uses a multi-scale quantization approach with 3 layers. Each layer has 4096 possible code values (0-4095), resulting in 12,288 total SNAC tokens. The configuration also defines special tokens to mark the start/end of audio sequences and layer separators.
SNAC Multi-Scale Structure
SNAC quantizes audio at different temporal resolutions:
- Layer 1: Coarse temporal structure (lower frame rate)
- Layer 2: Medium temporal structure
- Layer 3: Fine acoustic details (higher frame rate)
This hierarchical approach captures both long-term patterns (prosody, semantics) and short-term details (phonetic features).
Dataset Loading
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | |
This function basically performs two things,
- loads the dataset from Hugging Face Hub. While you can start from a dataset with
(audio, text)pairs like MrDragonFox/Elise, I have already transformed the dataset to(text, SNAC code)pairs that you can use here mohitmayank/elise_text_snac_codes. We will download and use this dataset for training. - dynamically discovers any additional special tokens present in the dataset text and adds them to the configuration. This is important because the dataset may contain special tokens that are not part of the SNAC configuration, such as
<giggling>,<laughter>,<sigh>, etc.
SNAC Code to Token Conversion
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | |
This function converts SNAC codes (which are integers) into token strings that can be added to the language model vocabulary. In a way, we are transforming the audio tokens (which were 3 lists of tensors) into a single string of tokens. This is done by assuming the SNAC codec follow a time frame, where each frame consist of 7 elements. Each time frame will contain codes from all the 3 layers in following manner,
- First element of the frame is code from layer 1
- Second and third elements of the frame are codes from layer 2
- Fourth to seventh elements of the frame are codes from layer 3
One sample time frame will look like this:
<snac_l1_123> <snac_l2_456> <snac_l2_478> <snac_l3_789><snac_l3_1123><snac_l3_100><snac_l3_54>
We do it for the complete audio sequence till we exhaust all the time frames or the SNAC codes, and the encapsulate the audio within the special tokens <audio_start> and <audio_end>, so that the complete sequence looks like this:
<audio_start> <snac_l1_123> ...<snac_l3_54> <audio_end>
Why follow this complex format?
This format is chosen because it is easy to parse and decode back to audio. It is also easy to train a language model on this format. On the plus side, this format can help with real-time streaming application in future, as we get generation output in timeframe manner.
Data Preprocessing
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | |
The preprocessing function combines text and SNAC tokens into a single sequence. The format is: "{text} {snac_tokens}", where the model learns to predict the SNAC tokens given the text. For causal language modeling, labels are set equal to input_ids, meaning the model learns to predict the next token in the sequence (including both text and audio tokens). One example of the output of the preprocessing function will look like this:
Please have mercy on my dainty, frail body. Your coils are so strong and powerful, and I am powerless to resist <audio_start> <snac_l1_123> ... <snac_l3_54> <audio_end>
This is the format that the model will learn to predict.
Training Format
The model is trained to predict the entire sequence autoregressively. During inference, you provide only the text, and the model generates the SNAC tokens that follow, which can then be decoded back to audio.
Model and Tokenizer Setup
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | |
This is a critical function that extends the language model vocabulary to include SNAC tokens:
- Load Base Model: Loads the pretrained Gemma-3-270m model and tokenizer
- Add Special Tokens: Adds SNAC special tokens (
<audio_start>,<audio_end>, etc.) - Add SNAC Code Tokens: Creates tokens for all possible SNAC codes (e.g.,
<snac_l1_0>through<snac_l1_4095>for each layer) - Resize Embeddings: Expands the model's embedding layer to accommodate the new tokens
Embedding Resizing
When new tokens are added, the model's embedding matrix must be resized. The new token embeddings are typically initialized randomly or copied from similar existing tokens. This is crucial for the model to learn meaningful representations of the audio tokens.
Training Configuration
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | |
The training arguments configure the fine-tuning process:
- Batch Size:
per_device_train_batch_size=2withgradient_accumulation_steps=4gives an effective batch size of 8 - Learning Rate:
2e-4is a common starting point for fine-tuning language models - Scheduler: Cosine learning rate schedule with 10% warmup
- Mixed Precision: Can use
bf16orfp16for faster training and lower memory usage
Experimentation
You are invited to experiment with different learning rates, batch sizes, gradient accumulation steps, number of training epochs, etc. to get the best performance.
Training Execution
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | |
The training pipeline:
- Load Data: Loads text-SNAC code pairs from the dataset
- Setup Model: Extends the vocabulary and resizes embeddings
- Preprocess: Converts SNAC codes to tokens and formats training sequences
- Data Collator: Handles batching and padding for causal language modeling
- Train: Uses
SFTTrainerto fine-tune the model
The DataCollatorForLanguageModeling with mlm=False is configured for causal language modeling, where the model predicts the next token in the sequence.
And we are done! Once you perform the training, you should see a training log like shown below:
But Mohit, it's overfitting!
Yes, sir! This is expected because we are using a very small dataset, and training for multiple epochs. This tutorial is meant to give you a starting point and a reference implementation.
Head over to the conclusion section for next steps and my comments. We are just getting started.
Inference
After training, inference involves:
- Text Input: Provide text to the model
- Token Generation: Model generates SNAC tokens autoregressively
- Token Parsing: Extract SNAC codes from generated tokens
- Audio Decoding: Use SNAC decoder to convert codes back to audio waveform
Here's a complete inference implementation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | |
Understanding the Parsing Function
The parse_snac_codes function performs several key operations:
-
Token Filtering: It filters tokens with IDs >= 262173, which corresponds to the start of the SNAC token vocabulary in the extended model.
-
Frame Processing: The function processes tokens in batches of 7, matching the SNAC frame structure:
- 1 token from layer 1 (coarse structure)
- 2 tokens from layer 2 (medium detail)
- 4 tokens from layer 3 (fine detail)
-
Code Extraction: For each token, it:
- Subtracts 262173 to get the relative token ID within SNAC vocabulary
- Takes modulo 4096 to extract the actual code value (since each layer has 4096 possible codes: 0-4095)
-
Tensor Conversion: Converts the extracted codes into PyTorch tensors with the batch dimension, matching the format expected by the SNAC decoder.
Token ID Threshold
The threshold value 262173 is specific to this model configuration. It represents the starting token ID of SNAC tokens in the extended vocabulary. If you're using a different base model or have a different vocabulary size, you'll need to adjust this threshold accordingly. You can verify this by checking the tokenizer's vocabulary size before and after adding SNAC tokens.
Best Practices
- Data Quality: Ensure high-quality text-audio pairs with accurate transcriptions
- Sequence Length: Balance
max_seq_lengthbetween model capacity and training efficiency - Learning Rate: Start with lower learning rates (1e-4 to 5e-4) for fine-tuning
- Gradient Accumulation: Use gradient accumulation to simulate larger batch sizes
- Checkpointing: Save checkpoints regularly to resume training if interrupted
- Evaluation: Monitor both training loss and audio quality metrics (e.g., MOS, SECS)
Conclusion
Training TTS models using LLM-based approaches with neural codecs represents a powerful paradigm that combines the strengths of modern language models with efficient audio representations. By treating audio generation as a sequence-to-sequence problem, we can leverage transfer learning and achieve high-quality speech synthesis. The key is understanding how to convert audio to discrete tokens, extend model vocabularies, and train the model to learn the text-to-audio mapping.
As stated before, this tutorial is meant to give you a starting point and a reference implementation. We are just getting started with the journey of training TTS models. I think there are multiple ways to improve the performance of the model. Some ways are:
- Use a larger dataset (currently it is trained on ~1000 samples, we need to atleast 10x the size to get respectable results)
- Data Preprocessing (we can preprocess the data to handle the special tokens and to make sure the data is in the correct format. We can further break the audio into smaller and cleaner chunks, clean the audio by denoising, etc)
- Data Diversification (we need to diversify the data to improve the performance of the model. We can use data augmentation techniques or have audio from multiple speakers speaking in different speed, pitch, language, accent, etc)
- Use Reinforcement Learning to fine-tune the model (we can add explicit rewards like WER, SECS, etc to the training loop to improve the quality of the generated speech)
- and more...
Maybe more on this in future guides!
Do let me know if you have any questions or suggestions. If you want to contribute to this guide, please feel free to submit a pull request. If you want to discuss something, please feel free to reach out to me on LinkedIn.