Source code for s3prl.dataio.dataset.util
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import List
from joblib import Parallel, delayed
from tqdm import tqdm
logger = logging.getLogger(__name__)
__all__ = [
"get_info",
]
[docs]def get_info(dataset, names: List[str], cache_dir: str = None, n_jobs: int = 6):
logger.info(
f"Getting info from dataset {dataset.__class__.__qualname__}: {' '.join(names)}"
)
if isinstance(cache_dir, (str, Path)):
logger.info(f"Using cached info in {cache_dir}")
cache_dir: Path = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
try:
data = dataset.getinfo(0)
for name in names:
assert name in data
except:
fn = dataset.__getitem__
else:
fn = dataset.getinfo
def _get(idx):
if isinstance(cache_dir, (str, Path)):
cache_path: Path = Path(cache_dir) / f"{idx}.json"
if cache_path.is_file():
with cache_path.open() as f:
cached = json.load(f)
all_presented = True
for name in names:
if name not in cached:
all_presented = False
if all_presented:
return cached
data = fn(idx)
info = {}
for name in names:
info[name] = data[name]
if isinstance(cache_dir, (str, Path)):
cache_path: Path = Path(cache_dir) / f"{idx}.json"
with cache_path.open("w") as f:
json.dump(info, f)
return info
infos = Parallel(n_jobs=n_jobs, backend="threading")(
delayed(_get)(idx) for idx in tqdm(range(len(dataset)))
)
organized_info = defaultdict(list)
for info in infos:
for k, v in info.items():
organized_info[k].append(v)
output = []
for name in names:
output.append(organized_info[name])
if len(output) == 1:
return output[0]
else:
return output