Source code for s3prl.util.audio_info

import json
from pathlib import Path
from typing import List

import torchaudio
from joblib import Parallel, delayed
from tqdm import tqdm

_default_cache_dir = Path.home() / ".cache" / "s3prl" / "audio_info"


__all__ = [
    "get_cache_dir",
    "set_cache_dir",
    "get_audio_info",
]


[docs]def get_cache_dir(): _default_cache_dir.mkdir(exist_ok=True, parents=True) return _default_cache_dir
[docs]def set_cache_dir(cache_dir: str): global _default_cache_dir _default_cache_dir = Path(cache_dir)
[docs]def get_audio_info( audio_paths: List[str], audio_ids: List[str], cache_dir: str = None, num_workers: int = 6, ) -> List[dict]: """ Use :code:`torchaudio.info` to retrieve the metadata from audio paths. The retrieved metadata is cached in :code:`cache_dir` """ cache_dir = cache_dir or get_cache_dir() cache_dir: Path = Path(cache_dir) def _get_info(audio_path: str, audio_id: str): cache_file = cache_dir / f"{audio_id}.json" if cache_file.is_file(): with cache_file.open() as f: info = json.load(f) return info torchaudio.set_audio_backend("sox_io") torchaudio_info = torchaudio.info(audio_path) info = { "sample_rate": torchaudio_info.sample_rate, "num_frames": torchaudio_info.num_frames, "num_channels": torchaudio_info.num_channels, "bits_per_sample": torchaudio_info.bits_per_sample, "encoding": torchaudio_info.encoding, } return info infos = Parallel(n_jobs=num_workers)( delayed(_get_info)(path, idx) for path, idx in tqdm(zip(audio_paths, audio_ids), desc="Get audio metadata") ) return infos