Source code for s3prl.dataio.dataset.diarization

# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
#   FileName     [ dataset.py ]
#   Synopsis     [ the speaker diarization dataset ]
#   Source       [ Refactored from https://github.com/hitachi-speech/EEND ]
#   Author       [ Jiatong Shi ]
#   Copyright    [ Copyright(c), Johns Hopkins University ]
"""*********************************************************************************************"""


import chunk
import io
import os
import subprocess
import sys

import numpy as np
import pandas as pd
import soundfile as sf

from . import Dataset


def _count_frames(data_len, size, step):
    # no padding at edges, last remaining samples are ignored
    return int((data_len - size + step) / step)


def _gen_frame_indices(
    data_length,
    size=2000,
    step=2000,
    use_last_samples=False,
    label_delay=0,
    subsampling=1,
):
    i = -1
    for i in range(_count_frames(data_length, size, step)):
        yield i * step, i * step + size
    if use_last_samples and i * step + size < data_length:
        if data_length - (i + 1) * step - subsampling * label_delay > 0:
            yield (i + 1) * step, data_length


def _gen_chunk_indices(data_len, chunk_size):
    step = chunk_size
    start = 0
    while start < data_len:
        end = min(data_len, start + chunk_size)
        yield start, end
        start += step


[docs]class DiarizationDataset(Dataset): def __init__( self, mode, data_dir, chunk_size=2000, frame_shift=256, subsampling=1, rate=16000, use_last_samples=True, label_delay=0, num_speakers=None, ): super().__init__() self.mode = mode self.data_dir = data_dir self.dtype = np.float32 self.chunk_size = chunk_size self.frame_shift = frame_shift self.subsampling = subsampling self.n_speakers = num_speakers self.chunk_indices = [] self.label_delay = label_delay self.data = KaldiData(self.data_dir) # make chunk indices: filepath, start_frame, end_frame for rec in self.data.wavs: data_len = int(self.data.reco2dur[rec] * rate / frame_shift) data_len = int(data_len / self.subsampling) for chunk_id, (st, ed) in enumerate( _gen_frame_indices( data_len, chunk_size, chunk_size, use_last_samples, label_delay=self.label_delay, subsampling=self.subsampling, ) ): self.chunk_indices.append( (rec, chunk_id, st * self.subsampling, ed * self.subsampling) ) def __len__(self): return len(self.chunk_indices)
[docs] def getinfo(self, i): rec, chunk_id, st, ed = self.chunk_indices[i] return { "record_id": rec, "chunk_id": chunk_id, }
def __getitem__(self, i): rec, chunk_id, st, ed = self.chunk_indices[i] X, T = self._get_labeled_speech(rec, st, ed, self.n_speakers) return { "x": X, "x_len": len(X), "label": T, "label_len": len(T), "record_id": rec, "chunk_id": chunk_id, } def _get_labeled_speech( self, rec, start, end, n_speakers=None, use_speaker_id=False ): """Extracts speech chunks and corresponding labels Extracts speech chunks and corresponding diarization labels for given recording id and start/end times Args: rec (str): recording id start (int): start frame index end (int): end frame index n_speakers (int): number of speakers if None, the value is given from data Returns: data: speech chunk (n_samples) T: label (n_frmaes, n_speakers)-shaped np.int32 array. """ data, rate = self.data.load_wav( rec, start * self.frame_shift, end * self.frame_shift ) frame_num = end - start filtered_segments = self.data.segments[rec] # filtered_segments = self.data.segments[self.data.segments['rec'] == rec] speakers = np.unique( [self.data.utt2spk[seg["utt"]] for seg in filtered_segments] ).tolist() if n_speakers is None: n_speakers = len(speakers) T = np.zeros((frame_num, n_speakers), dtype=np.int32) if use_speaker_id: all_speakers = sorted(self.data.spk2utt.keys()) S = np.zeros((frame_num, len(all_speakers)), dtype=np.int32) for seg in filtered_segments: speaker_index = speakers.index(self.data.utt2spk[seg["utt"]]) if use_speaker_id: all_speaker_index = all_speakers.index(self.data.utt2spk[seg["utt"]]) start_frame = np.rint(seg["st"] * rate / self.frame_shift).astype(int) end_frame = np.rint(seg["et"] * rate / self.frame_shift).astype(int) rel_start = rel_end = None if start <= start_frame and start_frame < end: rel_start = start_frame - start if start < end_frame and end_frame <= end: rel_end = end_frame - start if rel_start is not None or rel_end is not None: T[rel_start:rel_end, speaker_index] = 1 if use_speaker_id: S[rel_start:rel_end, all_speaker_index] = 1 if use_speaker_id: return data, T, S else: return data, T
####################### # Kaldi-style Dataset # #######################
[docs]class KaldiData: """This class holds data in kaldi-style directory.""" def __init__(self, data_dir): """Load kaldi data directory.""" self.data_dir = data_dir self.segments = self._load_segments_rechash( os.path.join(self.data_dir, "segments") ) self.utt2spk = self._load_utt2spk(os.path.join(self.data_dir, "utt2spk")) self.wavs = self._load_wav_scp(os.path.join(self.data_dir, "wav.scp")) self.reco2dur = self._load_reco2dur(os.path.join(self.data_dir, "reco2dur")) self.spk2utt = self._load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
[docs] def load_wav(self, recid, start=0, end=None): """Load wavfile given recid, start time and end time.""" data, rate = self._load_wav(self.wavs[recid], start, end) return data, rate
def _load_segments(self, segments_file): """Load segments file as array.""" if not os.path.exists(segments_file): return None return np.loadtxt( segments_file, dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")], ndmin=1, ) def _load_segments_hash(self, segments_file): """Load segments file as dict with uttid index.""" ret = {} if not os.path.exists(segments_file): return None for line in open(segments_file): utt, rec, st, et = line.strip().split() ret[utt] = (rec, float(st), float(et)) return ret def _load_segments_rechash(self, segments_file): """Load segments file as dict with recid index.""" ret = {} if not os.path.exists(segments_file): return None for line in open(segments_file): utt, rec, st, et = line.strip().split() if rec not in ret: ret[rec] = [] ret[rec].append({"utt": utt, "st": float(st), "et": float(et)}) return ret def _load_wav_scp(self, wav_scp_file): """Return dictionary { rec: wav_rxfilename }.""" if os.path.exists(wav_scp_file): lines = [line.strip().split(None, 1) for line in open(wav_scp_file)] return {x[0]: x[1] for x in lines} else: wav_dir = os.path.join(self.data_dir, "wav") return { os.path.splitext(filename)[0]: os.path.join(wav_dir, filename) for filename in sorted(os.listdir(wav_dir)) } def _load_wav(self, wav_rxfilename, start=0, end=None): """This function reads audio file and return data in numpy.float32 array. "lru_cache" holds recently loaded audio so that can be called many times on the same audio file. OPTIMIZE: controls lru_cache size for random access, considering memory size """ if wav_rxfilename.endswith("|"): # input piped command p = subprocess.Popen( wav_rxfilename[:-1], shell=True, stdout=subprocess.PIPE, ) data, samplerate = sf.read( io.BytesIO(p.stdout.read()), dtype="float32", ) # cannot seek data = data[start:end] elif wav_rxfilename == "-": # stdin data, samplerate = sf.read(sys.stdin, dtype="float32") # cannot seek data = data[start:end] else: # normal wav file data, samplerate = sf.read(wav_rxfilename, start=start, stop=end) return data, samplerate def _load_utt2spk(self, utt2spk_file): """Returns dictionary { uttid: spkid }.""" lines = [line.strip().split(None, 1) for line in open(utt2spk_file)] return {x[0]: x[1] for x in lines} def _load_spk2utt(self, spk2utt_file): """Returns dictionary { spkid: list of uttids }.""" if not os.path.exists(spk2utt_file): return None lines = [line.strip().split() for line in open(spk2utt_file)] return {x[0]: x[1:] for x in lines} def _load_reco2dur(self, reco2dur_file): """Returns dictionary { recid: duration }.""" if not os.path.exists(reco2dur_file): return None lines = [line.strip().split(None, 1) for line in open(reco2dur_file)] return {x[0]: float(x[1]) for x in lines} def _process_wav(self, wav_rxfilename, process): """This function returns preprocessed wav_rxfilename. Args: wav_rxfilename: input process: command which can be connected via pipe, use stdin and stdout Returns: wav_rxfilename: output piped command """ if wav_rxfilename.endswith("|"): # input piped command return wav_rxfilename + process + "|" # stdin "-" or normal file return "cat {0} | {1} |".format(wav_rxfilename, process) def _extract_segments(self, wavs, segments=None): """This function returns generator of segmented audio. Yields (utterance id, numpy.float32 array). TODO?: sampling rate is not converted. """ if segments is not None: # segments should be sorted by rec-id for seg in segments: wav = wavs[seg["rec"]] data, samplerate = self.load_wav(wav) st_sample = np.rint(seg["st"] * samplerate).astype(int) et_sample = np.rint(seg["et"] * samplerate).astype(int) yield seg["utt"], data[st_sample:et_sample] else: # segments file not found, # wav.scp is used as segmented audio list for rec in wavs: data, samplerate = self.load_wav(wavs[rec]) yield rec, data