speech2text_ctc_task#

(s3prl.task.speech2text_ctc_task)

Speech2Text with CTC loss

Authors
  • Heng-Jui Chang 2022

Speech2TextCTCExample#

class s3prl.task.speech2text_ctc_task.Speech2TextCTCExample(input_size=3, output_size=4)[source][source]#

Bases: Module

An example speech-to-text task with CTC objective

Parameters:
  • input_size (int, optional) – Input size. Defaults to 3.

  • output_size (int, optional) – Output size. Defaults to 4.

property input_size[source]#
property output_size[source]#
forward(x, x_len)[source][source]#
Parameters:
  • x (torch.Tensor) – (batch_size, timestemps, input_size)

  • x_len (torch.LongTensor) – (batch_size, )

Returns:

(batch_size, output_size) y_len (torch.LongTensor): (batch_size)

Return type:

y (torch.Tensor)

call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
training: bool[source]#

Speech2TextCTCTask#

class s3prl.task.speech2text_ctc_task.Speech2TextCTCTask(model: Module, tokenizer: Tokenizer, decoder: Optional[Union[BeamDecoder, dict]] = None, log_metrics: List[str] = ['cer', 'wer'])[source][source]#

Bases: Task

Speech-to-text task with CTC objective

Parameters:
  • model (Speech2TextCTCExample) –

  • tokenizer (Tokenizer) – Text tokenizer.

  • decoder (Union[BeamDecoder, dict], optional) – Beam decoder or decoder’s config. Defaults to None.

  • log_metrics (List[str], optional) – Metrics to be logged. Defaults to [“cer”, “wer”].

predict(x: Tensor, x_len: LongTensor)[source][source]#
Parameters:
  • x (torch.Tensor) – (batch_size, timestamps, input_size)

  • x_len (torch.LongTensor) – (batch_size, )

Returns:

(batch_size, timestamps, output_size) prediction (list): prediction strings valid_length (torch.LongTensor): (batch_size, )

Return type:

logits (torch.Tensor)

forward(_mode: str, x: Tensor, x_len: LongTensor, labels: ndarray, class_ids: LongTensor, unique_name: ndarray, beam_decode: bool = False, _dump_dir: Optional[str] = None)[source][source]#

Each forward step in the training loop

Parameters:
  • mode (str) – train / valid / test

  • x (torch.Tensor) – Input waveform or acoustic features. (batch_size, timestamps, input_size)

  • x_len (torch.LongTensor) – Lengths of inputs. (batch_size, )

  • labels (np.ndarray) – Ground truth transcriptions (str). (batch_size, )

  • class_ids (torch.LongTensor) – Tokenized ground truth transcriptions.

  • unique_name (np.ndarray) – Unique names for each sample.

reduction(_mode: str, cached_results: List[dict], _dump_dir: Optional[str] = None)[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
get_state()[source]#
parse_cached_results(cached_results: List[dict])[source]#
set_state(state: dict)[source]#
abstract test_reduction()[source]#
abstract test_step()[source]#
abstract train_reduction()[source]#
abstract train_step()[source]#
abstract valid_reduction()[source]#
abstract valid_step()[source]#
training: bool[source]#