sampler#

(s3prl.dataio.sampler)

Control how torch DataLoader group instances into a batch

s3prl.dataio.sampler.balanced_weighted_sampler

For datasets with highly unbalanced class

s3prl.dataio.sampler.distributed_sampler

Wrap any batch sampler for distributed training

s3prl.dataio.sampler.fixed_batch_size_batch_sampler

The most commonly used batch sampler, recover the default batch sampler used in torch DataLoader

s3prl.dataio.sampler.group_same_item_sampler

Group the data points with the same key into the same batch

s3prl.dataio.sampler.max_timestamp_batch_sampler

Limit the maximum timestamps in a batch to realize dynamic batching.

s3prl.dataio.sampler.sorted_sampler

The most commonly used batch sampler in S3PRL legacy codebase, which sorts the lengths of all the data points and group the instances with the similar lengths together.

BalancedWeightedSampler#

class s3prl.dataio.sampler.BalancedWeightedSampler(labels: List[str], batch_size: int, duplicate: int = 1, seed: int = 12345678)[source][source]#

Bases: object

This batch sampler is always randomized, hence cannot be used for testing

set_epoch(epoch: int)[source][source]#

DistributedBatchSamplerWrapper#

class s3prl.dataio.sampler.DistributedBatchSamplerWrapper(batch_sampler: BatchSampler, num_replicas: Optional[int] = None, rank: Optional[int] = None, allow_duplicates: bool = False, allow_uneven: bool = False)[source][source]#

Bases: object

set_epoch(epoch: int) None[source][source]#

FixedBatchSizeBatchSampler#

class s3prl.dataio.sampler.FixedBatchSizeBatchSampler(data_source, batch_size: int, shuffle: bool = False, seed: int = 12345678)[source][source]#

Bases: object

The reduced timestamps for a batch should not exceed the max_timestamp. If shuffled, each indices are first shuffled before aggregated into batches

Parameters:

data_source – __len__ is implemented

set_epoch(epoch: int) None[source][source]#

GroupSameItemSampler#

class s3prl.dataio.sampler.GroupSameItemSampler(items: List[str])[source][source]#

Bases: object

set_epoch(epoch: int)[source][source]#

MaxTimestampBatchSampler#

class s3prl.dataio.sampler.MaxTimestampBatchSampler(lengths: List[int], max_length: int, shuffle: bool = False, seed: int = 12345678, reduce_func: Optional[callable] = None)[source][source]#

Bases: object

The reduced timestamps for a batch should not exceed the max_timestamp. If shuffled, each indices are first shuffled before aggregated into batches

set_epoch(epoch: int)[source][source]#

SortedBucketingSampler#

class s3prl.dataio.sampler.SortedBucketingSampler(lengths: List[int], batch_size: int, max_length: int = 300000, shuffle: bool = False, in_batch_shuffle: bool = False, seed: int = 12345678)[source][source]#

Bases: object

Parameters:
  • lengths (List[int]) –

  • batch_size (int) – the default batch size

  • max_length (int) – if a batch contains at least on utt longer than max_length, half the batch

  • get_length_func (callable) – get the length of each item in the dataset, if None, a default function will be used

  • shuffle (bool) – Whether to shuffle the batches

  • in_batch_shuffle (bool) – if False, batches are sorted by length from long to short

set_epoch(epoch: int)[source][source]#

SortedSliceSampler#

class s3prl.dataio.sampler.SortedSliceSampler(lengths: List[int], batch_size: int, max_length: int = 300000, seed: int = 12345678, in_batch_shuffle: bool = False)[source][source]#

Bases: object

This sampler should only be used for training hence is always in random shuffle mode

Parameters:
  • lengths (List[int]) –

  • batch_size (int) – the default batch size

  • max_length (int) – if a batch contains at least on utt longer than max_length, half the batch

  • get_length_func (callable) – get the length of each item in the dataset, if None, a default function will be used

  • in_batch_shuffle (bool) – if False, batches are sorted by length from long to short

set_epoch(epoch: int)[source][source]#