Trainer API

The TTS.trainer.Trainer provides a lightweight, extensible, and feature-complete training run-time. We optimized it for 🐸 but can also be used for any DL training in different domains. It supports distributed multi-gpu, mixed-precision (apex or torch.amp) training.

Trainer

class TTS.trainer.Trainer(args, config, output_path, c_logger=None, tb_logger=None, model=None, cudnn_benchmark=False)[source]
eval_epoch()[source]

Main entry point for the evaluation loop. Run evaluation on the all validation samples.

eval_step(batch, step)[source]

Perform a evaluation step on a batch of inputs and log the process.

Parameters
  • batch (Dict) – Input batch.

  • step (int) – Current step number in this epoch.

Returns

Model outputs and losses.

Return type

Tuple[Dict, Dict]

fit()[source]

Where the ✨️magic✨️ happens…

format_batch(batch)[source]

Format the dataloader output and return a batch.

Parameters

batch (List) – Batch returned by the dataloader.

Returns

Formatted batch.

Return type

Dict

static get_criterion(model)[source]

Receive the criterion from the model. Model must implement get_criterion().

Parameters

model (nn.Module) – Training model.

Returns

Criterion layer.

Return type

nn.Module

static get_lr(model, config)[source]

Set the initial learning rate by the model if model implements get_lr() else try setting the learning rate fromthe config.

Parameters
  • model (nn.Module) – Training model.

  • config (Coqpit) – Training configuration.

Returns

A single learning rate or a list of learning rates, one for each optimzier.

Return type

Union[float, List[float]]

static get_model(config)[source]

Initialize model from config.

Parameters

config (Coqpit) – Model config.

Returns

initialized model.

Return type

nn.Module

static get_optimizer(model, config)[source]

Receive the optimizer from the model if model implements get_optimizer() else check the optimizer parameters in the config and try initiating the optimizer.

Parameters
  • model (nn.Module) – Training model.

  • config (Coqpit) – Training configuration.

Returns

A optimizer or a list of optimizers. GAN models define a list.

Return type

Union[torch.optim.Optimizer, List]

static get_scheduler(model, config, optimizer)[source]

Receive the scheduler from the model if model implements get_scheduler() else check the config and try initiating the scheduler.

Parameters
  • model (nn.Module) – Training model.

  • config (Coqpit) – Training configuration.

Returns

A scheduler or a list of schedulers, one for each optimizer.

Return type

Union[torch.optim.Optimizer, List]

get_train_dataloader(ap, data_items, verbose)[source]

Initialize and return a training data loader.

Parameters
  • ap (AudioProcessor) – Audio processor.

  • data_items (List) – Data samples used for training.

  • verbose (bool) – enable/disable printing loader stats at initialization.

Returns

Initialized training data loader.

Return type

DataLoader

restore_model(config, restore_path, model, optimizer, scaler=None)[source]

Restore training from an old run. It restores model, optimizer, AMP scaler and training stats.

Parameters
  • config (Coqpit) – Model config.

  • restore_path (str) – Path to the restored training run.

  • model (nn.Module) – Model to restored.

  • optimizer (torch.optim.Optimizer) – Optimizer to restore.

  • scaler (torch.cuda.amp.GradScaler, optional) – AMP scaler to restore. Defaults to None.

Returns

[description]

Return type

Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]

save_best_model()[source]

Save the best model. It only saves if the current target loss is smaller then the previous.

test_run()[source]

Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.

train_epoch()[source]

Main entry point for the training loop. Run training on the all training samples.

train_step(batch, batch_n_steps, step, loader_start_time)[source]

Perform a training step on a batch of inputs and log the process.

Parameters
  • batch (Dict) – Input batch.

  • batch_n_steps (int) – Number of steps needed to complete an epoch. Needed for logging.

  • step (int) – Current step number in this epoch.

  • loader_start_time (float) – The time when the data loading is started. Needed for logging.

Returns

Model outputs and losses.

Return type

Tuple[Dict, Dict]

TrainingArgs

class TTS.trainer.TrainingArgs(continue_path='', restore_path='', best_path='', config_path='', rank=0, group_id='')[source]

Trainer arguments to be defined externally. It helps integrating the Trainer with the higher level APIs and set the values for distributed training.