base#
(s3prl.problem.base)
The shared backbone of common ML train/test procedure for all problems
- Authors:
Leo 2022
Problem#
- class s3prl.problem.base.Problem[source][source]#
Bases:
object- classmethod get_class_from_name(name: str)[source][source]#
- Parameters:
name (str) – the
__name__of the problem class- Returns:
Problem
- build_collate_fn(build_collate_fn: dict, mode: str)[source][source]#
By default returns
s3prl.dataset.base.default_collate_fn- Parameters:
build_collate_fn (dict) – same in
default_config, no argument supported for nowmode (str) – train, valid, or test
- Returns:
callable
the collate_fn for torch DataLoader in train/valid/test
mode
- build_upstream(build_upstream: dict)[source][source]#
By default build the upstream with
s3prl.nn.upstream.S3PRLUpstream- Parameters:
build_upstream (dict) – same in
default_config, arguments fors3prl.nn.upstream.S3PRLUpstream- Returns:
s3prl.nn.interface.AbsUpstreamReturn an upstream model, whose forward takes the waveform input and returns multiple hidden states as features.
- build_featurizer(build_featurizer: dict, upstream)[source][source]#
By default build the featurizer with
s3prl.nn.Featurizer- Parameters:
build_featurizer (dict) – same in
default_config, arguments fors3prl.nn.Featurizerupstream (
AbsUpstream) – the upstream model built bybuild_upstream
- Returns:
s3prl.nn.interface.AbsFeaturizerReturn the featurizer model. The featurizer is used to reduce the multiple hidden states returned from the upstream model (built by
build_upstream) into a single hidden state, so can be easliy fed into the downstream model
- build_model(build_model: dict, model_output_size: int, build_upstream: dict, build_featurizer: dict, build_downstream: dict)[source][source]#
By default build model with
s3prl.nn.upstream.UpstreamDownstreamModel- Parameters:
build_model (dict) – same in
default_config, arguments fors3prl.nn.upstream.UpstreamDownstreamModelmodel_output_size (int) – the required model’s output hidden size
build_upstream (dict) – same in
default_config, refer tobuild_upstreambuild_featurizer (dict) – same in
default_config, refer tobuild_featurizerbuild_downstream (dict) – same in
default_config, refer tobuild_downstream
- Returns:
torch.nn.Module
Return the entire model for the task, which takes the direct items from DataLoader as the input. Usually, the components can be built by
build_upstream,build_featurizer,build_downstream, and are concated together to get the final model. The upstream extracts multiple hidden states, the featuizer reduce them into a single hidden state, and the downstream takes the hidden states as the feature for the downstream-specific model.
- build_optimizer(build_optimizer: dict, parameters)[source][source]#
- Parameters:
build_optimizer (dict) –
same in
default_config, refer to belowkey
description
name
(str) - the optimizer class name in
torch.optimconf
(dict) - the arguments for initializing the optimizer class. e.g.
{"lr": 1.0e-4}parameters (iterable) – the standard params accepted by
torch.optim.Optimizer.
- Returns:
torch.optim.OptimizerAn optimizer following standard torch usage
- build_scheduler(build_scheduler: dict, optimizer)[source][source]#
- Parameters:
build_scheduler (dict) –
same in
default_configkey
description
name
(str) - the scheduler class name in
torch.optim.lr_schedulerconf
(dict) - the arguments for initializing the scheduler class. e.g.
{"gamma": 0.01}fortorch.optim.lr_scheduler.StepLRoptimizer – the standard torch optimizer accepted by Scheduler in
torch.optim.lr_scheduler.
- Returns:
torch scheduler
A scheduler following standard torch usage
- train(train: dict, train_dir: str, build_model_all_args: dict, build_task_all_args_except_model: dict, save_model: dict, save_task: dict, build_optimizer: dict, build_scheduler: dict, evaluate: dict, train_dataset, train_batch_sampler, train_collate_fn, valid_dataset, valid_batch_sampler, valid_collate_fn, num_workers: int, world_size: int, rank: int, eval_batch: int, device: str, global_config: dict = None)[source][source]#
- Parameters:
train (dict) –
same in
default_configkey
description
total_steps
(int) - the total optimization steps
log_step
(int) - logging frequency. log every
log_stepstepeval_step
(int) - evaluation frequency. Evaluate every
eval_stepstep. Note that you can control how many batch to evaluate to speed up the development by theeval_batchargument inrunsave_step
(int) - save the checkpoint every
save_stepstep.gradient_clipping
(float) - clip the gradient. important for RNNs.
gradient_accumulate
(int) - accumulate multiple steps’ gradient before updating network parameters to simulate large-batch optimization.
valid_metric
(str) - the metric to select the best valid checkpoint. Different Tasks have different supported valid_metrics. See
build_taskfor the supported metrics.valid_higher_better
(bool) - some metrics are higher better, while some are lower better this will affect how to save the best validation checkpoint.
auto_resume
(bool) - if there are already the last checkpoint in
target_dir(seerun), whether to resume from it or delete it and start a new training session.resume_ckpt_dir
(str) - you can directly specify the checkpoint path to resume which is not necessary in
target_dir(seerun).seed
(int) - fix the seed before the training start
keep_num_ckpts
(int) - to prevent saving too many checkpoints, only save the
keep_num_ckptslatest checkpoints and delete the old ones.use_scheduler
(bool) - whether to use the scheduler
**others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.
- evaluate(evaluate: dict, mode: str, task, dataset, batch_sampler, collate_fn, eval_batch: int, dump_dir: str, device: str, num_workers: int)[source][source]#
The evaluate routine used by
train(during validation phase) andrun(during testing phase).- Parameters:
evaluate (dict) – same in
default_config, no argument supported for now**others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.
- save_model(save_model: dict, model_ckpt_dir: str, build_model_all_args: dict, model: Module)[source][source]#
Save the model state_dict and the model initialization arguments into the given directory. If you override this method, it is highly possible you also need to override
load_model- Parameters:
save_model (dict) – same in
default_config, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside thesave_modelfield. You can rely on theomegaconfpackage to simplify the duplication.model_ckpt_dir (str) – save the model into the this directory.
build_model_all_args (dict) – all the arguments of
build_model. By saving this dictionary, you can easily reconstruct the same model by callingbuild_modelwith the saved dictionary.model (torch.nn.Module) – the model to be saved.
- Returns:
None
- load_model(model_ckpt_dir: str)[source][source]#
Return the saved model.
- Parameters:
model_ckpt_dir (str) – Restore the model with
build_modeland the checkpoint saved in this directory.- Returns:
torch.nn.Module
- save_task(save_task: dict, task_ckpt_dir: str, build_task_all_args_except_model: dict, task: Task)[source][source]#
Save the task’s state,
task.get_state(), and the initialization arguments into the given directory. If you override this method, it is highly possible you also need to overrideload_task.- Parameters:
save_task (dict) – same in
default_config, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside thesave_taskfield. You can rely on theomegaconfpackage to simplify the duplication.task_ckpt_dir (str) – save the task into this directory.
build_task_all_args_except_model (dict) – all the arguments of
build_taskexcept themodelargument since the model should be sapartely saved bysave_model. By saving this dictionary, you can easily reconstruct the same task by callingbuild_taskwith the saved dictionary.task (Task) – the task to be saved.
- Returns:
None
- load_task(task_ckpt_dir: str, model: Module, task_overrides: dict = None)[source][source]#
Return the saved task.
- Parameters:
task_ckpt_dir (str) – Restore the task with
build_taskand the checkpoint saved in this directory.model (torch.nn.Module) – the model for the task, since the model is separately saved and is required for
build_task.task_overrides (dict) – overrides the saved initialization arguments, so can change the loaded task’s behavior. Like, change the decoding hyperparameters.
- Returns:
- load_model_and_task(ckpts_dir: str, task_overrides: dict = None)[source][source]#
This is a helper method to combine
load_modelandload_tasktogether to directly load the model and the task. This method assumes the model is saved underckpts_dir / 'model'and the task is saved underckpts_dir / 'task'- Returns:
tuple
model (
torch.nn.Module)task (
s3prl.task.Task)