Source code for s3prl.util.pseudo_data

"""
Create pseudo data

Authors
  * Leo 2022
"""

import random
import shutil
import tempfile
from pathlib import Path
from typing import Any, List

import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence

SAMPLE_RATE = 16000

__all__ = [
    "pseudo_audio",
    "get_pseudo_wavs",
]


[docs]class pseudo_audio: """ This context manager returns filepaths (List[str]) and num_samples (List[int]) on entering """ def __init__(self, secs: List[float], sample_rate: int = SAMPLE_RATE): self.tempdir = Path(tempfile.TemporaryDirectory().name) self.tempdir.mkdir(parents=True, exist_ok=True) self.num_samples = [] for n, sec in enumerate(secs): wav = torch.randn(1, round(sample_rate * sec)) torchaudio.save( str(self.tempdir / f"audio_{n}.wav"), wav, sample_rate=sample_rate ) self.num_samples.append(wav.size(-1)) self.filepaths = [ str(self.tempdir / f"audio_{i}.wav") for i in range(len(secs)) ] def __enter__(self): return self.filepaths, self.num_samples def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: shutil.rmtree(self.tempdir)
[docs]def get_pseudo_wavs( seed: int = 0, n: int = 2, min_secs: int = 1, max_secs: int = 3, sample_rate: int = SAMPLE_RATE, device: str = "cpu", padded: bool = False, ): random.seed(seed) torch.manual_seed(seed) wavs = [] wavs_len = [] for _ in range(n): wav_length = random.randint(min_secs * sample_rate, max_secs * sample_rate) wav = torch.randn(wav_length, requires_grad=True).to(device) wavs_len.append(wav_length) wavs.append(wav) if not padded: return wavs else: return pad_sequence(wavs, batch_first=True), torch.LongTensor(wavs_len)