run#

(s3prl.problem.asv.run)

The backbone run procedure of ASV tasks

Authors
  • Haibin Wu 2022

  • Leo 2022

ASV#

class s3prl.problem.asv.run.ASV[source][source]#

Bases: Problem

run(target_dir: str, cache_dir: str, remove_all_cache: bool = False, start: int = 0, stop: Optional[int] = None, num_workers: int = 6, eval_batch: int = -1, device: str = 'cuda', world_size: int = 1, rank: int = 0, test_ckpt_dir: Optional[str] = None, test_ckpt_steps: Optional[List[int]] = None, prepare_data: Optional[dict] = None, build_encoder: Optional[dict] = None, build_dataset: Optional[dict] = None, build_batch_sampler: Optional[dict] = None, build_collate_fn: Optional[dict] = None, build_upstream: Optional[dict] = None, build_featurizer: Optional[dict] = None, build_downstream: Optional[dict] = None, build_model: Optional[dict] = None, build_task: Optional[dict] = None, build_optimizer: Optional[dict] = None, build_scheduler: Optional[dict] = None, save_model: Optional[dict] = None, save_task: Optional[dict] = None, train: Optional[dict] = None, evaluate: Optional[dict] = None)[source][source]#

stage

description

0

Parse the corpus and save the metadata file (waveform path, label…)

1

Build the encoder for encoding the speaker labels

2

Train the model

3

Evaluate the model on multiple test sets, multiple checkpoints will be evaluated for each test set (See test_ckpt_steps)

4

Report the best result find on each test set

Parameters:
  • target_dir (str) – The directory that stores the script result.

  • cache_dir (str) – The directory that caches the processed data. Default: /home/user/.cache/s3prl/data

  • remove_all_cache (bool) – Whether to remove all the cache stored under cache_dir. Default: False

  • start (int) – The starting stage of the problem script. Default: 0

  • stop (int) – The stoping stage of the problem script, set None to reach the final stage. Default: None

  • num_workers (int) – num_workers for all the torch DataLoder

  • eval_batch (int) – During evaluation (valid or test), limit the number of batch. This is helpful for the fast development to check everything won’t crash. If is -1, disable this feature and evaluate the entire epoch. Default: -1

  • device (str) – The device type for all torch-related operation: “cpu” or “cuda” Default: “cuda”

  • world_size (int) – How many processes are running this script simultaneously (in parallel). Usually this is just 1, however if you are runnig distributed training, this should be > 1. Default: 1

  • rank (int) – When distributed training, world_size > 1. Take world_size == 8 for example, this means 8 processes (8 GPUs) are runing in parallel. The script needs to know which process among 8 processes it is. In this case, rank can range from 0~7. All the 8 processes have the same world_size but different rank (process id).

  • test_ckpt_dir (str) – Specify the checkpoint path for testing. If not, use checkpoints specified by test_ckpts_steps.

  • test_ckpt_steps (List[int]) – After training, multiple steps of checkpoints are saved. This option specifies which checkpoints (multiple) will be used for evaluation.

  • **kwds – The other arguments like prepare_data and build_model are method specific-arguments for methods like prepare_data and build_model, and will not be used in the core run logic. See the specific method documentation for their supported arguments and meaning

build_task(build_task: dict, model, encoder, test_trials=None)[source][source]#

Build the task, which defines the logics for every train/valid/test forward step for the model, and the logics for how to reduce all the batch results from multiple train/valid/test steps into metrics

By default build SpeakerVerification

Parameters:
  • build_task (dict) – same in default_config, no argument supported for now

  • model (torch.nn.Module) – the model built by build_model

  • encoder – the encoder built by build_encoder

  • test_trials (List[Tuple[int, str, str]]) – each tuple in the list consists of (label, enroll_utt_id, test_utt_id). label is either 0 or 1

Returns:

Task

build_collate_fn(build_collate_fn: dict, mode: str)[source]#

By default returns s3prl.dataset.base.default_collate_fn

Parameters:
  • build_collate_fn (dict) – same in default_config, no argument supported for now

  • mode (str) – train, valid, or test

Returns:

callable

the collate_fn for torch DataLoader in train/valid/test mode

build_featurizer(build_featurizer: dict, upstream)[source]#

By default build the featurizer with s3prl.nn.Featurizer

Parameters:
Returns:

s3prl.nn.interface.AbsFeaturizer

Return the featurizer model. The featurizer is used to reduce the multiple hidden states returned from the upstream model (built by build_upstream) into a single hidden state, so can be easliy fed into the downstream model

build_model(build_model: dict, model_output_size: int, build_upstream: dict, build_featurizer: dict, build_downstream: dict)[source]#

By default build model with s3prl.nn.upstream.UpstreamDownstreamModel

Parameters:
  • build_model (dict) – same in default_config, arguments for s3prl.nn.upstream.UpstreamDownstreamModel

  • model_output_size (int) – the required model’s output hidden size

  • build_upstream (dict) – same in default_config, refer to build_upstream

  • build_featurizer (dict) – same in default_config, refer to build_featurizer

  • build_downstream (dict) – same in default_config, refer to build_downstream

Returns:

torch.nn.Module

Return the entire model for the task, which takes the direct items from DataLoader as the input. Usually, the components can be built by build_upstream, build_featurizer, build_downstream, and are concated together to get the final model. The upstream extracts multiple hidden states, the featuizer reduce them into a single hidden state, and the downstream takes the hidden states as the feature for the downstream-specific model.

build_optimizer(build_optimizer: dict, parameters)[source]#
Parameters:
  • build_optimizer (dict) –

    same in default_config, refer to below

    key

    description

    name

    (str) - the optimizer class name in torch.optim

    conf

    (dict) - the arguments for initializing the optimizer class. e.g. {"lr": 1.0e-4}

  • parameters (iterable) – the standard params accepted by torch.optim.Optimizer.

Returns:

torch.optim.Optimizer

An optimizer following standard torch usage

build_scheduler(build_scheduler: dict, optimizer)[source]#
Parameters:
  • build_scheduler (dict) –

    same in default_config

    key

    description

    name

    (str) - the scheduler class name in torch.optim.lr_scheduler

    conf

    (dict) - the arguments for initializing the scheduler class. e.g. {"gamma": 0.01} for torch.optim.lr_scheduler.StepLR

  • optimizer – the standard torch optimizer accepted by Scheduler in torch.optim.lr_scheduler.

Returns:

torch scheduler

A scheduler following standard torch usage

build_upstream(build_upstream: dict)[source]#

By default build the upstream with s3prl.nn.upstream.S3PRLUpstream

Parameters:

build_upstream (dict) – same in default_config, arguments for s3prl.nn.upstream.S3PRLUpstream

Returns:

s3prl.nn.interface.AbsUpstream

Return an upstream model, whose forward takes the waveform input and returns multiple hidden states as features.

evaluate(evaluate: dict, mode: str, task, dataset, batch_sampler, collate_fn, eval_batch: int, dump_dir: str, device: str, num_workers: int)[source]#

The evaluate routine used by train (during validation phase) and run (during testing phase).

Parameters:
  • evaluate (dict) – same in default_config, no argument supported for now

  • **others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.

classmethod get_class_from_name(name: str)[source]#
Parameters:

name (str) – the __name__ of the problem class

Returns:

Problem

load_model(model_ckpt_dir: str)[source]#

Return the saved model.

Parameters:

model_ckpt_dir (str) – Restore the model with build_model and the checkpoint saved in this directory.

Returns:

torch.nn.Module

load_model_and_task(ckpts_dir: str, task_overrides: Optional[dict] = None)[source]#

This is a helper method to combine load_model and load_task together to directly load the model and the task. This method assumes the model is saved under ckpts_dir / 'model' and the task is saved under ckpts_dir / 'task'

Returns:

tuple

  1. model (torch.nn.Module)

  2. task (s3prl.task.Task)

load_task(task_ckpt_dir: str, model: Module, task_overrides: Optional[dict] = None)[source]#

Return the saved task.

Parameters:
  • task_ckpt_dir (str) – Restore the task with build_task and the checkpoint saved in this directory.

  • model (torch.nn.Module) – the model for the task, since the model is separately saved and is required for build_task.

  • task_overrides (dict) – overrides the saved initialization arguments, so can change the loaded task’s behavior. Like, change the decoding hyperparameters.

Returns:

s3prl.task.Task

main(args: Optional[List[str]] = None)[source]#
save_model(save_model: dict, model_ckpt_dir: str, build_model_all_args: dict, model: Module)[source]#

Save the model state_dict and the model initialization arguments into the given directory. If you override this method, it is highly possible you also need to override load_model

Parameters:
  • save_model (dict) – same in default_config, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside the save_model field. You can rely on the omegaconf package to simplify the duplication.

  • model_ckpt_dir (str) – save the model into the this directory.

  • build_model_all_args (dict) – all the arguments of build_model. By saving this dictionary, you can easily reconstruct the same model by calling build_model with the saved dictionary.

  • model (torch.nn.Module) – the model to be saved.

Returns:

None

save_task(save_task: dict, task_ckpt_dir: str, build_task_all_args_except_model: dict, task: Task)[source]#

Save the task’s state, task.get_state(), and the initialization arguments into the given directory. If you override this method, it is highly possible you also need to override load_task.

Parameters:
  • save_task (dict) – same in default_config, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside the save_task field. You can rely on the omegaconf package to simplify the duplication.

  • task_ckpt_dir (str) – save the task into this directory.

  • build_task_all_args_except_model (dict) – all the arguments of build_task except the model argument since the model should be sapartely saved by save_model. By saving this dictionary, you can easily reconstruct the same task by calling build_task with the saved dictionary.

  • task (Task) – the task to be saved.

Returns:

None

train(train: dict, train_dir: str, build_model_all_args: dict, build_task_all_args_except_model: dict, save_model: dict, save_task: dict, build_optimizer: dict, build_scheduler: dict, evaluate: dict, train_dataset, train_batch_sampler, train_collate_fn, valid_dataset, valid_batch_sampler, valid_collate_fn, num_workers: int, world_size: int, rank: int, eval_batch: int, device: str, global_config: Optional[dict] = None)[source]#
Parameters:
  • train (dict) –

    same in default_config

    key

    description

    total_steps

    (int) - the total optimization steps

    log_step

    (int) - logging frequency. log every log_step step

    eval_step

    (int) - evaluation frequency. Evaluate every eval_step step. Note that you can control how many batch to evaluate to speed up the development by the eval_batch argument in run

    save_step

    (int) - save the checkpoint every save_step step.

    gradient_clipping

    (float) - clip the gradient. important for RNNs.

    gradient_accumulate

    (int) - accumulate multiple steps’ gradient before updating network parameters to simulate large-batch optimization.

    valid_metric

    (str) - the metric to select the best valid checkpoint. Different Tasks have different supported valid_metrics. See build_task for the supported metrics.

    valid_higher_better

    (bool) - some metrics are higher better, while some are lower better this will affect how to save the best validation checkpoint.

    auto_resume

    (bool) - if there are already the last checkpoint in target_dir (see run), whether to resume from it or delete it and start a new training session.

    resume_ckpt_dir

    (str) - you can directly specify the checkpoint path to resume which is not necessary in target_dir (see run).

    seed

    (int) - fix the seed before the training start

    keep_num_ckpts

    (int) - to prevent saving too many checkpoints, only save the keep_num_ckpts latest checkpoints and delete the old ones.

    use_scheduler

    (bool) - whether to use the scheduler

  • **others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.