GAN API#
The TTS.vocoder.models.gan.GAN
provides an easy way to implementing new GAN based models. You just need
to define the model architectures for the generator and the discriminator networks and give them to the GAN
class
to do its ✨️.
GAN#
- class TTS.vocoder.models.gan.GAN(config, ap=None)[source]#
-
- static format_batch(batch)[source]#
Format the batch for training.
- Parameters:
batch (List) – Batch out of the dataloader.
- Returns:
formatted model inputs.
- Return type:
Dict
- forward(x)[source]#
Run the generator’s forward pass.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
output of the GAN generator network.
- Return type:
torch.Tensor
- get_data_loader(config, assets, is_eval, samples, verbose, num_gpus, rank=None)[source]#
Initiate and return the GAN dataloader.
- Parameters:
config (Coqpit) – Model config.
ap (AudioProcessor) – Audio processor.
is_eval (True) – Set the dataloader for evaluation if true.
samples (List) – Data samples.
verbose (bool) – Log information if true.
num_gpus (int) – Number of GPUs in use.
rank (int) – Rank of the current GPU. Defaults to None.
- Returns:
Torch dataloader.
- Return type:
DataLoader
- get_lr()[source]#
Set the initial learning rates for each optimizer.
- Returns:
learning rates for each optimizer.
- Return type:
List
- get_optimizer()[source]#
Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
- Returns:
optimizers.
- Return type:
List
- get_scheduler(optimizer)[source]#
Set the schedulers for each optimizer.
- Parameters:
optimizer (List[torch.optim.Optimizer]) – List of optimizers.
- Returns:
Schedulers, one for each optimizer.
- Return type:
List
- inference(x)[source]#
Run the generator’s inference pass.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
output of the GAN generator network.
- Return type:
torch.Tensor
- load_checkpoint(config, checkpoint_path, eval=False, cache=False)[source]#
Load a GAN checkpoint and initialize model parameters.
- Parameters:
config (Coqpit) – Model config.
checkpoint_path (str) – Checkpoint file path.
eval (bool, optional) – If true, load the model for inference. If falseDefaults to False.
- on_train_step_start(trainer)[source]#
Enable the discriminator training based on steps_to_start_discriminator
- Parameters:
trainer (Trainer) – Trainer object.
- train_step(batch, criterion, optimizer_idx)[source]#
Compute model outputs and the loss values. optimizer_idx selects the generator or the discriminator for network on the current pass.
- Parameters:
batch (Dict) – Batch of samples returned by the dataloader.
criterion (Dict) – Criterion used to compute the losses.
optimizer_idx (int) – ID of the optimizer in use on the current pass.
- Raises:
ValueError – optimizer_idx is an unexpected value.
- Returns:
model outputs and the computed loss values.
- Return type:
Tuple[Dict, Dict]