pit#
(s3prl.nn.pit)
Permutation Invariant Training (PIT) loss
- Authors:
Jiatong Shi 2021
pit_loss#
- s3prl.nn.pit.pit_loss(output, label, length)[source][source]#
The Permutation Invariant Training loss
- Parameters:
output (torch.FloatTensor) – prediction in (batch_size, seq_len, num_class)
label (torch.FloatTensor) – label in the same shape as
output
length (torch.LongTensor) – the valid length of each instance.
output
andlabel
share the same valid length
- Returns:
loss (torch.FloatTensor)
min_idx (int): the id with the minimum loss
all the permutation
- Return type:
tuple