rnn#

(s3prl.nn.rnn)

RNN models used in Superb Benchmark

Authors:
  • Heng-Jui Chang 2022

  • Leo 2022

RNNEncoder#

class s3prl.nn.rnn.RNNEncoder(input_size: int, output_size: int, module: str = 'LSTM', proj_size: int = 1024, hidden_size: List[int] = [1024], dropout: List[float] = [0.0], layer_norm: List[bool] = [False], proj: List[bool] = [True], sample_rate: List[int] = [1], sample_style: str = 'drop', bidirectional: bool = False)[source][source]#

Bases: AbsFrameModel

RNN Encoder for sequence to sequence modeling, e.g., ASR.

Parameters:
  • input_size (int) – Input size.

  • output_size (int) – Output size.

  • module (str, optional) – RNN module type. Defaults to “LSTM”.

  • hidden_size (List[int], optional) – Hidden sizes for each layer. Defaults to [1024].

  • dropout (List[float], optional) – Dropout rates for each layer. Defaults to [0.0].

  • layer_norm (List[bool], optional) – Whether to use layer norm for each layer. Defaults to [False].

  • proj (List[bool], optional) – Whether to use projection for each layer. Defaults to [True].

  • sample_rate (List[int], optional) – Downsample rates for each layer. Defaults to [1].

  • sample_style (str, optional) – Downsample style (“drop” or “concat”). Defaults to “drop”.

  • bidirectional (bool, optional) – Whether RNN layers are bidirectional. Defaults to False.

forward(x: Tensor, x_len: LongTensor)[source][source]#
Parameters:
  • xs (torch.FloatTensor) – (batch_size, seq_len, input_size)

  • xs_len (torch.LongTensor) – (batch_size, )

Returns:

  1. ys (torch.FloatTensor): (batch_size, seq_len, output_size)

  2. ys_len (torch.LongTensor): (batch_size, )

Return type:

tuple

property input_size: int[source]#
property output_size: int[source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
training: bool[source]#

SuperbDiarizationModel#

class s3prl.nn.rnn.SuperbDiarizationModel(input_size: int, output_size: int, rnn_layers: int, hidden_size: int)[source][source]#

Bases: AbsFrameModel

The exact RNN model used in SUPERB Benchmark for Speaker Diarization

Parameters:
  • input_size (int) – input_size

  • output_size (int) – output_size

  • rnn_layers (int) – number of rnn layers

  • hidden_size (int) – the hidden size across all rnn layers

property input_size: int[source]#
property output_size: int[source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
forward(xs, xs_len)[source][source]#
Parameters:
  • xs (torch.FloatTensor) – (batch_size, seq_len, input_size)

  • xs_len (torch.LongTensor) – (batch_size, )

Returns:

  1. ys (torch.FloatTensor): (batch_size, seq_len, output_size)

  2. ys_len (torch.LongTensor): (batch_size, )

Return type:

tuple

training: bool[source]#

RNNLayer#

class s3prl.nn.rnn.RNNLayer(input_size: int, hidden_size: int, module: str, dropout: float = 0.0, bidirectional: bool = False, proj: bool = False, layer_norm: bool = False, sample_rate: int = 1, sample_style: str = 'drop')[source][source]#

Bases: Module

RNN Layer

Parameters:
  • input_size (int) – Input size.

  • hidden_size (int) – Hidden size.

  • module (str) – RNN module (RNN, GRU, LSTM)

  • dropout (float, optional) – Dropout rate. Defaults to 0.0.

  • bidirectional (bool, optional) – Bidirectional. Defaults to False.

  • proj (bool, optional) – Projection layer. Defaults to False.

  • layer_norm (bool, optional) – Layer normalization. Defaults to False.

  • sample_rate (int, optional) – Downsampling rate. Defaults to 1.

  • sample_style (str, optional) – Downsampling style (drop or concat). Defaults to “drop”.

forward(xs: Tensor, xs_len: LongTensor)[source][source]#
Parameters:
  • xs (torch.FloatTensor) – (batch_size, seq_len, input_size)

  • xs_len (torch.LongTensor) – (batch_size, )

Returns:

  1. ys (torch.FloatTensor): (batch_size, seq_len, output_size)

  2. ys_len (torch.LongTensor): (batch_size, )

Return type:

tuple

property input_size: int[source]#
property output_size: int[source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
training: bool[source]#