Source code for s3prl.dataio.dataset.load_audio

import random
from typing import List, Tuple

import librosa
import torch
import torchaudio

from . import Dataset

torchaudio.set_audio_backend("sox_io")


[docs]class LoadAudio(Dataset): """ Args: start_secs: use None if load from start end_secs: use None if load to end """ def __init__( self, filepaths: List[str], start_secs: List[float] = None, end_secs: List[float] = None, sox_effects: Tuple[Tuple[str]] = None, individual_sox_effects: List[Tuple[Tuple[str]]] = None, max_secs: float = None, generator: random.Random = None, sample_rate: int = 16000, ) -> None: super().__init__() self.filepaths = filepaths self.start_secs = start_secs self.end_secs = end_secs if generator is None: generator = random.Random(12345678) self.generator = generator self.sample_rate = sample_rate self.max_secs = max_secs assert int(start_secs is not None) + int(end_secs is not None) in [ 0, 2, ], "start_secs and end_secs must both be given if anyone is given" assert ( int(sox_effects is not None) + int(individual_sox_effects is not None) <= 1 ) if sox_effects is not None: individual_sox_effects = [sox_effects for _ in range(len(filepaths))] self.individual_sox_effects = individual_sox_effects def __len__(self): return len(self.filepaths) def __getitem__(self, index: int): start_sec = None if self.start_secs is None else self.start_secs[index] start_sec = start_sec or 0.0 end_sec = None if self.end_secs is None else self.end_secs[index] duration = None if end_sec is None else (self.end_secs[index] - start_sec) y, sr = librosa.load( self.filepaths[index], sr=self.sample_rate, offset=start_sec, duration=duration, ) assert sr == self.sample_rate wav = torch.FloatTensor(y).view(1, -1) if self.individual_sox_effects is not None: wav, sr = torchaudio.sox_effects.apply_effects_tensor( wav, sr, effects=self.individual_sox_effects[index] ) if sr != self.sample_rate: wav, sr = torchaudio.transforms.Resample(sr, self.sample_rate)(wav) if self.max_secs is not None: secs = wav.size(-1) / self.sample_rate if secs > self.max_secs: max_samples = round(self.max_secs * self.sample_rate) start = self.generator.randint(0, wav.size(-1) - max_samples) wav = wav[:, start : start + max_samples] wav = wav.view(-1) return { "wav_path": self.filepaths[index], "wav_len": len(wav), "wav": wav, }