Source code for s3prl.problem.base

"""
The shared backbone of common ML train/test procedure for all problems

Authors:
  * Leo 2022
"""

from __future__ import annotations

import argparse
import inspect
import logging
import math
import os
import pickle
import shutil
import sys
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from time import time
from typing import Dict, List, Union

import omegaconf
import torch
import yaml
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from s3prl.dataio.collate_fn import default_collate_fn
from s3prl.dataio.sampler import DistributedBatchSamplerWrapper
from s3prl.nn.upstream import Featurizer, S3PRLUpstream, UpstreamDownstreamModel
from s3prl.task import Task
from s3prl.util.override import parse_overrides
from s3prl.util.seed import fix_random_seeds

logger = logging.getLogger(__name__)

LOGGING_FORMAT = "%(levelname)s | %(asctime)s | %(module)s:%(lineno)d | %(message)s"
ACCEPTABLE_ERRORS = [
    "CUDA out of memory",
    "Unable to find a valid cuDNN algorithm to run convolution",  # Usually caused by CUDA OOM
]
PRIMITIVE_TYPES = (int, float, bool, str)

DEFAULT_CONFIG_FORMAT = """
The default arguments for :obj:`run` in yaml.
Note that for the fields with inner values, like :code:`build_model`,
the outer field name corresponds to a method name, so you can find the method
:obj:`build_model`. Furthermore, the values inside that field will be
directly passed into the method. So by changing these inner values, you
can directly affect the behavior of the corresponding method. See the method
documentation for all the supported arguments and their meanings.

The methods affected by the following config are: {:s}

.. code-block:: yaml

{:s}
"""

__all__ = ["Problem"]


class _DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)


def _force_cacheable(data: dict):
    output = dict()
    for key, value in data.items():
        if isinstance(value, torch.Tensor):
            value = value.detach().cpu()
        output[key] = value
    return output


def _to_device(data, device: str):
    output = dict()
    for key, value in data.items():
        if isinstance(value, torch.Tensor):
            value = value.to(device)
        output[key] = value
    return output


def _doc_default_config(cls: Problem):
    """
    This is used to layout the :code:`default_config` dictionary into yaml format
    for :code:`default_config`'s docstring.
    """

    def _append_prefix_spaces(docstring: str):
        return "\n".join([f"  {line}" for line in docstring.split("\n")])

    obj = cls()
    try:
        config = obj.default_config()
    except:
        return
    else:
        methods = []
        for k, v in config.items():
            if hasattr(cls, k):
                methods.append(getattr(cls, k))
        method_links = " ".join([f":obj:`{method.__name__}`" for method in methods])

        yaml_str = yaml.dump(config, sort_keys=False, width=float("inf"))
        yaml_str = _append_prefix_spaces(yaml_str)
        cls.default_config.__doc__ = DEFAULT_CONFIG_FORMAT.format(
            method_links, yaml_str
        )


[docs]class Problem: _store: Dict[str, Problem] = dict() def __init_subclass__(cls) -> None: super().__init_subclass__() cls._store[cls.__name__] = cls _doc_default_config(cls)
[docs] @classmethod def get_class_from_name(cls, name: str): """ Args: name (str): the :code:`__name__` of the problem class Returns: Problem """ assert ( name in cls._store ), f"The class '{name}' is either not defined or not imported" return cls._store[name]
[docs] def build_collate_fn(self, build_collate_fn: dict, mode: str): """ By default returns :obj:`s3prl.dataset.base.default_collate_fn` Args: build_collate_fn (dict): same in :obj:`default_config`, no argument supported for now mode (str): train, valid, or test Returns: callable the collate_fn for torch DataLoader in train/valid/test :code:`mode` """ return default_collate_fn
[docs] def build_upstream(self, build_upstream: dict): """ By default build the upstream with :obj:`s3prl.nn.upstream.S3PRLUpstream` Args: build_upstream (dict): same in :obj:`default_config`, arguments for :obj:`s3prl.nn.upstream.S3PRLUpstream` Returns: :obj:`s3prl.nn.interface.AbsUpstream` Return an upstream model, whose forward takes the waveform input and returns multiple hidden states as features. """ upstream = S3PRLUpstream(**build_upstream) return upstream
[docs] def build_featurizer(self, build_featurizer: dict, upstream): """ By default build the featurizer with :obj:`s3prl.nn.Featurizer` Args: build_featurizer (dict): same in :obj:`default_config`, arguments for :obj:`s3prl.nn.Featurizer` upstream (:obj:`AbsUpstream`): the upstream model built by :obj:`build_upstream` Returns: :obj:`s3prl.nn.interface.AbsFeaturizer` Return the featurizer model. The featurizer is used to reduce the multiple hidden states returned from the upstream model (built by :obj:`build_upstream`) into a single hidden state, so can be easliy fed into the downstream model """ featurizer = Featurizer(upstream, **build_featurizer) return featurizer
[docs] def build_model( self, build_model: dict, model_output_size: int, build_upstream: dict, build_featurizer: dict, build_downstream: dict, ): """ By default build model with :obj:`s3prl.nn.upstream.UpstreamDownstreamModel` Args: build_model (dict): same in :obj:`default_config`, arguments for :obj:`s3prl.nn.upstream.UpstreamDownstreamModel` model_output_size (int): the required model's output hidden size build_upstream (dict): same in :obj:`default_config`, refer to :obj:`build_upstream` build_featurizer (dict): same in :obj:`default_config`, refer to :obj:`build_featurizer` build_downstream (dict): same in :obj:`default_config`, refer to :obj:`build_downstream` Returns: torch.nn.Module Return the entire model for the task, which takes the direct items from DataLoader as the input. Usually, the components can be built by :obj:`build_upstream`, :obj:`build_featurizer`, :obj:`build_downstream`, and are concated together to get the final model. The upstream extracts multiple hidden states, the featuizer reduce them into a single hidden state, and the downstream takes the hidden states as the feature for the downstream-specific model. """ upstream = self.build_upstream(build_upstream) featurizer: Featurizer = self.build_featurizer(build_featurizer, upstream) downstream = self.build_downstream( build_downstream, featurizer.output_size, model_output_size, featurizer.downsample_rate, ) model = UpstreamDownstreamModel(upstream, featurizer, downstream, **build_model) return model
[docs] def build_optimizer(self, build_optimizer: dict, parameters): """ Args: build_optimizer (dict): same in :obj:`default_config`, refer to below ==================== ==================== key description ==================== ==================== name (str) - the optimizer class name in :obj:`torch.optim` conf (dict) - the arguments for initializing the optimizer class. e.g. :code:`{"lr": 1.0e-4}` ==================== ==================== parameters (iterable): the standard params accepted by :obj:`torch.optim.Optimizer`. Returns: :obj:`torch.optim.Optimizer` An optimizer following standard torch usage """ def _default_build_optimizer(name: str, conf: dict): opt_cls = getattr(torch.optim, name) opt = opt_cls(parameters, **conf) return opt return _default_build_optimizer(**build_optimizer)
[docs] def build_scheduler(self, build_scheduler: dict, optimizer): """ Args: build_scheduler (dict): same in :obj:`default_config` ==================== ==================== key description ==================== ==================== name (str) - the scheduler class name in :obj:`torch.optim.lr_scheduler` conf (dict) - the arguments for initializing the scheduler class. e.g. :code:`{"gamma": 0.01}` for :obj:`torch.optim.lr_scheduler.StepLR` ==================== ==================== optimizer: the standard torch optimizer accepted by Scheduler in :obj:`torch.optim.lr_scheduler`. Returns: torch scheduler A scheduler following standard torch usage """ def _default_build_scheduler(name: str, conf: dict): scheduler_cls = getattr(torch.optim.lr_scheduler, name) scheduler = scheduler_cls(optimizer, **conf) return scheduler return _default_build_scheduler(**build_scheduler)
[docs] def train( self, train: dict, train_dir: str, build_model_all_args: dict, build_task_all_args_except_model: dict, save_model: dict, save_task: dict, build_optimizer: dict, build_scheduler: dict, evaluate: dict, train_dataset, train_batch_sampler, train_collate_fn, valid_dataset, valid_batch_sampler, valid_collate_fn, num_workers: int, world_size: int, rank: int, eval_batch: int, device: str, global_config: dict = None, ): """ Args: train (dict): same in :obj:`default_config` ========================== ==================== key description ========================== ==================== total_steps (int) - the total optimization steps log_step (int) - logging frequency. log every :code:`log_step` step eval_step (int) - evaluation frequency. Evaluate every :code:`eval_step` step. \ Note that you can control how many batch to evaluate to speed up the \ development by the :code:`eval_batch` argument in :obj:`run` save_step (int) - save the checkpoint every :code:`save_step` step. gradient_clipping (float) - clip the gradient. important for RNNs. gradient_accumulate (int) - accumulate multiple steps' gradient before updating network parameters \ to simulate large-batch optimization. valid_metric (str) - the metric to select the best valid checkpoint. Different Tasks have different \ supported valid_metrics. See :obj:`build_task` for the supported metrics. valid_higher_better (bool) - some metrics are higher better, while some are lower better \ this will affect how to save the best validation checkpoint. auto_resume (bool) - if there are already the last checkpoint in :code:`target_dir` (see :obj:`run`), \ whether to resume from it or delete it and start a new training session. resume_ckpt_dir (str) - you can directly specify the checkpoint path to resume which is not necessary \ in :code:`target_dir` (see :obj:`run`). seed (int) - fix the seed before the training start keep_num_ckpts (int) - to prevent saving too many checkpoints, only save the :code:`keep_num_ckpts` \ latest checkpoints and delete the old ones. use_scheduler (bool) - whether to use the scheduler ========================== ==================== **others: only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now. """ @dataclass class TrainConfig: total_steps: int log_step: int eval_step: int save_step: int gradient_clipping: float gradient_accumulate: int valid_metric: str valid_higher_better: bool auto_resume: bool = True resume_ckpt_dir: str = None seed: int = 0 keep_num_ckpts: int = 2 use_scheduler: bool = False conf = TrainConfig(**train) fix_random_seeds(conf.seed) train_dir: Path = Path(train_dir) if not conf.auto_resume and train_dir.is_dir(): logger.warning( f"{train_dir} exists. Delete the directory since auto_resume=False" ) shutil.rmtree(train_dir) train_dir.mkdir(exist_ok=True, parents=True) ckpt_dirs = [key for key in os.listdir(train_dir) if key.startswith("step_")] ckpt_dirs.sort(key=lambda name: int(name.split("_")[-1]), reverse=True) resume = False if conf.auto_resume: if conf.resume_ckpt_dir is not None and Path(conf.resume_ckpt_dir).is_dir(): resume = True if len(ckpt_dirs) > 0: resume = True if resume: resume_ckpt_dir = Path(conf.resume_ckpt_dir or train_dir / ckpt_dirs[0]) logger.info(f"Loading checkpoints from {resume_ckpt_dir}") try: _, task = self.load_model_and_task(resume_ckpt_dir) except: logger.error( f"Fail to load the checkpoint {resume_ckpt_dir}. " "You can set '--train.auto_resume False' to ignore the crashed checkpoint to avoid this behavior." ) raise optimizer_state = torch.load( resume_ckpt_dir / "optimizer.pt", map_location="cpu" ) if conf.use_scheduler: scheduler_state = torch.load( resume_ckpt_dir / "scheduler.pt", map_location="cpu" ) else: scheduler_state = None with open(resume_ckpt_dir / "training_stats.yaml", "r") as f: training_stats = yaml.load(f, Loader=yaml.FullLoader) global_step = int(training_stats["global_step"]) epoch = int(training_stats["epoch"]) valid_best_metrics = dict(training_stats["valid_best_metrics"]) else: model = self.build_model(**build_model_all_args) task = self.build_task(model=model, **build_task_all_args_except_model) optimizer_state = None scheduler_state = None global_step = 0 epoch = 0 valid_best_metrics = dict() device = torch.device(device) wrapped_task = task.to(device) if world_size > 1: torch.cuda.set_device(device.index) wrapped_task = _DistributedDataParallel( task, device_ids=[device.index], find_unused_parameters=True, output_device=device.index, ) optimizer = self.build_optimizer(build_optimizer, task.parameters()) if optimizer_state: optimizer.load_state_dict(optimizer_state) scheduler = None if conf.use_scheduler: scheduler = self.build_scheduler(build_scheduler, optimizer) if scheduler_state: scheduler.load_state_dict(scheduler_state) train_batch_sampler = DistributedBatchSamplerWrapper( train_batch_sampler, num_replicas=world_size, rank=rank, ) train_dataloader = DataLoader( train_dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn, ) tqdm_file = sys.stderr if rank == 0 else open(os.devnull, "w") pbar = tqdm( total=conf.total_steps, dynamic_ncols=True, desc="train", file=tqdm_file, ) pbar.n = global_step if rank == 0: tf_dir = train_dir / "tb" tf_logger = SummaryWriter(str(tf_dir)) def _save_ckpts_to_dir( ckpts_dir: str, task, optimizer, scheduler, build_model_all_args: dict, build_task_all_args_except_model: dict, save_model: dict, save_task: dict, training_stats: dict, global_config: dict, ): ckpts_dir: Path = Path(ckpts_dir) ckpts_dir.mkdir(exist_ok=True, parents=True) model_ckpt_dir = ckpts_dir / "model" self.save_model( save_model, model_ckpt_dir, build_model_all_args, task.model ) task_ckpt_dir = ckpts_dir / "task" self.save_task( save_task, task_ckpt_dir, build_task_all_args_except_model, task ) torch.save(optimizer.state_dict(), ckpts_dir / "optimizer.pt") if scheduler is not None: torch.save(scheduler.state_dict(), ckpts_dir / "scheduler.pt") with (ckpts_dir / "training_stats.yaml").open("w") as f: yaml.safe_dump(training_stats, f) with (ckpts_dir / "config.yaml").open("w") as f: yaml.safe_dump(global_config, f) backward_steps = 0 while pbar.n < pbar.total: train_batch_sampler.set_epoch(epoch), batch_results = [] logger.info(f"Start epoch {epoch}") for batch in train_dataloader: # try/except block for forward/backward try: if pbar.n >= pbar.total: break global_step = pbar.n + 1 wrapped_task.train() batch = _to_device(batch, device) loss, cacheable = wrapped_task("train", **batch) (loss / conf.gradient_accumulate).backward() batch_results.append(_force_cacheable(cacheable)) except RuntimeError as e: if world_size > 1: raise acceptable = False for acc_err in ACCEPTABLE_ERRORS: if str(e) in acc_err: acceptable = True break if not acceptable: raise logger.warning(f"Step {global_step}: {str(e)}") with torch.cuda.device(device): torch.cuda.empty_cache() optimizer.zero_grad() continue backward_steps += 1 if backward_steps % conf.gradient_accumulate > 0: continue grad_norm = torch.nn.utils.clip_grad_norm_( wrapped_task.parameters(), conf.gradient_clipping ) if math.isnan(grad_norm): logger.warning(f"[Runner] - grad norm is NaN at step {global_step}") else: optimizer.step() optimizer.zero_grad() if conf.use_scheduler: scheduler.step() if rank > 0: batch_results = [] pbar.update(1) continue def _log_results( split_name: str, logs: dict, tensorboard: SummaryWriter, global_step: int, ): logger.info(f"{split_name} at step {global_step}") for name, value in logs.items(): value = float(value) logger.info(f"{name}: {value}") tensorboard.add_scalar( f"{split_name}-{name}", value, global_step=global_step ) if global_step % conf.log_step == 0: logs = wrapped_task.reduction("train", batch_results) _log_results("train", logs, tf_logger, global_step) batch_results = [] save_names = [] if global_step % conf.eval_step == 0: assert ( valid_dataset is not None and valid_batch_sampler is not None ), f"valid dataset is not supported, please set train.eval_step to infinite" logs: dict = self.evaluate( evaluate, "valid", task, valid_dataset, valid_batch_sampler, valid_collate_fn, eval_batch, train_dir, device, num_workers, ) _log_results("valid", logs, tf_logger, global_step) valid_metrics = {k: float(v) for k, v in logs.items()} new_metric = valid_metrics[conf.valid_metric] best_metric = valid_best_metrics.get(conf.valid_metric) if best_metric is None: is_new_best = True elif conf.valid_higher_better: is_new_best = new_metric > best_metric else: is_new_best = new_metric < best_metric if is_new_best: valid_best_metrics = deepcopy(valid_metrics) save_names.append("valid_best") if global_step % conf.save_step == 0: ckpt_dirs = [ key for key in os.listdir(train_dir) if key.startswith("step_") ] ckpt_dirs.sort(key=lambda stem: int(stem.split("_")[-1])) if ( conf.keep_num_ckpts is not None and len(ckpt_dirs) >= conf.keep_num_ckpts ): for ckpt_dir in ckpt_dirs[ : len(ckpt_dirs) - conf.keep_num_ckpts + 1 ]: shutil.rmtree(train_dir / ckpt_dir) save_names.append(f"step_{global_step}") for name in save_names: training_stats = dict( global_step=global_step, epoch=epoch, valid_best_metrics=valid_best_metrics, ) _save_ckpts_to_dir( train_dir / name, ( task.module if isinstance(task, _DistributedDataParallel) else task ), optimizer, scheduler, build_model_all_args, build_task_all_args_except_model, save_model, save_task, training_stats, global_config, ) pbar.update(1) epoch += 1 pbar.close() if rank == 0: tf_logger.close()
[docs] def evaluate( self, evaluate: dict, mode: str, task, dataset, batch_sampler, collate_fn, eval_batch: int, dump_dir: str, device: str, num_workers: int, ): """ The evaluate routine used by :obj:`train` (during validation phase) and :obj:`run` (during testing phase). Args: evaluate (dict): same in :obj:`default_config`, no argument supported for now **others: only meaningful when you want to override this train method, which is not the common case. Hence we skip the documentation for now. """ assert mode in ["valid", "test"] dataloader = DataLoader( dataset, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, ) task = task.to(device) with torch.no_grad(): batch_results = [] for batch_idx, batch in enumerate( tqdm(dataloader, desc=mode, total=len(dataloader)) ): if batch_idx == eval_batch: break batch = _to_device(batch, device) task.eval() loss, cacheable = task(mode, _dump_dir=dump_dir, **batch) batch_results.append(_force_cacheable(cacheable)) logs = task.reduction(mode, batch_results, _dump_dir=dump_dir) return logs
[docs] def save_model( self, save_model: dict, model_ckpt_dir: str, build_model_all_args: dict, model: torch.nn.Module, ): """ Save the model state_dict and the model initialization arguments into the given directory. If you override this method, it is highly possible you also need to override :obj:`load_model` Args: save_model (dict): same in :obj:`default_config`, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside the :code:`save_model` field. You can rely on the :code:`omegaconf` package to simplify the duplication. model_ckpt_dir (str): save the model into the this directory. build_model_all_args (dict): all the arguments of :obj:`build_model`. By saving this dictionary, you can easily reconstruct the same model by calling :obj:`build_model` with the saved dictionary. model (torch.nn.Module): the model to be saved. Returns: None """ model_ckpt_dir: Path = Path(model_ckpt_dir) if model_ckpt_dir.is_dir(): shutil.rmtree(model_ckpt_dir, ignore_errors=True) model_ckpt_dir.mkdir(exist_ok=True, parents=True) with (model_ckpt_dir / "problem_name").open("w") as f: f.write(f"{self.__class__.__name__}") torch.save(model.state_dict(), model_ckpt_dir / "state_dict.pt") # NOTE: all arguments for building model should be in simple types (yaml serializable) with (model_ckpt_dir / f"arguments.yaml").open("w") as f: yaml.safe_dump(build_model_all_args, f) if len(save_model) > 0: with (model_ckpt_dir / "extra_conf.yaml").open("w") as f: yaml.safe_dump(save_model, f)
[docs] def load_model(self, model_ckpt_dir: str): """ Return the saved model. Args: model_ckpt_dir (str): Restore the model with :obj:`build_model` and the checkpoint saved in this directory. Return: :obj:`torch.nn.Module` """ model_ckpt_dir: Path = Path(model_ckpt_dir) with (model_ckpt_dir / "arguments.yaml").open("r") as f: arguments = yaml.load(f, Loader=yaml.SafeLoader) model = self.build_model(**arguments) state_dict = torch.load(model_ckpt_dir / "state_dict.pt", map_location="cpu") model.load_state_dict(state_dict) return model
[docs] def save_task( self, save_task: dict, task_ckpt_dir: str, build_task_all_args_except_model: dict, task: Task, ): """ Save the task's state, :code:`task.get_state()`, and the initialization arguments into the given directory. If you override this method, it is highly possible you also need to override :obj:`load_task`. Args: save_task (dict): same in :obj:`default_config`, so the user can save additional settings, like the configuration of the dataset by duplicating the dataset hypers inside the :code:`save_task` field. You can rely on the :code:`omegaconf` package to simplify the duplication. task_ckpt_dir (str): save the task into this directory. build_task_all_args_except_model (dict): all the arguments of :obj:`build_task` except the :code:`model` argument since the model should be sapartely saved by :obj:`save_model`. By saving this dictionary, you can easily reconstruct the same task by calling :obj:`build_task` with the saved dictionary. task (Task): the task to be saved. Returns: None """ task_ckpt_dir: Path = Path(task_ckpt_dir) if task_ckpt_dir.is_dir(): shutil.rmtree(task_ckpt_dir, ignore_errors=True) task_ckpt_dir.mkdir(exist_ok=True, parents=True) with (task_ckpt_dir / "problem_name").open("w") as f: f.write(f"{self.__class__.__name__}") torch.save(task.get_state(), task_ckpt_dir / "state.pt") # NOTE: each argument is saved independently to prevent SPOF # i.e. a single argument which cannot be loaded will lead to missing # all other arguments, since the single argument file cannot be loaded arguments = build_task_all_args_except_model arguments_dir = task_ckpt_dir / "arguments" arguments_dir.mkdir(exist_ok=True, parents=True) for k, v in arguments.items(): try: # If the object is yaml serializable, use yaml for readibility yaml.safe_dump(v) except: # If not, use pickle with (arguments_dir / f"{k}.pkl").open("wb") as f: pickle.dump(v, f) else: with (arguments_dir / f"{k}.yaml").open("w") as f: yaml.safe_dump(v, f) if len(save_task) > 0: with (task_ckpt_dir, "extra_conf.yaml").open("w") as f: yaml.safe_dump(save_task, f)
[docs] def load_task( self, task_ckpt_dir: str, model: torch.nn.Module, task_overrides: dict = None ): """ Return the saved task. Args: task_ckpt_dir (str): Restore the task with :obj:`build_task` and the checkpoint saved in this directory. model (torch.nn.Module): the model for the task, since the model is separately saved and is required for :obj:`build_task`. task_overrides (dict): overrides the saved initialization arguments, so can change the loaded task's behavior. Like, change the decoding hyperparameters. Returns: :obj:`s3prl.task.Task` """ task_ckpt_dir: Path = Path(task_ckpt_dir) task_overrides = task_overrides or {} arguments = task_overrides.copy() arguments_dir = task_ckpt_dir / "arguments" for filename in os.listdir(arguments_dir): filepath = arguments_dir / filename key = filepath.stem if key in task_overrides: # do not load the file (potential crash) if the overrides already has the value continue if filepath.suffix == ".yaml": with filepath.open("r") as f: value = yaml.load(f, Loader=yaml.SafeLoader) elif filepath.suffix == ".pkl": with filepath.open("rb") as f: value = pickle.load(f) assert key not in arguments, ( f"Unexpected duplicated file stem '{key}' found in {arguments_dir}. " "Please delete one of them." ) arguments[key] = value task = self.build_task(model=model, **arguments) state = torch.load(Path(task_ckpt_dir) / "state.pt", map_location="cpu") task.set_state(state) return task
[docs] def load_model_and_task(self, ckpts_dir: str, task_overrides: dict = None): """ This is a helper method to combine :obj:`load_model` and :obj:`load_task` together to directly load the model and the task. This method assumes the model is saved under :code:`ckpts_dir / 'model'` and the task is saved under :code:`ckpts_dir / 'task'` Returns: tuple 1. model (:obj:`torch.nn.Module`) 2. task (:obj:`s3prl.task.Task`) """ ckpts_dir: Path = Path(ckpts_dir) task_overrides = task_overrides or {} model = self.load_model(ckpts_dir / "model") task = self.load_task(ckpts_dir / "task", model, task_overrides) return model, task
@staticmethod def _get_current_arguments( exclude_self_and_cls: bool = True, flatten_dict: Union[str, List[str]] = None ) -> dict: if isinstance(flatten_dict, str): flatten_dict = [flatten_dict] frame = inspect.currentframe().f_back args, _, _, values = inspect.getargvalues(frame) config = {key: values[key] for key in args} if exclude_self_and_cls: config.pop("self", None) config.pop("cls", None) if flatten_dict is not None: flatten_config = {} for k, v in config.items(): if k in flatten_dict: assert isinstance(v, dict) for _k, _v in v.items(): flatten_config[_k] = _v else: flatten_config[k] = v config = flatten_config def assert_no_missing(config: dict): omegaconf.OmegaConf.to_container( omegaconf.OmegaConf.create(config), throw_on_missing=True ) assert_no_missing(config) return config @staticmethod def _get_time_tag(): return datetime.fromtimestamp(time()).strftime("%Y_%m_%d_%H_%M_%S") @staticmethod def _stage_check(stage_id: int, stop: int, check_fn: callable): try: check_fn() except: logger.error( f"Stage {stage_id} was not done before or is corrupted. Please re-run from this stage." ) raise if isinstance(stop, int) and stage_id == stop: exit(0)
[docs] def main(self, args: List[str] = None): parser = argparse.ArgumentParser() parser.add_argument("--verbose", default="INFO") parser.add_argument( "--config", help="The yaml config path to override the default config" ) parser.add_argument("--print_config", "-p", action="store_true") parser.add_argument( "--dump_config", "-d", help="The path to dump the default config as yaml" ) args, override = parser.parse_known_args(args) if args.print_config: print(f"\nDefault config of {self}\n") print(yaml.safe_dump(self.default_config())) exit(0) if args.dump_config is not None: with open(args.dump_config, "w") as f: yaml.safe_dump(self.default_config(), f) exit(0) root_logger = logging.getLogger() root_logger.handlers = [] logging.basicConfig(level=getattr(logging, args.verbose), format=LOGGING_FORMAT) if args.config is not None: with open(args.config) as f: yaml_conf = yaml.load(f, Loader=yaml.FullLoader) or dict() else: yaml_conf = dict() override_conf = parse_overrides(override) schema = omegaconf.OmegaConf.create(self.default_config()) config = omegaconf.OmegaConf.merge(schema, yaml_conf, override_conf) config = omegaconf.OmegaConf.to_container( config, resolve=True, throw_on_missing=True ) logger.info(config) self.run(**config) return config