"""
Authors:
- Leo (2022)
"""
from typing import Any, List, Tuple
import pandas as pd
import torch
from .base import Dataset
from .load_audio import LoadAudio
__all__ = [
"chunking",
"scale_labels_secs",
"get_chunk_labels",
"chunk_labels_to_frame_tensor_label",
"FrameLabelDataset",
]
[docs]def chunking(
start_sec: float,
end_sec: float,
chunk_secs: float,
step_secs: float,
use_unfull_chunks: bool = True,
) -> List[Tuple[float, float]]:
"""
Produce chunks (start, end points) from a given start, end seconds
Args:
start_sec (float): The start second of the utterance
end_sec (float): The end second of the utterance
chunk_secs (float): The length (in seconds) of a chunked chunk
step_secs (float): The stride seconds between chunks
use_unfull_chunks (bool): Whether to produce chunks shorter than :code:`chunk_secs`
at the end of the recording
Returns:
List[Tuple[float, float]]: Each tuple describes the starting point (in sec)
and the ending point (in sec) of each chunk in order
"""
start, end = start_sec, end_sec
while end - start > 0:
if end - start >= chunk_secs:
yield start, start + chunk_secs
elif use_unfull_chunks:
yield start, end
start = start + step_secs
[docs]def scale_labels_secs(labels: List[Tuple[Any, float, float]], ratio: float):
"""
When the recording length is changed due to like pitch or speed manipulation,
the start/end timestamp (in seconds) should also be changed
Args:
labels (List[Tuple[Any, float, float]]): each chunk label is in (label, start_sec, end_sec)
ratio (float): the scaling ratio
Returns:
List[Tuple[Any, float, float]]: the scaled labels
"""
assert ratio > 0
return [(label, start * ratio, end * ratio) for label, start, end in labels]
[docs]def get_chunk_labels(
start_sec: float,
end_sec: float,
labels: List[Tuple[Any, float, float]],
):
"""
Given a pair a start, end points, filter out the relevant labels from the given :code:`labels`
and refine the start/end points of each label to reside between :code:`start_sec` and :code:`end_sec`
Args:
start_sec (float): the starting point
end_sec (float): the ending point
labels (List[Tuple[Any, float, float]]): the chunk labels
Returns:
List[Tuple[str, float, float]]: filtered labels. Only the labels relevant to the assigned
start/end point are left
"""
for label, start, end in labels:
assert start < end, f"start ({start}) >= end ({end})"
if start >= end_sec:
continue
if end <= start_sec:
continue
yield label, max(start_sec, start), min(end_sec, end)
[docs]def chunk_labels_to_frame_tensor_label(
start_sec: float,
end_sec: float,
labels: List[Tuple[int, float, float]],
num_class: int,
frame_shift: int,
sample_rate: int = 16000,
):
"""
Produce frame-level labels for the given chunk labels
Args:
start_sec (float): the starting point of the chunk
end_sec (float): the ending point of the chunk
labels (List[Tuple[int, float, float]]): the chunk labels, each label is a tuple
in (class_id, start_sec, end_sec)
num_class (int): number of classes
frame_shift (int): produce a frame per :code:`frame_shift` samples
sample_rate (int): the sample rate of the recording. default: 16000
Returns:
torch.FloatTensor: shape (num_frames, num_class).
the binary frame labels for the given :code:`labels`
"""
labels = get_chunk_labels(start_sec, end_sec, labels)
duration = end_sec - start_sec
num_frames = len(range(0, round(duration * sample_rate), frame_shift))
frame_labels = torch.zeros(num_frames, num_class)
for class_id, start, end in labels:
assert start >= start_sec, f"{start} < {start_sec}"
assert end >= start_sec, f"{end} < {start_sec}"
start_frame = round((start - start_sec) * sample_rate) // frame_shift
end_frame = round((end - start_sec) * sample_rate) // frame_shift
frame_labels[start_frame : end_frame + 1, class_id] = 1.0
return frame_labels
[docs]class FrameLabelDataset(Dataset):
"""
Args:
df (pd.DataFrame): the dataframe should have the following columns
record_id (str), wav_path (str), duration (float), utt_id (str),
label (int), start_sec (float), end_sec (float)
"""
def __init__(
self,
df: pd.DataFrame,
num_class: int,
frame_shift: int,
chunk_secs: float,
step_secs: float,
use_unfull_chunks: bool = True,
load_audio_conf: dict = None,
sample_rate: int = 16000,
) -> None:
super().__init__()
self.num_class = num_class
self.frame_shift = frame_shift
self.sample_rate = sample_rate
recording_df = df[["record_id", "wav_path", "duration"]].drop_duplicates()
record_ids = recording_df["record_id"].tolist()
record_to_labels = {}
for record_id in record_ids:
subset_df = df[df["record_id"] == record_id]
labels = list(
zip(subset_df["label"], subset_df["start_sec"], subset_df["end_sec"])
)
record_to_labels[record_id] = labels
self.chunked_utts = []
for _, row in recording_df.iterrows():
chunks = chunking(
0.0,
row["duration"],
chunk_secs,
step_secs,
use_unfull_chunks,
)
for chunk_id, (start, end) in enumerate(chunks):
labels = list(
get_chunk_labels(start, end, record_to_labels[row["record_id"]])
)
self.chunked_utts.append(
{
"record_id": row["record_id"],
"chunk_id": chunk_id,
"wav_path": row["wav_path"],
"start_sec": start,
"end_sec": end,
"unique_name": f"{row['record_id']}-{start}-{end}",
"labels": labels,
}
)
def flatten(data: List[dict], key: str):
return [item[key] for item in data]
self.audio_loader = LoadAudio(
flatten(self.chunked_utts, "wav_path"),
flatten(self.chunked_utts, "start_sec"),
flatten(self.chunked_utts, "end_sec"),
**(load_audio_conf or {}),
)
def __len__(self) -> int:
return len(self.chunked_utts)
[docs] def getinfo(self, index: int):
return self.chunked_utts[index]
def __getitem__(self, index):
info = self.getinfo(index)
audio = self.audio_loader[index]
label = chunk_labels_to_frame_tensor_label(
info["start_sec"],
info["end_sec"],
info["labels"],
self.num_class,
self.frame_shift,
self.sample_rate,
)
return {
"x": audio["wav"],
"x_len": audio["wav_len"],
"y": label,
"y_len": len(label),
"unique_name": info["unique_name"],
"labels": info["labels"],
"record_id": info["record_id"],
"chunk_id": info["chunk_id"],
}