"""
The backbone run procedure for Speaker Diarization
Authors:
* Leo 2022
"""
import logging
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import List
import yaml
from s3prl.problem.base import Problem
from s3prl.task.diarization import DiarizationPIT
from .util import csv_to_kaldi_dir, kaldi_dir_to_rttm, make_rttm_and_score
logger = logging.getLogger(__name__)
__all__ = ["Diarization"]
[docs]class Diarization(Problem):
[docs] def run(
self,
target_dir: str,
cache_dir: str,
remove_all_cache: bool = False,
start: int = 0,
stop: int = None,
num_workers: int = 6,
eval_batch: int = -1,
device: str = "cuda",
world_size: int = 1,
rank: int = 0,
test_ckpt_dir: str = None,
num_speaker: int = 2,
prepare_data: dict = None,
build_dataset: dict = None,
build_batch_sampler: dict = None,
build_collate_fn: dict = None,
build_upstream: dict = None,
build_featurizer: dict = None,
build_downstream: dict = None,
build_model: dict = None,
build_task: dict = None,
build_optimizer: dict = None,
build_scheduler: dict = None,
save_model: dict = None,
save_task: dict = None,
train: dict = None,
evaluate: dict = None,
scoring: dict = None,
):
"""
======== ====================
stage description
======== ====================
0 Parse the corpus and save the Kaldi-style data directory for speaker diarization
1 Train the model
2 Inference the prediction
3 Score the prediction
======== ====================
Args:
target_dir (str):
The directory that stores the script result.
cache_dir (str):
The directory that caches the processed data.
Default: /home/user/.cache/s3prl/data
remove_all_cache (bool):
Whether to remove all the cache stored under `cache_dir`.
Default: False
start (int):
The starting stage of the problem script.
Default: 0
stop (int):
The stoping stage of the problem script, set `None` to reach the final stage.
Default: None
num_workers (int): num_workers for all the torch DataLoder
eval_batch (int):
During evaluation (valid or test), limit the number of batch.
This is helpful for the fast development to check everything won't crash.
If is -1, disable this feature and evaluate the entire epoch.
Default: -1
device (str):
The device type for all torch-related operation: "cpu" or "cuda"
Default: "cuda"
world_size (int):
How many processes are running this script simultaneously (in parallel).
Usually this is just 1, however if you are runnig distributed training,
this should be > 1.
Default: 1
rank (int):
When distributed training, world_size > 1. Take :code:`world_size == 8` for
example, this means 8 processes (8 GPUs) are runing in parallel. The script
needs to know which process among 8 processes it is. In this case, :code:`rank`
can range from 0~7. All the 8 processes have the same :code:`world_size` but
different :code:`rank` (process id).
test_ckpt_dir (str):
Specify the checkpoint path for testing. If not, use checkpoints specified by
:code:`test_ckpts_steps`.
num_speaker (int):
How many speakers per utterance
**others:
The other arguments like :code:`prepare_data` and :code:`build_model` are
method specific-arguments for methods like :obj:`prepare_data` and
:obj:`build_model`, and will not be used in the core :obj:`run` logic.
See the specific method documentation for their supported arguments and
meaning
"""
yaml_path = Path(target_dir) / "configs" / f"{self._get_time_tag()}.yaml"
yaml_path.parent.mkdir(exist_ok=True, parents=True)
with yaml_path.open("w") as f:
yaml.safe_dump(self._get_current_arguments(), f)
cache_dir: str = cache_dir or Path.home() / ".cache" / "s3prl" / "data"
prepare_data: dict = prepare_data or {}
build_dataset: dict = build_dataset or {}
build_batch_sampler: dict = build_batch_sampler or {}
build_collate_fn: dict = build_collate_fn or {}
build_upstream: dict = build_upstream or {}
build_featurizer: dict = build_featurizer or {}
build_downstream: dict = build_downstream or {}
build_model: dict = build_model or {}
build_task: dict = build_task or {}
build_optimizer: dict = build_optimizer or {}
build_scheduler: dict = build_scheduler or {}
save_model: dict = save_model or {}
save_task: dict = save_task or {}
train: dict = train or {}
evaluate: dict = evaluate or {}
scoring: dict = scoring or {}
target_dir: Path = Path(target_dir)
target_dir.mkdir(exist_ok=True, parents=True)
cache_dir = Path(cache_dir)
cache_dir.mkdir(exist_ok=True, parents=True)
if remove_all_cache:
shutil.rmtree(cache_dir, ignore_errors=True)
stage_id = 0
if start <= stage_id:
logger.info(f"Stage {stage_id}: prepare data")
train_csv, valid_csv, test_csvs = self.prepare_data(
prepare_data, target_dir, cache_dir, get_path_only=False
)
train_csv, valid_csv, test_csvs = self.prepare_data(
prepare_data, target_dir, cache_dir, get_path_only=True
)
def check_fn():
assert Path(train_csv).is_file() and Path(valid_csv).is_file()
for test_csv in test_csvs:
assert Path(test_csv).is_file()
self._stage_check(stage_id, stop, check_fn)
for csv in [train_csv, valid_csv, *test_csvs]:
data_dir = target_dir / "kaldi_data" / Path(csv).stem
csv_to_kaldi_dir(csv, data_dir)
train_data = target_dir / "kaldi_data" / Path(train_csv).stem
valid_data = target_dir / "kaldi_data" / Path(valid_csv).stem
test_datas = [target_dir / "kaldi_data" / Path(csv).stem for csv in test_csvs]
test_rttms = []
for test_data in test_datas:
logger.info(f"Prepare RTTM for {test_data}")
test_rttm = target_dir / f"{Path(test_data).stem}.rttm"
kaldi_dir_to_rttm(test_data, test_rttm)
test_rttms.append(test_rttm)
model_output_size = num_speaker
model = self.build_model(
build_model,
model_output_size,
build_upstream,
build_featurizer,
build_downstream,
)
frame_shift = model.downsample_rate
stage_id = 1
train_dir = target_dir / "train"
if start <= stage_id:
logger.info(f"Stage {stage_id}: Train Model")
train_ds, train_bs = self._build_dataset_and_sampler(
target_dir,
cache_dir,
"train",
train_csv,
train_data,
num_speaker,
frame_shift,
build_dataset,
build_batch_sampler,
)
valid_ds, valid_bs = self._build_dataset_and_sampler(
target_dir,
cache_dir,
"valid",
valid_csv,
valid_data,
num_speaker,
frame_shift,
build_dataset,
build_batch_sampler,
)
build_model_all_args = dict(
build_model=build_model,
model_output_size=model_output_size,
build_upstream=build_upstream,
build_featurizer=build_featurizer,
build_downstream=build_downstream,
)
build_task_all_args_except_model = dict(
build_task=build_task,
)
self.train(
train,
train_dir,
build_model_all_args,
build_task_all_args_except_model,
save_model,
save_task,
build_optimizer,
build_scheduler,
evaluate,
train_ds,
train_bs,
self.build_collate_fn(build_collate_fn, "train"),
valid_ds,
valid_bs,
self.build_collate_fn(build_collate_fn, "valid"),
device=device,
eval_batch=eval_batch,
num_workers=num_workers,
world_size=world_size,
rank=rank,
)
def check_fn():
assert (train_dir / "valid_best").is_dir()
self._stage_check(stage_id, stop, check_fn)
stage_id = 2
test_ckpt_dir: Path = Path(test_ckpt_dir or target_dir / "train" / "valid_best")
test_dirs = []
for test_idx, test_data in enumerate(test_datas):
test_name = Path(test_data).stem
test_dir: Path = (
target_dir
/ "evaluate"
/ test_ckpt_dir.relative_to(train_dir).as_posix().replace("/", "-")
/ test_name
)
test_dirs.append(test_dir)
if start <= stage_id:
logger.info(f"Stage {stage_id}: Test model: {test_ckpt_dir}")
for test_idx, test_data in enumerate(test_datas):
test_csv = test_csvs[test_idx]
test_dir = test_dirs[test_idx]
test_dir.mkdir(exist_ok=True, parents=True)
logger.info(
f"Stage {stage_id}.{test_idx}: Test model on {test_dir} and dump prediction"
)
test_ds, test_bs = self._build_dataset_and_sampler(
target_dir,
cache_dir,
"test",
test_csv,
test_data,
num_speaker,
frame_shift,
build_dataset,
build_batch_sampler,
)
_, valid_best_task = self.load_model_and_task(test_ckpt_dir)
logs: dict = self.evaluate(
evaluate,
"test",
valid_best_task,
test_ds,
test_bs,
self.build_collate_fn(build_collate_fn, "test"),
eval_batch,
test_dir,
device,
num_workers,
)
test_metrics = {name: float(value) for name, value in logs.items()}
with (test_dir / f"result.yaml").open("w") as f:
yaml.safe_dump(test_metrics, f)
def check_fn():
for test_dir in test_dirs:
assert (test_dir / "prediction").is_dir()
self._stage_check(stage_id, stop, check_fn)
stage_id = 3
if start <= stage_id:
logger.info(f"Stage {stage_id}: Score model: {test_ckpt_dir}")
self.scoring(scoring, stage_id, test_dirs, test_rttms, frame_shift)
return stage_id
[docs] def scoring(
self,
scoring: dict,
stage_id: int,
test_dirs: List[str],
test_rttms: List[str],
frame_shift: int,
):
"""
Score the prediction
Args:
scoring (dict):
==================== ====================
key description
==================== ====================
thresholds (List[int]) - Given the 0~1 (float) soft prediction, the threshold decides \
how to get the 0/1 hard prediction. This list are all the thresholds to try.
median_filters (List[int]) - After getting hard prediction, use median filter to smooth out the \
prediction. This list are all the median filter sizes to try.
==================== ====================
*others:
This method is not designed to be overridden
"""
@dataclass
class ScoreConfig:
thresholds: List[int]
median_filters: List[int]
conf = ScoreConfig(**scoring)
for test_idx, test_dir in enumerate(test_dirs):
logger.info(
f"Stage {stage_id}.{test_idx}: Make RTTM and Score from prediction"
)
best_der, (best_th, best_med) = make_rttm_and_score(
test_dir / "prediction",
test_dir / "score",
test_rttms[test_idx],
frame_shift,
conf.thresholds,
conf.median_filters,
)
logger.info(f"Best dscore DER: {best_der}")
with (test_dir / "dscore.yaml").open("w") as f:
yaml.safe_dump(
dict(
der=best_der,
threshold=best_th,
median_filter=best_med,
),
f,
)
def _build_dataset_and_sampler(
self,
target_dir: str,
cache_dir: str,
mode: str,
data_csv: str,
data_dir: str,
num_speakers: int,
frame_shift: int,
build_dataset: dict,
build_batch_sampler: dict,
):
logger.info(f"Build {mode} dataset")
dataset = self.build_dataset(
build_dataset,
target_dir,
cache_dir,
mode,
data_csv,
data_dir,
num_speakers,
frame_shift,
)
logger.info(f"Build {mode} batch sampler")
batch_sampler = self.build_batch_sampler(
build_batch_sampler,
target_dir,
cache_dir,
mode,
data_csv,
data_dir,
dataset,
)
return dataset, batch_sampler
[docs] def build_task(self, build_task: dict, model):
"""
Build the task, which defines the logics for every train/valid/test forward step for the :code:`model`,
and the logics for how to reduce all the batch results from multiple train/valid/test steps into metrics
By default build :obj:`DiarizationPIT`
Args:
build_task (dict): same in :obj:`default_config`, no argument supported for now
model (torch.nn.Module): the model built by :obj:`build_model`
Returns:
Task
"""
task = DiarizationPIT(model)
return task