Model API#
Model API provides you a set of functions that easily make your model compatible with the Trainer
,
Synthesizer
and ModelZoo
.
Base TTS Model#
- class TTS.model.BaseTrainerModel(*args, **kwargs)[source]#
BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS.
Every new 🐸TTS model must inherit it.
- abstract inference(input, aux_input={})[source]#
Forward pass for inference.
It must return a dictionary with the main model output and all the auxiliary outputs. The key
`model_outputs`
is considered to be the main output and you can add any other auxiliary outputs as you want.We don’t use *kwargs since it is problematic with the TorchScript API.
- Parameters:
input (torch.Tensor) – [description]
aux_input (Dict) – Auxiliary inputs like speaker embeddings, durations etc.
- Returns:
[description]
- Return type:
Dict
- abstract static init_from_config(config)[source]#
Init the model and all its attributes from the given config.
Override this depending on your model.
- abstract load_checkpoint(config, checkpoint_path, eval=False, strict=True, cache=False)[source]#
Load a model checkpoint gile and get ready for training or inference.
- Parameters:
config (Coqpit) – Model configuration.
checkpoint_path (str) – Path to the model checkpoint file.
eval (bool, optional) – If true, init model for inference else for training. Defaults to False.
strict (bool, optional) – Match all checkpoint keys to model’s keys. Defaults to True.
cache (bool, optional) – If True, cache the file locally for subsequent calls. It is cached under get_user_data_dir()/tts_cache. Defaults to False.
Base tts Model#
- class TTS.tts.models.base_tts.BaseTTS(config, ap, tokenizer, speaker_manager=None, language_manager=None)[source]#
Base tts class. Every new tts model must inherit this.
It defines common tts specific functions on top of Model implementation.
- format_batch(batch)[source]#
Generic batch formatting for TTSDataset.
You must override this if you use a custom dataset.
- Parameters:
batch (Dict) – [description]
- Returns:
[description]
- Return type:
Dict
- init_multispeaker(config, data=None)[source]#
Initialize a speaker embedding layer if needen and define expected embedding channel size for defining in_channels size of the connected layers.
This implementation yields 3 possible outcomes:
If config.use_speaker_embedding and `config.use_d_vector_file are False, do nothing.
If config.use_d_vector_file is True, set expected embedding channel size to config.d_vector_dim or 512.
3. If config.use_speaker_embedding, initialize a speaker embedding layer with channel size of config.d_vector_dim or 512.
You can override this function for new models.
- Parameters:
config (Coqpit) – Model configuration.
- on_init_start(trainer)[source]#
Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.
- test_run(assets)[source]#
Generic test run for tts models used by Trainer.
You can override this for a different behaviour.
- Parameters:
assets (dict) – A dict of training assets. For tts models, it must include {‘audio_processor’: ap}.
- Returns:
Test figures and audios to be projected to Tensorboard.
- Return type:
Tuple[Dict, Dict]
Base vocoder Model#
- class TTS.vocoder.models.base_vocoder.BaseVocoder(config)[source]#
Base vocoder class. Every new vocoder model must inherit this.
It defines vocoder specific functions on top of Model.
- Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
3D tensors batch x time x channels
2D tensors batch x channels
1D tensors batch x 1