Source code for s3prl.nn.speaker_loss

"""
Speaker verification loss

Authors:
  * Haibin Wu 2022
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
    "softmax",
    "amsoftmax",
]


[docs]class softmax(nn.Module): """ The standard softmax loss in an unified interface for all speaker-related softmax losses """ def __init__(self, input_size: int, output_size: int): super().__init__() self._indim = input_size self._outdim = output_size self.fc = nn.Linear(input_size, output_size) self.criertion = nn.CrossEntropyLoss() @property def input_size(self): return self._indim @property def output_size(self): return self._outdim
[docs] def forward(self, x: torch.Tensor, label: torch.LongTensor): """ Args: x (torch.Tensor): (batch_size, input_size) label (torch.LongTensor): (batch_size, ) Returns: loss (torch.float) logit (torch.Tensor): (batch_size, ) """ assert x.size()[0] == label.size()[0] assert x.size()[1] == self.input_size x = F.normalize(x, dim=1) x = self.fc(x) loss = self.criertion(x, label) return loss, x
[docs]class amsoftmax(nn.Module): """ AMSoftmax Args: input_size (int): The input feature size output_size (int): The output feature size margin (float): Hyperparameter denotes the margin to the decision boundry scale (float): Hyperparameter that scales the cosine value """ def __init__( self, input_size: int, output_size: int, margin: float = 0.2, scale: float = 30 ): super().__init__() self._indim = input_size self._outdim = output_size self.margin = margin self.scale = scale self.W = torch.nn.Parameter( torch.randn(input_size, output_size), requires_grad=True ) self.ce = nn.CrossEntropyLoss() nn.init.xavier_normal_(self.W, gain=1) @property def input_size(self): return self._indim @property def output_size(self): return self._outdim
[docs] def forward(self, x: torch.Tensor, label: torch.LongTensor): """ Args: x (torch.Tensor): (batch_size, input_size) label (torch.LongTensor): (batch_size, ) Returns: loss (torch.float) logit (torch.Tensor): (batch_size, ) """ assert x.size()[0] == label.size()[0] assert x.size()[1] == self.input_size x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) x_norm = torch.div(x, x_norm) w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) w_norm = torch.div(self.W, w_norm) costh = torch.mm(x_norm, w_norm) label_view = label.view(-1, 1) if label_view.is_cuda: label_view = label_view.cpu() delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.margin) if x.is_cuda: delt_costh = delt_costh.cuda() costh_m = costh - delt_costh costh_m_s = self.scale * costh_m loss = self.ce(costh_m_s, label) return loss, costh_m_s