"""
The setting fo Superb SD
Authors:
* Jiatong Shi 2021
* Leo 2022
"""
from dataclasses import dataclass
from pathlib import Path
from omegaconf import MISSING
from s3prl.dataio.dataset import DiarizationDataset, get_info
from s3prl.dataio.sampler import FixedBatchSizeBatchSampler, GroupSameItemSampler
from s3prl.nn.rnn import SuperbDiarizationModel
from .run import Diarization
from .util import kaldi_dir_to_csv
__all__ = [
"SuperbSD",
]
[docs]class SuperbSD(Diarization):
[docs] def default_config(self):
return dict(
start=0,
stop=None,
target_dir=MISSING,
cache_dir=None,
remove_all_cache=False,
prepare_data=dict(
data_dir=MISSING,
),
build_dataset=dict(
chunk_size=2000,
subsampling=1,
rate=16000,
use_last_samples=True,
label_delay=0,
),
build_batch_sampler=dict(
train=dict(
batch_size=8,
shuffle=True,
),
valid=dict(
batch_size=1,
),
),
build_upstream=dict(
name=MISSING,
),
build_featurizer=dict(
layer_selections=None,
normalize=False,
),
build_downstream=dict(
hidden_size=512,
rnn_layers=1,
),
build_model=dict(
upstream_trainable=False,
),
build_optimizer=dict(
name="Adam",
conf=dict(
lr=1.0e-4,
),
),
build_scheduler=dict(
name="ExponentialLR",
gamma=0.9,
),
save_model=dict(
extra_conf=dict(
build_downstream_conf="${build_downstream}"
), # This is redundant for ASR. Just to show how to clone other fields
),
save_task=dict(),
train=dict(
total_steps=30000,
log_step=500,
eval_step=500,
save_step=500,
gradient_clipping=1.0,
gradient_accumulate=4,
valid_metric="der",
valid_higher_better=False,
auto_resume=True,
resume_ckpt_dir=None,
),
scoring=dict(
thresholds=[0.3, 0.4, 0.5, 0.6, 0.7],
median_filters=[1, 11],
),
)
[docs] def prepare_data(
self, prepare_data: dict, target_dir: str, cache_dir: str, get_path_only=False
):
"""
Prepare the task-specific data metadata (path, labels...).
Args:
prepare_data (dict): same in :obj:`default_config`
==================== ====================
key description
==================== ====================
data_dir (str) - the standard Kaldi data directory
==================== ====================
target_dir (str): Parse your corpus and save the csv file into this directory
cache_dir (str): If the parsing or preprocessing takes too long time, you can save
the temporary files into this directory. This directory is expected to be shared
across different training sessions (different hypers and :code:`target_dir`)
get_path_only (str): Directly return the filepaths no matter they exist or not.
Returns:
tuple
1. train_path (str)
2. valid_path (str)
3. test_paths (List[str])
Each path (str) should be a csv file containing the following columns:
==================== ====================
column description
==================== ====================
record_id (str) - the id for the recording
duration (float) - the total seconds of the recording
wav_path (str) - the absolute path of the recording
utt_id (str) - the id for the segmented utterance, should be \
globally unique across all recordings instead of just \
unique in a recording
speaker (str) - the speaker label for the segmented utterance
start_sec (float) - segment start second in the recording
end_sec (float) - segment end second in the recording
==================== ====================
Instead of one waveform file per row, the above file format is one segment per row,
and a waveform file can have multiple overlapped segments uttered by different speakers.
"""
@dataclass
class Config:
data_dir: str
conf = Config(**prepare_data)
target_dir: Path = Path(target_dir)
train_csv = target_dir / "train.csv"
valid_csv = target_dir / "valid.csv"
test_csv = target_dir / "test.csv"
if get_path_only:
return train_csv, valid_csv, [test_csv]
kaldi_dir_to_csv(Path(conf.data_dir) / "train", train_csv)
kaldi_dir_to_csv(Path(conf.data_dir) / "dev", valid_csv)
kaldi_dir_to_csv(Path(conf.data_dir) / "test", test_csv)
return train_csv, valid_csv, [test_csv]
[docs] def build_dataset(
self,
build_dataset: dict,
target_dir: str,
cache_dir: str,
mode: str,
data_csv: str,
data_dir: str,
num_speakers: int,
frame_shift: int,
):
"""
Build the dataset for train/valid/test.
Args:
build_dataset (dict): same in :obj:`default_config`, supports arguments for :obj:`DiarizationDataset`
target_dir (str): Current experiment directory
cache_dir (str): If the preprocessing takes too long time, you can save
the temporary files into this directory. This directory is expected to be shared
across different training sessions (different hypers and :code:`target_dir`)
mode (str): train/valid/test
data_csv (str): The metadata csv file for the specific :code:`mode`
data_dir (str): The converted kaldi data directory from :code:`data_csv`
num_speakers (int): The number of speaker per utterance
frame_shift (int): The frame shift of the upstream model (downsample rate from 16 KHz)
Returns:
torch Dataset
For all train/valid/test mode, the dataset should return each item as a dictionary
containing the following keys:
==================== ====================
key description
==================== ====================
x (torch.FloatTensor) - the waveform in (seq_len, 1)
x_len (int) - the waveform length :code:`seq_len`
label (torch.LongTensor) - the binary label for each upstream frame, \
shape: :code:`(upstream_len, 2)`
label_len (int) - the upstream feature's seq length :code:`upstream_len`
record_id (str) - the unique id for the recording
chunk_id (int) - since recording can be chunked into several segments \
for efficient training, this field indicate the segment's \
original position (order, 0-index) in the recording. This \
field is only useful during the testing stage
==================== ====================
"""
dataset = DiarizationDataset(
mode,
data_dir,
frame_shift=frame_shift,
num_speakers=num_speakers,
**build_dataset,
)
return dataset
[docs] def build_batch_sampler(
self,
build_batch_sampler: dict,
target_dir: str,
cache_dir: str,
mode: str,
data_csv: str,
data_dir: str,
dataset,
):
"""
Return the batch sampler for torch DataLoader.
Args:
build_batch_sampler (dict): same in :obj:`default_config`
==================== ====================
key description
==================== ====================
train (dict) - arguments for :obj:`FixedBatchSizeBatchSampler`
valid (dict) - arguments for :obj:`FixedBatchSizeBatchSampler`
test (dict) - arguments for :obj:`GroupSameItemSampler`, should always \
use this batch sampler for the testing stage
==================== ====================
target_dir (str): Current experiment directory
cache_dir (str): If the preprocessing takes too long time, save
the temporary files into this directory. This directory is expected to be shared
across different training sessions (different hypers and :code:`target_dir`)
mode (str): train/valid/test
data_csv (str): The metadata csv file for the specific :code:`mode`
data_dir (str): The converted kaldi data directory from :code:`data_csv`
dataset: the dataset from :obj:`build_dataset`
Returns:
batch sampler for torch DataLoader
"""
@dataclass
class Config:
train: dict = None
valid: dict = None
conf = Config(**build_batch_sampler)
if mode == "train":
return FixedBatchSizeBatchSampler(dataset, **(conf.train or {}))
elif mode == "valid":
return FixedBatchSizeBatchSampler(dataset, **(conf.valid or {}))
elif mode == "test":
record_ids = get_info(dataset, ["record_id"])
return GroupSameItemSampler(record_ids)
else:
raise ValueError(f"Unsupported mode: {mode}")
[docs] def build_downstream(
self,
build_downstream: dict,
downstream_input_size: int,
downstream_output_size: int,
downstream_input_stride: int,
):
"""
Return the task-specific downstream model.
By default build the :obj:`SuperbDiarizationModel` model
Args:
build_downstream (dict): same in :obj:`default_config`, support arguments of :obj:`SuperbDiarizationModel`
downstream_input_size (int): the required input size of the model
downstream_output_size (int): the required output size of the model
downstream_input_stride (int): the input feature's stride (from 16 KHz)
Returns:
:obj:`s3prl.nn.interface.AbsFrameModel`
"""
return SuperbDiarizationModel(
downstream_input_size, downstream_output_size, **build_downstream
)