Source code for s3prl.util.benchmark

Benchmark the timing a block of code

  * Leo 2022

import logging
from collections import defaultdict
from contextlib import ContextDecorator
from time import time
from typing import Any

import numpy as np
import torch

logger = logging.getLogger(__name__)
_history = defaultdict(list)

__all__ = ["benchmark"]

[docs]class benchmark(ContextDecorator): def __init__(self, name: str, freq: int = 1) -> None: super().__init__() = name self.freq = freq def __enter__(self): torch.cuda.synchronize() self.start = time() def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch.cuda.synchronize() seconds = time() - self.start global _history _history[].append(seconds) if len(_history[]) % self.freq == 0: logger.warning( f"{}: {seconds} secs, avg {np.array(_history[]).mean()} secs" )