"""
Parse VoxCeleb1 corpus for classification
Authors:
* Leo 2022
* Cheng Liang 2022
"""
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import List
from filelock import FileLock
from joblib import Parallel, delayed
from tqdm import tqdm
from .base import Corpus
logger = logging.getLogger(__name__)
SPLIT_FILE_URL = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt"
CACHE_ROOT = Path.home() / ".cache" / "s3prl"
__all__ = [
"VoxCeleb1SID",
]
[docs]class VoxCeleb1SID(Corpus):
def __init__(
self, dataset_root: str, n_jobs: int = 4, cache_root: str = CACHE_ROOT
) -> None:
self.dataset_root = Path(dataset_root).resolve()
uid2split = self._get_standard_usage(self.dataset_root, cache_root)
self._split2uids = defaultdict(list)
for uid, split in uid2split.items():
self._split2uids[split].append(Path(uid.replace("/", "-")).stem)
uid2wavpath = self._find_wavs_with_uids(
self.dataset_root, sorted(uid2split.keys()), n_jobs=n_jobs
)
self._data = {
Path(uid.replace("/", "-")).stem: {
"wav_path": uid2wavpath[uid],
"label": self._build_label(uid),
}
for uid in uid2split.keys()
}
@property
def all_data(self):
return self._data
@property
def data_split_ids(self):
return (
self._split2uids["train"],
self._split2uids["valid"],
self._split2uids["test"],
)
@staticmethod
def _get_standard_usage(dataset_root: Path, cache_root: Path):
split_filename = SPLIT_FILE_URL.split("/")[-1]
split_filepath = Path(cache_root) / split_filename
if not split_filepath.is_file():
with FileLock(str(split_filepath) + ".lock"):
os.system(f"wget {SPLIT_FILE_URL} -O {str(split_filepath)}")
standard_usage = [
line.strip().split(" ") for line in open(split_filepath, "r").readlines()
]
def code2split(code: int):
splits = ["train", "valid", "test"]
return splits[code - 1]
standard_usage = {uid: code2split(int(split)) for split, uid in standard_usage}
return standard_usage
@staticmethod
def _find_wavs_with_uids(dataset_root, uids, n_jobs=4):
def find_wav_with_uid(uid):
found_wavs = list(dataset_root.glob(f"*/wav/{uid}"))
assert len(found_wavs) == 1
return uid, found_wavs[0]
uids_with_wavs = Parallel(n_jobs=n_jobs)(
delayed(find_wav_with_uid)(uid) for uid in tqdm(uids, desc="Search wavs")
)
uids2wav = {uid: wav for uid, wav in uids_with_wavs}
return uids2wav
@staticmethod
def _build_label(uid):
id_string = uid.split("/")[0]
label = f"speaker_{int(id_string[2:]) - 10001}"
return label
[docs] @classmethod
def download_dataset(
cls, target_dir: str, splits: List[str] = ["dev", "test"]
) -> None:
tgt_dir = os.path.abspath(target_dir)
assert os.path.exists(tgt_dir), "Target directory does not exist"
from zipfile import ZipFile
import requests
def unzip_then_delete(filepath: str, split: str):
assert os.path.exists(filepath), "File not found!"
with ZipFile(filepath) as zipf:
zipf.extractall(path=os.path.join(tgt_dir, "Voxceleb1", split))
os.remove(os.path.abspath(filepath))
def download_from_url(url: str, split: str):
filename = url.split("/")[-1].replace(" ", "_")
filepath = os.path.join(tgt_dir, filename)
r = requests.get(url, stream=True)
if r.ok:
logger.info(f"Saving {filename} to", filepath)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024 * 10):
if chunk:
f.write(chunk)
f.flush()
os.fsync(f.fileno())
logger.info(f"{filename} successfully downloaded")
else:
logger.info(f"Download failed: status code {r.status_code}\n{r.text}")
return filepath
def download_dev():
partpaths = []
for part in ["a", "b", "c", "d"]:
if os.path.exists(os.path.join(tgt_dir, f"vox1_dev_wav_parta{part}")):
logger.info(f"vox1_dev_wav_parta{part} exists, skip donwload")
partpaths.append(os.path.join(tgt_dir, f"vox1_dev_wav_parta{part}"))
continue
fp = download_from_url(
f"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_parta{part}",
"dev",
)
partpaths.append(fp)
zippath = os.path.join(tgt_dir, "vox1_dev_wav.zip")
with open(zippath, "wb") as outfile:
for f in partpaths:
with open(f, "rb") as infile:
for line in infile:
outfile.write(line)
for f in partpaths:
os.remove(f)
unzip_then_delete(zippath, "dev")
for split in splits:
if not os.path.exists(os.path.join(tgt_dir, "Voxceleb1/" + split + "/wav")):
if split == "dev":
download_dev()
else:
filepath = download_from_url(
"https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip",
"test",
)
unzip_then_delete(filepath, "test")
logger.info(f"Voxceleb1 dataset downloaded. Located at {tgt_dir}/Voxceleb1/")