base#

(s3prl.problem.base)

The shared backbone of common ML train/test procedure for all problems

Authors:
  • Leo 2022

Problem#

class s3prl.problem.base.Problem[source][source]#

Bases: object

classmethod get_class_from_name(name: str)[source][source]#
Parameters:

name (str) – the __name__ of the problem class

Returns:

Problem

build_collate_fn(build_collate_fn: dict, mode: str)[source][source]#

By default returns s3prl.dataset.base.default_collate_fn

Parameters:
  • build_collate_fn (dict) – same in default_config, no argument supported for now

  • mode (str) – train, valid, or test

Returns:

callable

the collate_fn for torch DataLoader in train/valid/test mode

build_upstream(build_upstream: dict)[source][source]#

By default build the upstream with s3prl.nn.upstream.S3PRLUpstream

Parameters:

build_upstream (dict) – same in default_config, arguments for s3prl.nn.upstream.S3PRLUpstream

Returns:

s3prl.nn.interface.AbsUpstream

Return an upstream model, whose forward takes the waveform input and returns multiple hidden states as features.

build_featurizer(build_featurizer: dict, upstream)[source][source]#

By default build the featurizer with s3prl.nn.Featurizer

Parameters:
Returns:

s3prl.nn.interface.AbsFeaturizer

Return the featurizer model. The featurizer is used to reduce the multiple hidden states returned from the upstream model (built by build_upstream) into a single hidden state, so can be easliy fed into the downstream model

build_model(build_model: dict, model_output_size: int, build_upstream: dict, build_featurizer: dict, build_downstream: dict)[source][source]#

By default build model with s3prl.nn.upstream.UpstreamDownstreamModel

Parameters:
  • build_model (dict) – same in default_config, arguments for s3prl.nn.upstream.UpstreamDownstreamModel

  • model_output_size (int) – the required model’s output hidden size

  • build_upstream (dict) – same in default_config, refer to build_upstream

  • build_featurizer (dict) – same in default_config, refer to build_featurizer

  • build_downstream (dict) – same in default_config, refer to build_downstream

Returns:

torch.nn.Module

Return the entire model for the task, which takes the direct items from DataLoader as the input. Usually, the components can be built by build_upstream, build_featurizer, build_downstream, and are concated together to get the final model. The upstream extracts multiple hidden states, the featuizer reduce them into a single hidden state, and the downstream takes the hidden states as the feature for the downstream-specific model.

build_optimizer(build_optimizer: dict, parameters)[source][source]#
Parameters:
  • build_optimizer (dict) –

    same in default_config, refer to below

    key

    description

    name

    (str) - the optimizer class name in torch.optim

    conf

    (dict) - the arguments for initializing the optimizer class. e.g. {"lr": 1.0e-4}

  • parameters (iterable) – the standard params accepted by torch.optim.Optimizer.

Returns:

torch.optim.Optimizer

An optimizer following standard torch usage

build_scheduler(build_scheduler: dict, optimizer)[source][source]#
Parameters:
  • build_scheduler (dict) –

    same in default_config

    key

    description

    name

    (str) - the scheduler class name in torch.optim.lr_scheduler

    conf

    (dict) - the arguments for initializing the scheduler class. e.g. {"gamma": 0.01} for torch.optim.lr_scheduler.StepLR

  • optimizer – the standard torch optimizer accepted by Scheduler in torch.optim.lr_scheduler.

Returns:

torch scheduler

A scheduler following standard torch usage

train(train: dict, train_dir: str, build_model_all_args: dict, build_task_all_args_except_model: dict, save_model: dict, save_task: dict, build_optimizer: dict, build_scheduler: dict, evaluate: dict, train_dataset, train_batch_sampler, train_collate_fn, valid_dataset, valid_batch_sampler, valid_collate_fn, num_workers: int, world_size: int, rank: int, eval_batch: int, device: str, global_config: Optional[dict] = None)[source][source]#
Parameters:
  • train (dict) –

    same in default_config

    key

    description

    total_steps

    (int) - the total optimization steps

    log_step

    (int) - logging frequency. log every log_step step

    eval_step

    (int) - evaluation frequency. Evaluate every eval_step step. Note that you can control how many batch to evaluate to speed up the development by the eval_batch argument in run

    save_step

    (int) - save the checkpoint every save_step step.

    gradient_clipping

    (float) - clip the gradient. important for RNNs.

    gradient_accumulate

    (int) - accumulate multiple steps’ gradient before updating network parameters to simulate large-batch optimization.

    valid_metric

    (str) - the metric to select the best valid checkpoint. Different Tasks have different supported valid_metrics. See build_task for the supported metrics.

    valid_higher_better

    (bool) - some metrics are higher better, while some are lower better this will affect how to save the best validation checkpoint.

    auto_resume

    (bool) - if there are already the last checkpoint in target_dir (see run), whether to resume from it or delete it and start a new training session.

    resume_ckpt_dir

    (str) - you can directly specify the checkpoint path to resume which is not necessary in target_dir (see run).

    seed

    (int) - fix the seed before the training start

    keep_num_ckpts

    (int) - to prevent saving too many checkpoints, only save the keep_num_ckpts latest checkpoints and delete the old ones.

    use_scheduler

    (bool) - whether to use the scheduler

  • **others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.

evaluate(evaluate: dict, mode: str, task, dataset, batch_sampler, collate_fn, eval_batch: int, dump_dir: str, device: str, num_workers: int)[source][source]#

The evaluate routine used by train (during validation phase) and run (during testing phase).

Parameters:
  • evaluate (dict) – same in default_config, no argument supported for now

  • **others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.

save_model(save_model: dict, model_ckpt_dir: str, build_model_all_args: dict, model: Module)[source][source]#

Save the model state_dict and the model initialization arguments into the given directory. If you override this method, it is highly possible you also need to override load_model

Parameters:
  • save_model (dict) – same in default_config, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside the save_model field. You can rely on the omegaconf package to simplify the duplication.

  • model_ckpt_dir (str) – save the model into the this directory.

  • build_model_all_args (dict) – all the arguments of build_model. By saving this dictionary, you can easily reconstruct the same model by calling build_model with the saved dictionary.

  • model (torch.nn.Module) – the model to be saved.

Returns:

None

load_model(model_ckpt_dir: str)[source][source]#

Return the saved model.

Parameters:

model_ckpt_dir (str) – Restore the model with build_model and the checkpoint saved in this directory.

Returns:

torch.nn.Module

save_task(save_task: dict, task_ckpt_dir: str, build_task_all_args_except_model: dict, task: Task)[source][source]#

Save the task’s state, task.get_state(), and the initialization arguments into the given directory. If you override this method, it is highly possible you also need to override load_task.

Parameters:
  • save_task (dict) – same in default_config, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside the save_task field. You can rely on the omegaconf package to simplify the duplication.

  • task_ckpt_dir (str) – save the task into this directory.

  • build_task_all_args_except_model (dict) – all the arguments of build_task except the model argument since the model should be sapartely saved by save_model. By saving this dictionary, you can easily reconstruct the same task by calling build_task with the saved dictionary.

  • task (Task) – the task to be saved.

Returns:

None

load_task(task_ckpt_dir: str, model: Module, task_overrides: Optional[dict] = None)[source][source]#

Return the saved task.

Parameters:
  • task_ckpt_dir (str) – Restore the task with build_task and the checkpoint saved in this directory.

  • model (torch.nn.Module) – the model for the task, since the model is separately saved and is required for build_task.

  • task_overrides (dict) – overrides the saved initialization arguments, so can change the loaded task’s behavior. Like, change the decoding hyperparameters.

Returns:

s3prl.task.Task

load_model_and_task(ckpts_dir: str, task_overrides: Optional[dict] = None)[source][source]#

This is a helper method to combine load_model and load_task together to directly load the model and the task. This method assumes the model is saved under ckpts_dir / 'model' and the task is saved under ckpts_dir / 'task'

Returns:

tuple

  1. model (torch.nn.Module)

  2. task (s3prl.task.Task)

main(args: Optional[List[str]] = None)[source][source]#