base#
(s3prl.problem.base)
The shared backbone of common ML train/test procedure for all problems
- Authors:
Leo 2022
Problem#
- class s3prl.problem.base.Problem[source][source]#
Bases:
object
- classmethod get_class_from_name(name: str)[source][source]#
- Parameters:
name (str) – the
__name__
of the problem class- Returns:
Problem
- build_collate_fn(build_collate_fn: dict, mode: str)[source][source]#
By default returns
s3prl.dataset.base.default_collate_fn
- Parameters:
build_collate_fn (dict) – same in
default_config
, no argument supported for nowmode (str) – train, valid, or test
- Returns:
callable
the collate_fn for torch DataLoader in train/valid/test
mode
- build_upstream(build_upstream: dict)[source][source]#
By default build the upstream with
s3prl.nn.upstream.S3PRLUpstream
- Parameters:
build_upstream (dict) – same in
default_config
, arguments fors3prl.nn.upstream.S3PRLUpstream
- Returns:
s3prl.nn.interface.AbsUpstream
Return an upstream model, whose forward takes the waveform input and returns multiple hidden states as features.
- build_featurizer(build_featurizer: dict, upstream)[source][source]#
By default build the featurizer with
s3prl.nn.Featurizer
- Parameters:
build_featurizer (dict) – same in
default_config
, arguments fors3prl.nn.Featurizer
upstream (
AbsUpstream
) – the upstream model built bybuild_upstream
- Returns:
s3prl.nn.interface.AbsFeaturizer
Return the featurizer model. The featurizer is used to reduce the multiple hidden states returned from the upstream model (built by
build_upstream
) into a single hidden state, so can be easliy fed into the downstream model
- build_model(build_model: dict, model_output_size: int, build_upstream: dict, build_featurizer: dict, build_downstream: dict)[source][source]#
By default build model with
s3prl.nn.upstream.UpstreamDownstreamModel
- Parameters:
build_model (dict) – same in
default_config
, arguments fors3prl.nn.upstream.UpstreamDownstreamModel
model_output_size (int) – the required model’s output hidden size
build_upstream (dict) – same in
default_config
, refer tobuild_upstream
build_featurizer (dict) – same in
default_config
, refer tobuild_featurizer
build_downstream (dict) – same in
default_config
, refer tobuild_downstream
- Returns:
torch.nn.Module
Return the entire model for the task, which takes the direct items from DataLoader as the input. Usually, the components can be built by
build_upstream
,build_featurizer
,build_downstream
, and are concated together to get the final model. The upstream extracts multiple hidden states, the featuizer reduce them into a single hidden state, and the downstream takes the hidden states as the feature for the downstream-specific model.
- build_optimizer(build_optimizer: dict, parameters)[source][source]#
- Parameters:
build_optimizer (dict) –
same in
default_config
, refer to belowkey
description
name
(str) - the optimizer class name in
torch.optim
conf
(dict) - the arguments for initializing the optimizer class. e.g.
{"lr": 1.0e-4}
parameters (iterable) – the standard params accepted by
torch.optim.Optimizer
.
- Returns:
torch.optim.Optimizer
An optimizer following standard torch usage
- build_scheduler(build_scheduler: dict, optimizer)[source][source]#
- Parameters:
build_scheduler (dict) –
same in
default_config
key
description
name
(str) - the scheduler class name in
torch.optim.lr_scheduler
conf
(dict) - the arguments for initializing the scheduler class. e.g.
{"gamma": 0.01}
fortorch.optim.lr_scheduler.StepLR
optimizer – the standard torch optimizer accepted by Scheduler in
torch.optim.lr_scheduler
.
- Returns:
torch scheduler
A scheduler following standard torch usage
- train(train: dict, train_dir: str, build_model_all_args: dict, build_task_all_args_except_model: dict, save_model: dict, save_task: dict, build_optimizer: dict, build_scheduler: dict, evaluate: dict, train_dataset, train_batch_sampler, train_collate_fn, valid_dataset, valid_batch_sampler, valid_collate_fn, num_workers: int, world_size: int, rank: int, eval_batch: int, device: str, global_config: Optional[dict] = None)[source][source]#
- Parameters:
train (dict) –
same in
default_config
key
description
total_steps
(int) - the total optimization steps
log_step
(int) - logging frequency. log every
log_step
stepeval_step
(int) - evaluation frequency. Evaluate every
eval_step
step. Note that you can control how many batch to evaluate to speed up the development by theeval_batch
argument inrun
save_step
(int) - save the checkpoint every
save_step
step.gradient_clipping
(float) - clip the gradient. important for RNNs.
gradient_accumulate
(int) - accumulate multiple steps’ gradient before updating network parameters to simulate large-batch optimization.
valid_metric
(str) - the metric to select the best valid checkpoint. Different Tasks have different supported valid_metrics. See
build_task
for the supported metrics.valid_higher_better
(bool) - some metrics are higher better, while some are lower better this will affect how to save the best validation checkpoint.
auto_resume
(bool) - if there are already the last checkpoint in
target_dir
(seerun
), whether to resume from it or delete it and start a new training session.resume_ckpt_dir
(str) - you can directly specify the checkpoint path to resume which is not necessary in
target_dir
(seerun
).seed
(int) - fix the seed before the training start
keep_num_ckpts
(int) - to prevent saving too many checkpoints, only save the
keep_num_ckpts
latest checkpoints and delete the old ones.use_scheduler
(bool) - whether to use the scheduler
**others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.
- evaluate(evaluate: dict, mode: str, task, dataset, batch_sampler, collate_fn, eval_batch: int, dump_dir: str, device: str, num_workers: int)[source][source]#
The evaluate routine used by
train
(during validation phase) andrun
(during testing phase).- Parameters:
evaluate (dict) – same in
default_config
, no argument supported for now**others – only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now.
- save_model(save_model: dict, model_ckpt_dir: str, build_model_all_args: dict, model: Module)[source][source]#
Save the model state_dict and the model initialization arguments into the given directory. If you override this method, it is highly possible you also need to override
load_model
- Parameters:
save_model (dict) – same in
default_config
, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside thesave_model
field. You can rely on theomegaconf
package to simplify the duplication.model_ckpt_dir (str) – save the model into the this directory.
build_model_all_args (dict) – all the arguments of
build_model
. By saving this dictionary, you can easily reconstruct the same model by callingbuild_model
with the saved dictionary.model (torch.nn.Module) – the model to be saved.
- Returns:
None
- load_model(model_ckpt_dir: str)[source][source]#
Return the saved model.
- Parameters:
model_ckpt_dir (str) – Restore the model with
build_model
and the checkpoint saved in this directory.- Returns:
torch.nn.Module
- save_task(save_task: dict, task_ckpt_dir: str, build_task_all_args_except_model: dict, task: Task)[source][source]#
Save the task’s state,
task.get_state()
, and the initialization arguments into the given directory. If you override this method, it is highly possible you also need to overrideload_task
.- Parameters:
save_task (dict) – same in
default_config
, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside thesave_task
field. You can rely on theomegaconf
package to simplify the duplication.task_ckpt_dir (str) – save the task into this directory.
build_task_all_args_except_model (dict) – all the arguments of
build_task
except themodel
argument since the model should be sapartely saved bysave_model
. By saving this dictionary, you can easily reconstruct the same task by callingbuild_task
with the saved dictionary.task (Task) – the task to be saved.
- Returns:
None
- load_task(task_ckpt_dir: str, model: Module, task_overrides: Optional[dict] = None)[source][source]#
Return the saved task.
- Parameters:
task_ckpt_dir (str) – Restore the task with
build_task
and the checkpoint saved in this directory.model (torch.nn.Module) – the model for the task, since the model is separately saved and is required for
build_task
.task_overrides (dict) – overrides the saved initialization arguments, so can change the loaded task’s behavior. Like, change the decoding hyperparameters.
- Returns:
- load_model_and_task(ckpts_dir: str, task_overrides: Optional[dict] = None)[source][source]#
This is a helper method to combine
load_model
andload_task
together to directly load the model and the task. This method assumes the model is saved underckpts_dir / 'model'
and the task is saved underckpts_dir / 'task'
- Returns:
tuple
model (
torch.nn.Module
)task (
s3prl.task.Task
)