GAN API#
The TTS.vocoder.models.gan.GAN
provides an easy way to implementing new GAN based models. You just need
to define the model architectures for the generator and the discriminator networks and give them to the GAN
class
to do its ✨️.
GAN#
- class TTS.vocoder.models.gan.GAN(config, ap=None)[source]#
-
- static format_batch(batch)[source]#
Format the batch for training.
- Parameters
batch (List) – Batch out of the dataloader.
- Returns
formatted model inputs.
- Return type
Dict
- forward(x)[source]#
Run the generator’s forward pass.
- Parameters
x (torch.Tensor) – Input tensor.
- Returns
output of the GAN generator network.
- Return type
torch.Tensor
- get_data_loader(config, assets, is_eval, samples, verbose, num_gpus, rank=None)[source]#
Initiate and return the GAN dataloader.
- Parameters
config (Coqpit) – Model config.
ap (AudioProcessor) – Audio processor.
is_eval (True) – Set the dataloader for evaluation if true.
samples (List) – Data samples.
verbose (bool) – Log information if true.
num_gpus (int) – Number of GPUs in use.
rank (int) – Rank of the current GPU. Defaults to None.
- Returns
Torch dataloader.
- Return type
DataLoader
- get_lr()[source]#
Set the initial learning rates for each optimizer.
- Returns
learning rates for each optimizer.
- Return type
List
- get_optimizer()[source]#
Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
- Returns
optimizers.
- Return type
List
- get_scheduler(optimizer)[source]#
Set the schedulers for each optimizer.
- Parameters
optimizer (List[torch.optim.Optimizer]) – List of optimizers.
- Returns
Schedulers, one for each optimizer.
- Return type
List
- inference(x)[source]#
Run the generator’s inference pass.
- Parameters
x (torch.Tensor) – Input tensor.
- Returns
output of the GAN generator network.
- Return type
torch.Tensor
- load_checkpoint(config, checkpoint_path, eval=False, cache=False)[source]#
Load a GAN checkpoint and initialize model parameters.
- Parameters
config (Coqpit) – Model config.
checkpoint_path (str) – Checkpoint file path.
eval (bool, optional) – If true, load the model for inference. If falseDefaults to False.
- on_train_step_start(trainer)[source]#
Enable the discriminator training based on steps_to_start_discriminator
- Parameters
trainer (Trainer) – Trainer object.
- train_step(batch, criterion, optimizer_idx)[source]#
Compute model outputs and the loss values. optimizer_idx selects the generator or the discriminator for network on the current pass.
- Parameters
batch (Dict) – Batch of samples returned by the dataloader.
criterion (Dict) – Criterion used to compute the losses.
optimizer_idx (int) – ID of the optimizer in use on the current pass.
- Raises
ValueError – optimizer_idx is an unexpected value.
- Returns
model outputs and the computed loss values.
- Return type
Tuple[Dict, Dict]