Source code for s3prl.nn.interface

"""
Model interfaces

Authors:
  * Leo 2022
"""

from typing import List, Tuple

import torch
import torch.nn as nn

__all__ = [
    "AbsUpstream",
    "AbsFeaturizer",
    "AbsFrameModel",
    "AbsUtteranceModel",
]


[docs]class AbsUpstream(nn.Module): """ The upstream model should follow this interface. Please subclass it. """ @property def num_layer(self) -> int: """ number of hidden states """ raise NotImplementedError @property def hidden_sizes(self) -> List[int]: """ hidden size of each hidden state """ raise NotImplementedError @property def downsample_rates(self) -> List[int]: """ downsample rate from 16 KHz waveforms for each hidden state """ raise NotImplementedError
[docs] def forward( self, wavs: torch.FloatTensor, wavs_len: torch.LongTensor ) -> Tuple[List[torch.FloatTensor], List[torch.LongTensor]]: """ Args: wavs (torch.FloatTensor): (batch_size, seq_len, 1) wavs_len (torch.LongTensor): (batch_size, ) Returns: tuple: 1. all_hs (List[torch.FloatTensor]): all the hidden states 2. all_hs_len (List[torch.LongTensor]): the lengths for all the hidden states """ raise NotImplementedError
[docs]class AbsFeaturizer(nn.Module): """ The featurizer should follow this interface. Please subclass it. The featurizer's mission is to reduce (standardize) the multiple hidden states from :obj:`AbsUpstream` into a single hidden state, so that the downstream model can use it as a conventional representation. """ @property def output_size(self) -> int: """ The output size after hidden states reduction """ raise NotImplementedError @property def downsample_rate(self) -> int: """ The downsample rate from 16 KHz waveform of the reduced single hidden state """ raise NotImplementedError
[docs] def forward( self, all_hs: List[torch.FloatTensor], all_hs_len: List[torch.LongTensor] ) -> Tuple[torch.FloatTensor, torch.LongTensor]: """ Args: all_hs (List[torch.FloatTensor]): all the hidden states all_hs_len (List[torch.LongTensor]): the lengths for all the hidden states Returns: tuple: 1. hs (torch.FloatTensor) 2. hs_len (torch.LongTensor) """ raise NotImplementedError
[docs]class AbsFrameModel(nn.Module): """ The frame-level model interface. """ @property def input_size(self) -> int: raise NotImplementedError @property def output_size(self) -> int: raise NotImplementedError
[docs] def forward( self, x: torch.FloatTensor, x_len: torch.LongTensor ) -> Tuple[torch.FloatTensor, torch.LongTensor]: raise NotImplementedError
[docs]class AbsUtteranceModel(nn.Module): """ The utterance-level model interface, which pools the temporal dimension. """ @property def input_size(self) -> int: raise NotImplementedError @property def output_size(self) -> int: raise NotImplementedError
[docs] def forward( self, x: torch.FloatTensor, x_len: torch.LongTensor ) -> torch.FloatTensor: raise NotImplementedError