sampler#
(s3prl.dataio.sampler)
Control how torch DataLoader group instances into a batch
For datasets with highly unbalanced class |
|
Wrap any batch sampler for distributed training |
|
The most commonly used batch sampler, recover the default batch sampler used in torch DataLoader |
|
Group the data points with the same key into the same batch |
|
Limit the maximum timestamps in a batch to realize dynamic batching. |
|
The most commonly used batch sampler in S3PRL legacy codebase, which sorts the lengths of all the data points and group the instances with the similar lengths together. |
BalancedWeightedSampler#
DistributedBatchSamplerWrapper#
FixedBatchSizeBatchSampler#
- class s3prl.dataio.sampler.FixedBatchSizeBatchSampler(data_source, batch_size: int, shuffle: bool = False, seed: int = 12345678)[source][source]#
Bases:
object
The reduced timestamps for a batch should not exceed the max_timestamp. If shuffled, each indices are first shuffled before aggregated into batches
- Parameters:
data_source – __len__ is implemented
GroupSameItemSampler#
MaxTimestampBatchSampler#
- class s3prl.dataio.sampler.MaxTimestampBatchSampler(lengths: List[int], max_length: int, shuffle: bool = False, seed: int = 12345678, reduce_func: Optional[callable] = None)[source][source]#
Bases:
object
The reduced timestamps for a batch should not exceed the max_timestamp. If shuffled, each indices are first shuffled before aggregated into batches
SortedBucketingSampler#
- class s3prl.dataio.sampler.SortedBucketingSampler(lengths: List[int], batch_size: int, max_length: int = 300000, shuffle: bool = False, in_batch_shuffle: bool = False, seed: int = 12345678)[source][source]#
Bases:
object
- Parameters:
lengths (List[int]) –
batch_size (int) – the default batch size
max_length (int) – if a batch contains at least on utt longer than max_length, half the batch
get_length_func (callable) – get the length of each item in the dataset, if None, a default function will be used
shuffle (bool) – Whether to shuffle the batches
in_batch_shuffle (bool) – if False, batches are sorted by length from long to short
SortedSliceSampler#
- class s3prl.dataio.sampler.SortedSliceSampler(lengths: List[int], batch_size: int, max_length: int = 300000, seed: int = 12345678, in_batch_shuffle: bool = False)[source][source]#
Bases:
object
This sampler should only be used for training hence is always in random shuffle mode
- Parameters:
lengths (List[int]) –
batch_size (int) – the default batch size
max_length (int) – if a batch contains at least on utt longer than max_length, half the batch
get_length_func (callable) – get the length of each item in the dataset, if None, a default function will be used
in_batch_shuffle (bool) – if False, batches are sorted by length from long to short