Model API#

Model API provides you a set of functions that easily make your model compatible with the Trainer, Synthesizer and ModelZoo.

Base TTS Model#

class TTS.model.BaseTrainerModel(*args, **kwargs)[source]#

BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS.

Every new 🐸TTS model must inherit it.

abstract inference(input, aux_input={})[source]#

Forward pass for inference.

It must return a dictionary with the main model output and all the auxiliary outputs. The key `model_outputs` is considered to be the main output and you can add any other auxiliary outputs as you want.

We don’t use *kwargs since it is problematic with the TorchScript API.

Parameters:
  • input (torch.Tensor) – [description]

  • aux_input (Dict) – Auxiliary inputs like speaker embeddings, durations etc.

Returns:

[description]

Return type:

Dict

abstract static init_from_config(config)[source]#

Init the model and all its attributes from the given config.

Override this depending on your model.

abstract load_checkpoint(config, checkpoint_path, eval=False, strict=True, cache=False)[source]#

Load a model checkpoint gile and get ready for training or inference.

Parameters:
  • config (Coqpit) – Model configuration.

  • checkpoint_path (str) – Path to the model checkpoint file.

  • eval (bool, optional) – If true, init model for inference else for training. Defaults to False.

  • strict (bool, optional) – Match all checkpoint keys to model’s keys. Defaults to True.

  • cache (bool, optional) – If True, cache the file locally for subsequent calls. It is cached under get_user_data_dir()/tts_cache. Defaults to False.

Base tts Model#

class TTS.tts.models.base_tts.BaseTTS(config, ap, tokenizer, speaker_manager=None, language_manager=None)[source]#

Base tts class. Every new tts model must inherit this.

It defines common tts specific functions on top of Model implementation.

format_batch(batch)[source]#

Generic batch formatting for TTSDataset.

You must override this if you use a custom dataset.

Parameters:

batch (Dict) – [description]

Returns:

[description]

Return type:

Dict

get_aux_input(**kwargs)[source]#

Prepare and return aux_input used by forward()

init_multispeaker(config, data=None)[source]#

Initialize a speaker embedding layer if needen and define expected embedding channel size for defining in_channels size of the connected layers.

This implementation yields 3 possible outcomes:

  1. If config.use_speaker_embedding and `config.use_d_vector_file are False, do nothing.

  2. If config.use_d_vector_file is True, set expected embedding channel size to config.d_vector_dim or 512.

3. If config.use_speaker_embedding, initialize a speaker embedding layer with channel size of config.d_vector_dim or 512.

You can override this function for new models.

Parameters:

config (Coqpit) – Model configuration.

on_init_start(trainer)[source]#

Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.

test_run(assets)[source]#

Generic test run for tts models used by Trainer.

You can override this for a different behaviour.

Parameters:

assets (dict) – A dict of training assets. For tts models, it must include {‘audio_processor’: ap}.

Returns:

Test figures and audios to be projected to Tensorboard.

Return type:

Tuple[Dict, Dict]

Base vocoder Model#

class TTS.vocoder.models.base_vocoder.BaseVocoder(config)[source]#

Base vocoder class. Every new vocoder model must inherit this.

It defines vocoder specific functions on top of Model.

Notes on input/output tensor shapes:

Any input or output tensor of the model must be shaped as

  • 3D tensors batch x time x channels

  • 2D tensors batch x channels

  • 1D tensors batch x 1