Source code for s3prl.util.download

"""
Thread-safe file downloading and cacheing

Authors
  * Leo 2022
  * Cheng Liang 2022
"""

import hashlib
import logging
import os
import shutil
import sys
import tempfile
import time
from pathlib import Path
from urllib.request import Request, urlopen

import requests
from filelock import FileLock
from tqdm import tqdm

logger = logging.getLogger(__name__)


_download_dir = Path.home() / ".cache" / "s3prl" / "download"

__all__ = [
    "get_dir",
    "set_dir",
    "download",
    "urls_to_filepaths",
]


[docs]def get_dir(): _download_dir.mkdir(exist_ok=True, parents=True) return _download_dir
[docs]def set_dir(d): global _download_dir _download_dir = Path(d)
def _download_url_to_file(url, dst, hash_prefix=None, progress=True): """ This function is not thread-safe. Please ensure only a single thread or process can enter this block at the same time """ file_size = None req = Request(url, headers={"User-Agent": "torch.hub"}) u = urlopen(req) meta = u.info() if hasattr(meta, "getheaders"): content_length = meta.getheaders("Content-Length") else: content_length = meta.get_all("Content-Length") if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) dst = os.path.expanduser(dst) dst_dir = os.path.dirname(dst) f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) try: if hash_prefix is not None: sha256 = hashlib.sha256() tqdm.write(f"Downloading: {url}", file=sys.stderr) tqdm.write(f"Destination: {dst}", file=sys.stderr) with tqdm( total=file_size, disable=not progress, unit="B", unit_scale=True, unit_divisor=1024, ) as pbar: while True: buffer = u.read(8192) if len(buffer) == 0: break f.write(buffer) if hash_prefix is not None: sha256.update(buffer) pbar.update(len(buffer)) f.close() if hash_prefix is not None: digest = sha256.hexdigest() if digest[: len(hash_prefix)] != hash_prefix: raise RuntimeError( 'invalid hash value (expected "{}", got "{}")'.format( hash_prefix, digest ) ) shutil.move(f.name, dst) finally: f.close() if os.path.exists(f.name): os.remove(f.name) def _download_url_to_file_requests(url, dst, hash_prefix=None, progress=True): """ Alternative download when urllib.Request fails. """ req = requests.get(url, stream=True, headers={"User-Agent": "torch.hub"}) file_size = int(req.headers["Content-Length"]) dst = os.path.expanduser(dst) dst_dir = os.path.dirname(dst) f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) try: if hash_prefix is not None: sha256 = hashlib.sha256() tqdm.write( f"urllib.Request method failed. Trying using another method...", file=sys.stderr, ) tqdm.write(f"Downloading: {url}", file=sys.stderr) tqdm.write(f"Destination: {dst}", file=sys.stderr) with tqdm( total=file_size, disable=not progress, unit="B", unit_scale=True, unit_divisor=1024, ) as pbar: for chunk in req.iter_content(chunk_size=1024 * 1024 * 10): if chunk: f.write(chunk) f.flush() os.fsync(f.fileno()) if hash_prefix is not None: sha256.update(chunk) pbar.update(len(chunk)) f.close() if hash_prefix is not None: digest = sha256.hexdigest() if digest[: len(hash_prefix)] != hash_prefix: raise RuntimeError( 'invalid hash value (expected "{}", got "{}")'.format( hash_prefix, digest ) ) shutil.move(f.name, dst) finally: f.close() if os.path.exists(f.name): os.remove(f.name) def _download(filepath: Path, url, refresh: bool, new_enough_secs: float = 2.0): """ If refresh is True, check the latest modfieid time of the filepath. If the file is new enough (no older than `new_enough_secs`), than directly use it. If the file is older than `new_enough_secs`, than re-download the file. This function is useful when multi-processes are all downloading the same large file """ Path(filepath).parent.mkdir(exist_ok=True, parents=True) lock_file = Path(str(filepath) + ".lock") logger.info(f"Requesting URL: {url}") with FileLock(str(lock_file)): if not filepath.is_file() or ( refresh and (time.time() - os.path.getmtime(filepath)) > new_enough_secs ): try: _download_url_to_file(url, filepath) except: _download_url_to_file_requests(url, filepath) logger.info(f"Using URL's local file: {filepath}") def _urls_to_filepaths(*args, refresh=False, download: bool = True): """ Preprocess the URL specified in *args into local file paths after downloading Args: Any number of URLs (1 ~ any) Return: Same number of downloaded file paths """ def _url_to_filepath(url): assert isinstance(url, str) m = hashlib.sha256() m.update(str.encode(url)) filepath = get_dir() / f"{str(m.hexdigest())}.{Path(url).name}" if download: _download(filepath, url, refresh=refresh) return str(filepath.resolve()) paths = [_url_to_filepath(url) for url in args] return paths if len(paths) > 1 else paths[0] download = _download urls_to_filepaths = _urls_to_filepaths