GAN API#

The TTS.vocoder.models.gan.GAN provides an easy way to implementing new GAN based models. You just need to define the model architectures for the generator and the discriminator networks and give them to the GAN class to do its ✨️.

GAN#

class TTS.vocoder.models.gan.GAN(config, ap=None)[source]#
eval_log(batch, outputs, logger, assets, steps)[source]#

Call _log() for evaluation.

eval_step(batch, criterion, optimizer_idx)[source]#

Call train_step() with no_grad()

static format_batch(batch)[source]#

Format the batch for training.

Parameters:

batch (List) – Batch out of the dataloader.

Returns:

formatted model inputs.

Return type:

Dict

forward(x)[source]#

Run the generator’s forward pass.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

output of the GAN generator network.

Return type:

torch.Tensor

get_criterion()[source]#

Return criterions for the optimizers

get_data_loader(config, assets, is_eval, samples, verbose, num_gpus, rank=None)[source]#

Initiate and return the GAN dataloader.

Parameters:
  • config (Coqpit) – Model config.

  • ap (AudioProcessor) – Audio processor.

  • is_eval (True) – Set the dataloader for evaluation if true.

  • samples (List) – Data samples.

  • verbose (bool) – Log information if true.

  • num_gpus (int) – Number of GPUs in use.

  • rank (int) – Rank of the current GPU. Defaults to None.

Returns:

Torch dataloader.

Return type:

DataLoader

get_lr()[source]#

Set the initial learning rates for each optimizer.

Returns:

learning rates for each optimizer.

Return type:

List

get_optimizer()[source]#

Initiate and return the GAN optimizers based on the config parameters.

It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.

Returns:

optimizers.

Return type:

List

get_scheduler(optimizer)[source]#

Set the schedulers for each optimizer.

Parameters:

optimizer (List[torch.optim.Optimizer]) – List of optimizers.

Returns:

Schedulers, one for each optimizer.

Return type:

List

inference(x)[source]#

Run the generator’s inference pass.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

output of the GAN generator network.

Return type:

torch.Tensor

load_checkpoint(config, checkpoint_path, eval=False, cache=False)[source]#

Load a GAN checkpoint and initialize model parameters.

Parameters:
  • config (Coqpit) – Model config.

  • checkpoint_path (str) – Checkpoint file path.

  • eval (bool, optional) – If true, load the model for inference. If falseDefaults to False.

on_train_step_start(trainer)[source]#

Enable the discriminator training based on steps_to_start_discriminator

Parameters:

trainer (Trainer) – Trainer object.

train_log(batch, outputs, logger, assets, steps)[source]#

Call _log() for training.

train_step(batch, criterion, optimizer_idx)[source]#

Compute model outputs and the loss values. optimizer_idx selects the generator or the discriminator for network on the current pass.

Parameters:
  • batch (Dict) – Batch of samples returned by the dataloader.

  • criterion (Dict) – Criterion used to compute the losses.

  • optimizer_idx (int) – ID of the optimizer in use on the current pass.

Raises:

ValueErroroptimizer_idx is an unexpected value.

Returns:

model outputs and the computed loss values.

Return type:

Tuple[Dict, Dict]