"""
Parse the Google Speech Commands V1 corpus
Authors:
* Leo 2022
* Cheng Liang 2022
"""
import hashlib
import logging
import re
from collections import OrderedDict
from pathlib import Path
from typing import List, Tuple, Union
from .base import Corpus
logger = logging.getLogger(__name__)
CLASSES = [
"yes",
"no",
"up",
"down",
"left",
"right",
"on",
"off",
"stop",
"go",
"_unknown_",
"_silence_",
]
__all__ = [
"SpeechCommandsV1",
]
[docs]class SpeechCommandsV1(Corpus):
"""
Args:
dataset_root (str): should contain a 'dev' sub-folder for the training/validation set
and a 'test' sub-folder for the testing set
"""
def __init__(self, gsc1: str, gsc1_test: str, n_jobs: int = 4) -> None:
train_dataset_root = Path(gsc1)
test_dataset_root = Path(gsc1_test)
train_list, valid_list = self.split_dataset(train_dataset_root)
train_list = self.parse_train_valid_data_list(train_list, train_dataset_root)
valid_list = self.parse_train_valid_data_list(valid_list, train_dataset_root)
test_list = self.parse_test_data_list(test_dataset_root)
self.train = self.list_to_dict(train_list)
self.valid = self.list_to_dict(valid_list)
self.test = self.list_to_dict(test_list)
self._data = OrderedDict()
self._data.update(self.train)
self._data.update(self.valid)
self._data.update(self.test)
[docs] @staticmethod
def split_dataset(
root_dir: Union[str, Path], max_uttr_per_class=2**27 - 1
) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
"""Split Speech Commands into 3 set.
Args:
root_dir: speech commands dataset root dir
max_uttr_per_class: predefined value in the original paper
Return:
train_list: [(class_name, audio_path), ...]
valid_list: as above
"""
train_list, valid_list = [], []
for entry in Path(root_dir).iterdir():
if not entry.is_dir() or entry.name == "_background_noise_":
continue
for audio_path in entry.glob("*.wav"):
speaker_hashed = re.sub(r"_nohash_.*$", "", audio_path.name)
hashed_again = hashlib.sha1(speaker_hashed.encode("utf-8")).hexdigest()
percentage_hash = (int(hashed_again, 16) % (max_uttr_per_class + 1)) * (
100.0 / max_uttr_per_class
)
if percentage_hash < 10:
valid_list.append((entry.name, audio_path))
elif percentage_hash < 20:
pass # testing set is discarded
else:
train_list.append((entry.name, audio_path))
return train_list, valid_list
[docs] @staticmethod
def parse_train_valid_data_list(data_list, train_dataset_root: Path):
data = [
(class_name, audio_path)
if class_name in CLASSES
else ("_unknown_", audio_path)
for class_name, audio_path in data_list
]
data += [
("_silence_", audio_path)
for audio_path in Path(train_dataset_root, "_background_noise_").glob(
"*.wav"
)
]
return data
[docs] @staticmethod
def parse_test_data_list(test_dataset_root: Path):
data = [
(class_dir.name, audio_path)
for class_dir in Path(test_dataset_root).iterdir()
if class_dir.is_dir()
for audio_path in class_dir.glob("*.wav")
]
return data
[docs] @staticmethod
def path_to_unique_name(path: str):
return "/".join(Path(path).parts[-2:])
[docs] @classmethod
def list_to_dict(cls, data_list):
data = dict(
{
cls.path_to_unique_name(audio_path): {
"wav_path": audio_path,
"class_name": class_name,
}
for class_name, audio_path in data_list
}
)
return data
@property
def all_data(self):
"""
Return:
Container: id (str)
wav_path (str)
class_name (str)
"""
return self._data
@property
def data_split_ids(self):
return list(self.train.keys()), list(self.valid.keys()), list(self.test.keys())
[docs] @classmethod
def download_dataset(cls, tgt_dir: str) -> None:
import os
import tarfile
import requests
assert os.path.exists(
os.path.abspath(tgt_dir)
), "Target directory does not exist"
def unzip_targz_then_delete(filepath: str, filename: str):
file_path = os.path.join(
os.path.abspath(tgt_dir), "CORPORA_DIR", filename.replace(".tar.gz", "")
)
os.makedirs(file_path)
with tarfile.open(os.path.abspath(filepath)) as tar:
tar.extractall(path=os.path.abspath(file_path))
os.remove(os.path.abspath(filepath))
def download_from_url(url: str):
filename = url.split("/")[-1].replace(" ", "_")
filepath = os.path.join(tgt_dir, filename)
r = requests.get(url, stream=True)
if r.ok:
logger.info(f"Saving {filename} to", os.path.abspath(filepath))
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024 * 10):
if chunk:
f.write(chunk)
f.flush()
os.fsync(f.fileno())
logger.info(f"{filename} successfully downloaded")
unzip_targz_then_delete(filepath, filename)
else:
logger.info(f"Download failed: status code {r.status_code}\n{r.text}")
if not os.path.exists(
os.path.join(os.path.abspath(tgt_dir), "CORPORA_DIR/speech_commands_v0.01")
):
download_from_url(
"http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz"
)
if not os.path.exists(
os.path.join(
os.path.abspath(tgt_dir), "CORPORA_DIR/speech_commands_test_set_v0.01"
)
):
download_from_url(
"http://download.tensorflow.org/data/speech_commands_test_set_v0.01.tar.gz"
)
logger.info(
f"Speech commands dataset downloaded. Located at {os.path.abspath(tgt_dir)}/CORPORA_DIR/"
)