

Control how torch DataLoader group instances into a batch


For datasets with highly unbalanced class


Wrap any batch sampler for distributed training


The most commonly used batch sampler, recover the default batch sampler used in torch DataLoader


Group the data points with the same key into the same batch


Limit the maximum timestamps in a batch to realize dynamic batching.


The most commonly used batch sampler in S3PRL legacy codebase, which sorts the lengths of all the data points and group the instances with the similar lengths together.


class s3prl.dataio.sampler.BalancedWeightedSampler(labels: List[str], batch_size: int, duplicate: int = 1, seed: int = 12345678)[source][source]#

Bases: object

This batch sampler is always randomized, hence cannot be used for testing

set_epoch(epoch: int)[source][source]#


class s3prl.dataio.sampler.DistributedBatchSamplerWrapper(batch_sampler: BatchSampler, num_replicas: Optional[int] = None, rank: Optional[int] = None, allow_duplicates: bool = False, allow_uneven: bool = False)[source][source]#

Bases: object

set_epoch(epoch: int) None[source][source]#


class s3prl.dataio.sampler.FixedBatchSizeBatchSampler(data_source, batch_size: int, shuffle: bool = False, seed: int = 12345678)[source][source]#

Bases: object

The reduced timestamps for a batch should not exceed the max_timestamp. If shuffled, each indices are first shuffled before aggregated into batches


data_source – __len__ is implemented

set_epoch(epoch: int) None[source][source]#


class s3prl.dataio.sampler.GroupSameItemSampler(items: List[str])[source][source]#

Bases: object

set_epoch(epoch: int)[source][source]#


class s3prl.dataio.sampler.MaxTimestampBatchSampler(lengths: List[int], max_length: int, shuffle: bool = False, seed: int = 12345678, reduce_func: Optional[callable] = None)[source][source]#

Bases: object

The reduced timestamps for a batch should not exceed the max_timestamp. If shuffled, each indices are first shuffled before aggregated into batches

set_epoch(epoch: int)[source][source]#


class s3prl.dataio.sampler.SortedBucketingSampler(lengths: List[int], batch_size: int, max_length: int = 300000, shuffle: bool = False, in_batch_shuffle: bool = False, seed: int = 12345678)[source][source]#

Bases: object

  • lengths (List[int]) –

  • batch_size (int) – the default batch size

  • max_length (int) – if a batch contains at least on utt longer than max_length, half the batch

  • get_length_func (callable) – get the length of each item in the dataset, if None, a default function will be used

  • shuffle (bool) – Whether to shuffle the batches

  • in_batch_shuffle (bool) – if False, batches are sorted by length from long to short

set_epoch(epoch: int)[source][source]#


class s3prl.dataio.sampler.SortedSliceSampler(lengths: List[int], batch_size: int, max_length: int = 300000, seed: int = 12345678, in_batch_shuffle: bool = False)[source][source]#

Bases: object

This sampler should only be used for training hence is always in random shuffle mode

  • lengths (List[int]) –

  • batch_size (int) – the default batch size

  • max_length (int) – if a batch contains at least on utt longer than max_length, half the batch

  • get_length_func (callable) – get the length of each item in the dataset, if None, a default function will be used

  • in_batch_shuffle (bool) – if False, batches are sorted by length from long to short

set_epoch(epoch: int)[source][source]#