Source code for s3prl.problem.common.run

"""
The backbone run procedure for the common train/valid/test

Authors
  * Leo 2022
"""

import logging
import pickle
import shutil
from pathlib import Path

import pandas as pd
import torch
import yaml

from s3prl.problem.base import Problem
from s3prl.task.utterance_classification_task import UtteranceClassificationTask

logger = logging.getLogger(__name__)

__all__ = ["Common"]


[docs]class Common(Problem):
[docs] def run( self, target_dir: str, cache_dir: str = None, remove_all_cache: bool = False, start: int = 0, stop: int = None, num_workers: int = 6, eval_batch: int = -1, device: str = "cuda", world_size: int = 1, rank: int = 0, test_ckpt_dir: str = None, prepare_data: dict = None, build_encoder: dict = None, build_dataset: dict = None, build_batch_sampler: dict = None, build_collate_fn: dict = None, build_upstream: dict = None, build_featurizer: dict = None, build_downstream: dict = None, build_model: dict = None, build_task: dict = None, build_optimizer: dict = None, build_scheduler: dict = None, save_model: dict = None, save_task: dict = None, train: dict = None, evaluate: dict = None, ): """ ======== ==================== stage description ======== ==================== 0 Parse the corpus and save the metadata file (waveform path, label...) 1 Build the encoder to encode the labels 2 Train the model 3 Evaluate the model on multiple test sets ======== ==================== Args: target_dir (str): The directory that stores the script result. cache_dir (str): The directory that caches the processed data. Default: /home/user/.cache/s3prl/data remove_all_cache (bool): Whether to remove all the cache stored under `cache_dir`. Default: False start (int): The starting stage of the problem script. Default: 0 stop (int): The stoping stage of the problem script, set `None` to reach the final stage. Default: None num_workers (int): num_workers for all the torch DataLoder eval_batch (int): During evaluation (valid or test), limit the number of batch. This is helpful for the fast development to check everything won't crash. If is -1, disable this feature and evaluate the entire epoch. Default: -1 device (str): The device type for all torch-related operation: "cpu" or "cuda" Default: "cuda" world_size (int): How many processes are running this script simultaneously (in parallel). Usually this is just 1, however if you are runnig distributed training, this should be > 1. Default: 1 rank (int): When distributed training, world_size > 1. Take :code:`world_size == 8` for example, this means 8 processes (8 GPUs) are runing in parallel. The script needs to know which process among 8 processes it is. In this case, :code:`rank` can range from 0~7. All the 8 processes have the same :code:`world_size` but different :code:`rank` (process id). test_ckpt_dir (str): Specify the checkpoint path for testing. If not, use the validation best checkpoint under the given :code:`target_dir` directory. **kwds: The other arguments like :code:`prepare_data` and :code:`build_model` are method specific-arguments for methods like :obj:`prepare_data` and :obj:`build_model`, and will not be used in the core :obj:`run` logic. See the specific method documentation for their supported arguments and meaning """ yaml_path = Path(target_dir) / "configs" / f"{self._get_time_tag()}.yaml" yaml_path.parent.mkdir(exist_ok=True, parents=True) with yaml_path.open("w") as f: yaml.safe_dump(self._get_current_arguments(), f) cache_dir: str = cache_dir or Path.home() / ".cache" / "s3prl" / "data" prepare_data: dict = prepare_data or {} build_encoder: dict = build_encoder or {} build_dataset: dict = build_dataset or {} build_batch_sampler: dict = build_batch_sampler or {} build_collate_fn: dict = build_collate_fn or {} build_upstream: dict = build_upstream or {} build_featurizer: dict = build_featurizer or {} build_downstream: dict = build_downstream or {} build_model: dict = build_model or {} build_task: dict = build_task or {} build_optimizer: dict = build_optimizer or {} build_scheduler: dict = build_scheduler or {} save_model: dict = save_model or {} save_task: dict = save_task or {} train: dict = train or {} evaluate = evaluate or {} target_dir: Path = Path(target_dir) target_dir.mkdir(exist_ok=True, parents=True) cache_dir = Path(cache_dir) cache_dir.mkdir(exist_ok=True, parents=True) if remove_all_cache: shutil.rmtree(cache_dir) stage_id = 0 if start <= stage_id: logger.info(f"Stage {stage_id}: prepare data") train_csv, valid_csv, test_csvs = self.prepare_data( prepare_data, target_dir, cache_dir, get_path_only=False ) train_csv, valid_csv, test_csvs = self.prepare_data( prepare_data, target_dir, cache_dir, get_path_only=True ) def check_fn(): assert Path(train_csv).is_file() and Path(valid_csv).is_file() for test_csv in test_csvs: assert Path(test_csv).is_file() self._stage_check(stage_id, stop, check_fn) stage_id = 1 if start <= stage_id: logger.info(f"Stage {stage_id}: build encoder") encoder_path = self.build_encoder( build_encoder, target_dir, cache_dir, train_csv, valid_csv, test_csvs, get_path_only=False, ) encoder_path = self.build_encoder( build_encoder, target_dir, cache_dir, train_csv, valid_csv, test_csvs, get_path_only=True, ) def check_fn(): assert Path(encoder_path).is_file() self._stage_check(stage_id, stop, check_fn) with open(encoder_path, "rb") as f: encoder = pickle.load(f) model_output_size = len(encoder) model = self.build_model( build_model, model_output_size, build_upstream, build_featurizer, build_downstream, ) frame_shift = model.downsample_rate stage_id = 2 train_dir = target_dir / "train" if start <= stage_id: logger.info(f"Stage {stage_id}: Train Model") train_ds, train_bs = self._build_dataset_and_sampler( target_dir, cache_dir, "train", train_csv, encoder_path, frame_shift, build_dataset, build_batch_sampler, ) valid_ds, valid_bs = self._build_dataset_and_sampler( target_dir, cache_dir, "valid", valid_csv, encoder_path, frame_shift, build_dataset, build_batch_sampler, ) with Path(encoder_path).open("rb") as f: encoder = pickle.load(f) build_model_all_args = dict( build_model=build_model, model_output_size=len(encoder), build_upstream=build_upstream, build_featurizer=build_featurizer, build_downstream=build_downstream, ) build_task_all_args_except_model = dict( build_task=build_task, encoder=encoder, valid_df=pd.read_csv(valid_csv), ) self.train( train, train_dir, build_model_all_args, build_task_all_args_except_model, save_model, save_task, build_optimizer, build_scheduler, evaluate, train_ds, train_bs, self.build_collate_fn(build_collate_fn, "train"), valid_ds, valid_bs, self.build_collate_fn(build_collate_fn, "valid"), device=device, eval_batch=eval_batch, num_workers=num_workers, world_size=world_size, rank=rank, ) def check_fn(): assert (train_dir / "valid_best").is_dir() self._stage_check(stage_id, stop, check_fn) stage_id = 3 if start <= stage_id: test_ckpt_dir: Path = Path( test_ckpt_dir or target_dir / "train" / "valid_best" ) assert test_ckpt_dir.is_dir() logger.info(f"Stage {stage_id}: Test model: {test_ckpt_dir}") for test_idx, test_csv in enumerate(test_csvs): test_name = Path(test_csv).stem test_dir: Path = ( target_dir / "evaluate" / test_ckpt_dir.relative_to(train_dir).as_posix().replace("/", "-") / test_name ) test_dir.mkdir(exist_ok=True, parents=True) logger.info(f"Stage {stage_id}.{test_idx}: Test model on {test_csv}") test_ds, test_bs = self._build_dataset_and_sampler( target_dir, cache_dir, "test", test_csv, encoder_path, frame_shift, build_dataset, build_batch_sampler, ) _, valid_best_task = self.load_model_and_task( test_ckpt_dir, task_overrides={"test_df": pd.read_csv(test_csv)} ) logs = self.evaluate( evaluate, "test", valid_best_task, test_ds, test_bs, self.build_collate_fn(build_collate_fn, "test"), eval_batch, test_dir, device, num_workers, ) test_metrics = {name: float(value) for name, value in logs.items()} logger.info(f"test results: {test_metrics}") with (test_dir / f"result.yaml").open("w") as f: yaml.safe_dump(test_metrics, f)
def _build_dataset_and_sampler( self, target_dir: str, cache_dir: str, mode: str, data_csv: str, encoder_path: str, frame_shift: int, build_dataset: dict, build_batch_sampler: dict, ): logger.info(f"Build {mode} dataset") dataset = self.build_dataset( build_dataset, target_dir, cache_dir, mode, data_csv, encoder_path, frame_shift, ) logger.info(f"Build {mode} batch sampler") batch_sampler = self.build_batch_sampler( build_batch_sampler, target_dir, cache_dir, mode, data_csv, dataset, ) return dataset, batch_sampler
[docs] def build_task( self, build_task: dict, model: torch.nn.Module, encoder, valid_df: pd.DataFrame = None, test_df: pd.DataFrame = None, ): """ Build the task, which defines the logics for every train/valid/test forward step for the :code:`model`, and the logics for how to reduce all the batch results from multiple train/valid/test steps into metrics By default build :obj:`UtteranceClassificationTask` Args: build_task (dict): same in :obj:`default_config`, no argument supported for now model (torch.nn.Module): the model built by :obj:`build_model` encoder: the encoder built by :obj:`build_encoder` Returns: Task """ task = UtteranceClassificationTask(model, encoder) return task