Source code for s3prl.dataio.sampler.fixed_batch_size_batch_sampler

"""
The most commonly used batch sampler, recover the default batch sampler used
in torch DataLoader

Authors:
  * Leo 2022
"""

import torch
from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler

__all__ = ["FixedBatchSizeBatchSampler"]


[docs]class FixedBatchSizeBatchSampler: """ The reduced timestamps for a batch should not exceed the max_timestamp. If shuffled, each indices are first shuffled before aggregated into batches Args: data_source: __len__ is implemented """ def __init__( self, data_source, batch_size: int, shuffle: bool = False, seed: int = 12345678, ) -> None: self.batch_size = batch_size self.seed = seed self.shuffle = shuffle if shuffle: self.generator = torch.Generator() self.sampler = RandomSampler(data_source, generator=self.generator) else: self.sampler = SequentialSampler(data_source)
[docs] def set_epoch(self, epoch: int) -> None: if self.shuffle: self.generator.manual_seed(self.seed + epoch)
def _evaluate_reduced_timestamps(self, batch_indices): return self.reduce_func([self.timestamps[indice] for indice in batch_indices]) def __iter__(self): batch_sampler = BatchSampler( self.sampler, batch_size=self.batch_size, drop_last=False ) return iter(batch_sampler) def __len__(self): return len(list(iter(self)))