-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Documentation for ASR-TTS models #6594
Changes from all commits
261e313
727431d
5f94f66
578f27f
c634761
9ec7c41
165c883
8c74a78
09f7764
62c6c9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -878,6 +878,113 @@ FastEmit Regularization is supported for the default Numba based WarpRNNT loss. | |
Refer to the above paper for results and recommendations of ``fastemit_lambda``. | ||
|
||
|
||
.. _Hybrid-ASR-TTS_model__Config: | ||
|
||
Hybrid ASR-TTS Model Configuration | ||
---------------------------------- | ||
|
||
:ref:`Hybrid ASR-TTS model <Hybrid-ASR-TTS_model>` consists of three parts: | ||
|
||
* ASR model (``EncDecCTCModelBPE`` or ``EncDecRNNTBPEModel``) | ||
* TTS Mel Spectrogram Generator (currently, only :ref:`FastPitch <FastPitch_model>` model is supported) | ||
* Enhancer model (optional) | ||
|
||
Also, the config allows to specify :ref:`text-only dataset <Hybrid-ASR-TTS_model__Text-Only-Data>`. | ||
|
||
Main parts of the config: | ||
|
||
* ASR model | ||
* ``asr_model_path``: path to the ASR model checkpoint (`.nemo`) file, loaded only once, then the config of the ASR model is stored in the ``asr_model`` field | ||
* ``asr_model_type``: needed only when training from scratch, ``rnnt_bpe`` corresponds to ``EncDecRNNTBPEModel``, ``ctc_bpe`` to ``EncDecCTCModelBPE`` | ||
* ``asr_model_fuse_bn``: fusing BatchNorm in the pretrained ASR model, can improve quality in finetuning scenario | ||
* TTS model | ||
* ``tts_model_path``: path to the pretrained TTS model checkpoint (`.nemo`) file, loaded only once, then the config of the model is stored in the ``tts_model`` field | ||
* Enhancer model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need link / documentation somewhere to explain what is the enhancer model |
||
* ``enhancer_model_path``: optional path to the enhancer model. Loaded only once, the config is stored in the ``enhancer_model`` field | ||
* ``train_ds`` | ||
* ``text_data``: properties related to text-only data | ||
* ``manifest_filepath``: path (or paths) to :ref:`text-only dataset <Hybrid-ASR-TTS_model__Text-Only-Data>` manifests | ||
* ``speakers_filepath``: path (or paths) to the text file containing speaker ids for the multi-speaker TTS model (speakers are sampled randomly during training) | ||
* ``min_words`` and ``max_words``: parameters to filter text-only manifests by the number of words | ||
* ``tokenizer_workers``: number of workers for initial tokenization (when loading the data). ``num_CPUs / num_GPUs`` is a recommended value. | ||
* ``asr_tts_sampling_technique``, ``asr_tts_sampling_temperature``, ``asr_tts_sampling_probabilities``: sampling parameters for text-only and audio-text data (if both specified). See parameters for ``nemo.collections.common.data.ConcatDataset`` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the valid sampling techniques ? |
||
* all other components are similar to conventional ASR models | ||
* ``validation_ds`` and ``test_ds`` correspond to the underlying ASR model | ||
|
||
|
||
.. code-block:: yaml | ||
|
||
model: | ||
sample_rate: 16000 | ||
|
||
# asr model | ||
asr_model_path: ??? | ||
asr_model: null | ||
asr_model_type: null # rnnt_bpe or ctc_bpe, needed only if instantiating from config, otherwise type is auto inferred | ||
asr_model_fuse_bn: false # only ConformerEncoder supported now, use false for other models | ||
|
||
# tts model | ||
tts_model_path: ??? | ||
tts_model: null | ||
|
||
# enhancer model | ||
enhancer_model_path: null | ||
enhancer_model: null | ||
|
||
train_ds: | ||
text_data: | ||
manifest_filepath: ??? | ||
speakers_filepath: ??? | ||
min_words: 1 | ||
max_words: 45 # 45 - recommended value, ~16.7 sec for LibriSpeech | ||
tokenizer_workers: 1 | ||
asr_tts_sampling_technique: round-robin # random, round-robin, temperature | ||
asr_tts_sampling_temperature: null | ||
asr_tts_sampling_probabilities: null # [0.5,0.5] – ASR,TTS | ||
manifest_filepath: ??? | ||
batch_size: 16 # you may increase batch_size if your memory allows | ||
# other params | ||
|
||
Finetuning | ||
~~~~~~~~~~~ | ||
|
||
To finetune existing ASR model using text-only data use ``<NeMo_git_root>/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py`` script with the corresponding config ``<NeMo_git_root>/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml``. | ||
|
||
Please specify paths to all the required models (ASR, TTS, and Enhancer checkpoints), along with ``train_ds.text_data.manifest_filepath`` and ``train_ds.text_data.speakers_filepath``. | ||
|
||
.. code-block:: shell | ||
|
||
python speech_to_text_bpe_with_text_finetune.py \ | ||
model.asr_model_path=<path to ASR model> \ | ||
model.tts_model_path=<path to compatible TTS model> \ | ||
model.enhancer_model_path=<optional path to enhancer model> \ | ||
model.asr_model_fuse_bn=<true recommended if ConformerEncoder with BatchNorm, false otherwise> \ | ||
model.train_ds.manifest_filepath=<path to manifest with audio-text pairs or null> \ | ||
model.train_ds.text_data.manifest_filepath=<path(s) to manifest with train text> \ | ||
model.train_ds.text_data.speakers_filepath=<path(s) to speakers list> \ | ||
model.train_ds.text_data.tokenizer_workers=4 \ | ||
model.validation_ds.manifest_filepath=<path to validation manifest> \ | ||
model.train_ds.batch_size=<batch_size> | ||
|
||
Training from Scratch | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
To train ASR model from scratch using text-only data use ``<NeMo_git_root>/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py`` script with conventional ASR model config, e.g. ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_ctc_bpe.yaml`` or ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_bpe.yaml`` | ||
|
||
Please specify the ASR model type, paths to the TTS model, and (optional) enhancer, along with text-only data-related fields. | ||
|
||
.. code-block:: shell | ||
|
||
python speech_to_text_bpe_with_text.py \ | ||
++asr_model_type=<rnnt_bpe or ctc_bpe> \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are the ++ required ? |
||
++tts_model_path=<path to compatible tts model> \ | ||
++enhancer_model_path=<optional path to enhancer model> \ | ||
++model.train_ds.text_data.manifest_filepath=<path(s) to manifests with train text> \ | ||
++model.train_ds.text_data.speakers_filepath=<path(s) to speakers list> \ | ||
++model.train_ds.text_data.min_words=1 \ | ||
++model.train_ds.text_data.max_words=45 \ | ||
++model.train_ds.text_data.tokenizer_workers=4 | ||
|
||
Fine-tuning Configurations | ||
-------------------------- | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,17 @@ If there is a local ``.nemo`` checkpoint that you'd like to load, use the :code: | |
|
||
Where the model base class is the ASR model class of the original checkpoint, or the general ``ASRModel`` class. | ||
|
||
|
||
Hybrid ASR-TTS Models Checkpoints | ||
--------------------------------- | ||
|
||
:ref:`Hybrid ASR-TTS model <Hybrid-ASR-TTS_model>` is a transparent wrapper for the ASR model, text-to-mel-spectrogram generator, and optional enhancer. | ||
The model is saved as a solid ``.nemo`` checkpoint containing all these parts. | ||
Due to transparency, the ASR model can be extracted after training/finetuning separately by using the ``asr_model`` attribute (NeMo submodel) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very cool! |
||
:code:`hybrid_model.asr_model.save_to(<asr_checkpoint_path>.nemo)` or by using a wrapper | ||
made for convenience purpose :code:`hybrid_model.save_asr_model_to(<asr_checkpoint_path>.nemo)` | ||
|
||
|
||
NGC Pretrained Checkpoints | ||
-------------------------- | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,6 @@ | |
import numpy as np | ||
import torch | ||
import torch.utils.data | ||
from nemo_text_processing.text_normalization.normalize import Normalizer | ||
from torch.nn.utils.rnn import pad_sequence | ||
from tqdm.auto import tqdm | ||
|
||
|
@@ -35,6 +34,12 @@ | |
from nemo.core.classes import Dataset, IterableDataset | ||
from nemo.utils import logging | ||
|
||
try: | ||
from nemo_text_processing.text_normalization.normalize import Normalizer | ||
except Exception as e: | ||
logging.warning(e) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't warn on import. |
||
logging.warning("nemo_text_processing is not installed") | ||
|
||
AnyPath = Union[Path, str] | ||
|
||
|
||
|
@@ -176,7 +181,7 @@ def __init__( | |
asr_use_start_end_token: bool, | ||
tts_parser: Callable, | ||
tts_text_pad_id: int, | ||
tts_text_normalizer: Normalizer, | ||
tts_text_normalizer: "Normalizer", | ||
tts_text_normalizer_call_kwargs: Dict, | ||
min_words: int = 1, | ||
max_words: int = 1_000_000, | ||
|
@@ -379,7 +384,7 @@ def __init__( | |
asr_use_start_end_token: bool, | ||
tts_parser: Callable, | ||
tts_text_pad_id: int, | ||
tts_text_normalizer: Normalizer, | ||
tts_text_normalizer: "Normalizer", | ||
tts_text_normalizer_call_kwargs: Dict, | ||
min_words: int = 1, | ||
max_words: int = 1_000_000, | ||
|
@@ -426,7 +431,7 @@ def __init__( | |
asr_use_start_end_token: bool, | ||
tts_parser: Callable, | ||
tts_text_pad_id: int, | ||
tts_text_normalizer: Normalizer, | ||
tts_text_normalizer: "Normalizer", | ||
tts_text_normalizer_call_kwargs: Dict, | ||
min_words: int = 1, | ||
max_words: int = 1_000_000, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side note - we need to try with hybrid models.