"""
Parse the LibriSpeech corpus
Authors:
* Heng-Jui Chang 2022
"""
import logging
import os
import re
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Any, Dict, List
from joblib import Parallel, delayed
from .base import Corpus
logger = logging.getLogger(__name__)
LIBRI_SPLITS = [
"train-clean-100",
"train-clean-360",
"train-other-500",
"dev-clean",
"dev-other",
"test-clean",
"test-other",
]
__all__ = [
"LibriSpeech",
]
def read_text(file: Path) -> str:
src_file = "-".join(str(file).split("-")[:-1]) + ".trans.txt"
idx = file.stem.replace(".flac", "")
with open(src_file, "r") as fp:
for line in fp:
if idx == line.split(" ")[0]:
return line[:-1].split(" ", 1)[1]
logger.warning(f"Transcription of {file} not found!")
def check_no_repeat(splits: List[str]) -> bool:
count = defaultdict(int)
for split in splits:
count[split] += 1
repeated = ""
for key, val in count.items():
if val > 1:
repeated += f" {key} ({val} times)"
if len(repeated) != 0:
logger.warning(
f"Found repeated splits in corpus: {repeated}, which might cause unexpected behaviors."
)
return False
return True
def _parse_spk_to_gender(speaker_file: Path) -> dict:
speaker_file = Path(speaker_file)
with speaker_file.open() as file:
lines = [line.strip() for line in file.readlines()]
for line_id in range(len(lines)):
line = lines[line_id]
if "SEX" in line and "SUBSET" in line and "MINUTES" in line and "NAME" in line:
break
line_id += 1 # first line with speaker info
spk2gender = {}
for line_id in range(line_id, len(lines)):
line = lines[line_id]
line = re.sub("\t+", " ", line)
line = re.sub(" +", " ", line)
parts = line.split("|", maxsplit=4)
ID, SEX, SUBSET, MINUTES, NAME = parts
spk2gender[int(ID)] = SEX.strip()
return spk2gender
[docs]class LibriSpeech(Corpus):
"""LibriSpeech Corpus
Link: https://www.openslr.org/12
Args:
dataset_root (str): Path to LibriSpeech corpus directory.
n_jobs (int, optional): Number of jobs. Defaults to 4.
train_split (List[str], optional): Training splits. Defaults to ["train-clean-100"].
valid_split (List[str], optional): Validation splits. Defaults to ["dev-clean"].
test_split (List[str], optional): Testing splits. Defaults to ["test-clean"].
"""
def __init__(
self,
dataset_root: str,
n_jobs: int = 4,
train_split: List[str] = ["train-clean-100"],
valid_split: List[str] = ["dev-clean"],
test_split: List[str] = ["test-clean"],
) -> None:
self.dataset_root = Path(dataset_root).resolve()
self.train_split = train_split
self.valid_split = valid_split
self.test_split = test_split
self.all_splits = train_split + valid_split + test_split
assert check_no_repeat(self.all_splits)
self.data_dict = self._collect_data(dataset_root, self.all_splits, n_jobs)
self.train = self._data_to_dict(self.data_dict, train_split)
self.valid = self._data_to_dict(self.data_dict, valid_split)
self.test = self._data_to_dict(self.data_dict, test_split)
self._data = OrderedDict()
self._data.update(self.train)
self._data.update(self.valid)
self._data.update(self.test)
[docs] def get_corpus_splits(self, splits: List[str]):
return self._data_to_dict(self.data_dict, splits)
@property
def all_data(self):
"""
Return all the data points in a dict of the format
.. code-block:: yaml
data_id1:
wav_path: (str) The waveform path
transcription: (str) The transcription
speaker: (str) The speaker name
gender: (str) The speaker's gender
corpus_split: (str) The split of corpus this sample belongs to
data_id2:
...
"""
return self._data
@property
def data_split_ids(self):
return (
list(self.train.keys()),
list(self.valid.keys()),
list(self.test.keys()),
)
@staticmethod
def _collect_data(
dataset_root: str, splits: List[str], n_jobs: int = 4
) -> Dict[str, Dict[str, List[Any]]]:
spkr2gender = _parse_spk_to_gender(Path(dataset_root) / "SPEAKERS.TXT")
data_dict = {}
for split in splits:
split_dir = os.path.join(dataset_root, split)
if not os.path.exists(split_dir):
logger.info(f"Split {split} is not downloaded. Skip data collection.")
continue
wav_list = list(Path(split_dir).rglob("*.flac"))
name_list = [file.stem.replace(".flac", "") for file in wav_list]
text_list = Parallel(n_jobs=n_jobs)(
delayed(read_text)(file) for file in wav_list
)
spkr_list = [int(name.split("-")[0]) for name in name_list]
wav_list, name_list, text_list, spkr_list = zip(
*[
(wav, name, text, spkr)
for (wav, name, text, spkr) in sorted(
zip(wav_list, name_list, text_list, spkr_list),
key=lambda x: x[1],
)
]
)
data_dict[split] = {
"name_list": list(name_list),
"wav_list": list(wav_list),
"text_list": list(text_list),
"spkr_list": list(spkr_list),
"gender_list": [spkr2gender[spkr] for spkr in spkr_list],
}
return data_dict
@staticmethod
def _data_to_dict(
data_dict: Dict[str, Dict[str, List[Any]]], splits: List[str]
) -> dict():
data = dict(
{
name: {
"wav_path": data_dict[split]["wav_list"][i],
"transcription": data_dict[split]["text_list"][i],
"speaker": data_dict[split]["spkr_list"][i],
"gender": data_dict[split]["gender_list"][i],
"corpus_split": split,
}
for split in splits
for i, name in enumerate(data_dict[split]["name_list"])
}
)
return data
[docs] @classmethod
def download_dataset(
cls,
target_dir: str,
splits: List[str] = ["train-clean-100", "dev-clean", "test-clean"],
) -> None:
import os
import tarfile
import requests
target_dir = Path(target_dir)
target_dir.mkdir(exist_ok=True, parents=True)
def unzip_targz_then_delete(filepath: str):
with tarfile.open(os.path.abspath(filepath)) as tar:
tar.extractall(path=os.path.abspath(target_dir))
os.remove(os.path.abspath(filepath))
def download_from_url(url: str):
filename = url.split("/")[-1].replace(" ", "_")
filepath = os.path.join(target_dir, filename)
r = requests.get(url, stream=True)
if r.ok:
logger.info(f"Saving {filename} to {os.path.abspath(filepath)}")
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024 * 10):
if chunk:
f.write(chunk)
f.flush()
os.fsync(f.fileno())
logger.info(f"{filename} successfully downloaded")
unzip_targz_then_delete(filepath)
else:
logger.info(f"Download failed: status code {r.status_code}\n{r.text}")
for split in splits:
if not os.path.exists(
os.path.join(os.path.abspath(target_dir), "Librispeech/" + split)
):
download_from_url(
"https://www.openslr.org/resources/12/" + split + ".tar.gz"
)
logger.info(
", ".join(splits)
+ f"downloaded. Located at {os.path.abspath(target_dir)}/Librispeech/"
)