"""
Common pooling methods
Authors:
* Leo 2022
* Haibin Wu 2022
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
"MeanPooling",
"TemporalAveragePooling",
"TemporalStatisticsPooling",
"SelfAttentivePooling",
"AttentiveStatisticsPooling",
]
[docs]class MeanPooling(nn.Module):
"""
Computes Temporal Average Pooling (MeanPooling over time) Module
"""
def __init__(self, input_size: int):
super().__init__()
self._in_size = input_size
@property
def input_size(self) -> int:
return self._in_size
@property
def output_size(self) -> int:
return self._in_size
[docs] def forward(self, xs: torch.Tensor, xs_len: torch.LongTensor):
"""
Args:
xs (torch.Tensor): Input tensor (#batch, frames, input_size).
xs_len (torch.LongTensor): with the lengths for each sample
Returns:
torch.Tensor: Output tensor (#batch, input_size)
"""
pooled_list = []
for x, x_len in zip(xs, xs_len):
pooled = torch.mean(x[:x_len], dim=0)
pooled_list.append(pooled)
return torch.stack(pooled_list)
TemporalAveragePooling = MeanPooling
[docs]class TemporalStatisticsPooling(nn.Module):
"""
TemporalStatisticsPooling
Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
"""
def __init__(self, input_size: int):
super().__init__()
self._input_size = input_size
@property
def input_size(self) -> int:
return self._input_size
@property
def output_size(self) -> int:
return self._input_size * 2
[docs] def forward(self, xs, xs_len):
"""
Computes Temporal Statistics Pooling Module
Args:
xs (torch.Tensor): Input tensor (#batch, frames, input_size).
xs_len (torch.LongTensor): with the lengths for each sample
Returns:
torch.Tensor: Output tensor (#batch, output_size)
"""
pooled_list = []
for x, x_len in zip(xs, xs_len):
mean = torch.mean(x[:x_len], dim=0)
std = torch.std(x[:x_len], dim=0)
pooled = torch.cat((mean, std), dim=-1)
pooled_list.append(pooled)
return torch.stack(pooled_list)
[docs]class SelfAttentivePooling(nn.Module):
"""
SelfAttentivePooling
Paper: Self-Attentive Speaker Embeddings for Text-Independent Speaker Verification
Link: https://danielpovey.com/files/2018_interspeech_xvector_attention.pdf
"""
def __init__(self, input_size: int):
super().__init__()
self._indim = input_size
self.sap_linear = nn.Linear(input_size, input_size)
self.attention = nn.Parameter(torch.FloatTensor(input_size, 1))
@property
def input_size(self) -> int:
return self._indim
@property
def output_size(self) -> int:
return self._indim
[docs] def forward(self, xs, xs_len):
"""
Computes Self-Attentive Pooling Module
Args:
xs (torch.Tensor): Input tensor (#batch, frames, input_size).
xs_len (torch.LongTensor): with the lengths for each sample
Returns:
torch.Tensor: Output tensor (#batch, input_size)
"""
pooled_list = []
for x, x_len in zip(xs, xs_len):
x = x[:x_len].unsqueeze(0)
h = torch.tanh(self.sap_linear(x))
w = torch.matmul(h, self.attention).squeeze(dim=2)
w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
x = torch.sum(x * w, dim=1)
pooled_list.append(x.squeeze(0))
return torch.stack(pooled_list)
[docs]class AttentiveStatisticsPooling(nn.Module):
"""
AttentiveStatisticsPooling
Paper: Attentive Statistics Pooling for Deep Speaker Embedding
Link: https://arxiv.org/pdf/1803.10963.pdf
"""
def __init__(self, input_size: int):
super().__init__()
self._indim = input_size
self.sap_linear = nn.Linear(input_size, input_size)
self.attention = nn.Parameter(torch.FloatTensor(input_size, 1))
@property
def input_size(self) -> int:
return self._indim
@property
def output_size(self) -> int:
return self._indim * 2
[docs] def forward(self, xs, xs_len):
"""
Computes Attentive Statistics Pooling Module
Args:
xs (torch.Tensor): Input tensor (#batch, frames, input_size).
xs_len (torch.LongTensor): with the lengths for each sample
Returns:
torch.Tensor: Output tensor (#batch, input_size)
"""
pooled_list = []
for x, x_len in zip(xs, xs_len):
x = x[:x_len].unsqueeze(0)
h = torch.tanh(self.sap_linear(x))
w = torch.matmul(h, self.attention).squeeze(dim=2)
w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
mu = torch.sum(x * w, dim=1)
rh = torch.sqrt((torch.sum((x**2) * w, dim=1) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, rh), 1).squeeze(0)
pooled_list.append(x)
return torch.stack(pooled_list)