pit#

(s3prl.nn.pit)

Permutation Invariant Training (PIT) loss

Authors:
  • Jiatong Shi 2021

pit_loss#

s3prl.nn.pit.pit_loss(output, label, length)[source][source]#

The Permutation Invariant Training loss

Parameters:
  • output (torch.FloatTensor) – prediction in (batch_size, seq_len, num_class)

  • label (torch.FloatTensor) – label in the same shape as output

  • length (torch.LongTensor) – the valid length of each instance. output and label share the same valid length

Returns:

  1. loss (torch.FloatTensor)

  2. min_idx (int): the id with the minimum loss

  3. all the permutation

Return type:

tuple