GAN API

The TTS.vocoder.models.gan.GAN provides an easy way to implementing new GAN based models. You just need to define the model architectures for the generator and the discriminator networks and give them to the GAN class to do its ✨️.

GAN

class TTS.vocoder.models.gan.GAN(config)[source]
eval_log(ap, batch, outputs)[source]

Call _log() for evaluation.

eval_step(batch, criterion, optimizer_idx)[source]

Call train_step() with no_grad()

static format_batch(batch)[source]

Format the batch for training.

Parameters

batch (List) – Batch out of the dataloader.

Returns

formatted model inputs.

Return type

Dict

forward(x)[source]

Run the generator’s forward pass.

Parameters

x (torch.Tensor) – Input tensor.

Returns

output of the GAN generator network.

Return type

torch.Tensor

get_criterion()[source]

Return criterions for the optimizers

get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)[source]

Initiate and return the GAN dataloader.

Parameters
  • config (Coqpit) – Model config.

  • ap (AudioProcessor) – Audio processor.

  • is_eval (True) – Set the dataloader for evaluation if true.

  • data_items (List) – Data samples.

  • verbose (bool) – Log information if true.

  • num_gpus (int) – Number of GPUs in use.

Returns

Torch dataloader.

Return type

DataLoader

get_lr()[source]

Set the initial learning rates for each optimizer.

Returns

learning rates for each optimizer.

Return type

List

get_optimizer()[source]

Initiate and return the GAN optimizers based on the config parameters.

It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.

Returns

optimizers.

Return type

List

get_scheduler(optimizer)[source]

Set the schedulers for each optimizer.

Parameters

optimizer (List[torch.optim.Optimizer]) – List of optimizers.

Returns

Schedulers, one for each optimizer.

Return type

List

inference(x)[source]

Run the generator’s inference pass.

Parameters

x (torch.Tensor) – Input tensor.

Returns

output of the GAN generator network.

Return type

torch.Tensor

load_checkpoint(config, checkpoint_path, eval=False)[source]

Load a GAN checkpoint and initialize model parameters.

Parameters
  • config (Coqpit) – Model config.

  • checkpoint_path (str) – Checkpoint file path.

  • eval (bool, optional) – If true, load the model for inference. If falseDefaults to False.

on_train_step_start(trainer)[source]

Enable the discriminator training based on steps_to_start_discriminator

Parameters

trainer (Trainer) – Trainer object.

train_log(ap, batch, outputs)[source]

Call _log() for training.

train_step(batch, criterion, optimizer_idx)[source]

Compute model outputs and the loss values. optimizer_idx selects the generator or the discriminator for network on the current pass.

Parameters
  • batch (Dict) – Batch of samples returned by the dataloader.

  • criterion (Dict) – Criterion used to compute the losses.

  • optimizer_idx (int) – ID of the optimizer in use on the current pass.

Raises

ValueErroroptimizer_idx is an unexpected value.

Returns

model outputs and the computed loss values.

Return type

Tuple[Dict, Dict]