speech2text_ctc_task#
(s3prl.task.speech2text_ctc_task)
Speech2Text with CTC loss
- Authors
Heng-Jui Chang 2022
Speech2TextCTCExample#
- class s3prl.task.speech2text_ctc_task.Speech2TextCTCExample(input_size=3, output_size=4)[source][source]#
Bases:
Module
An example speech-to-text task with CTC objective
- Parameters:
input_size (int, optional) – Input size. Defaults to 3.
output_size (int, optional) – Output size. Defaults to 4.
Speech2TextCTCTask#
- class s3prl.task.speech2text_ctc_task.Speech2TextCTCTask(model: Module, tokenizer: Tokenizer, decoder: Optional[Union[BeamDecoder, dict]] = None, log_metrics: List[str] = ['cer', 'wer'])[source][source]#
Bases:
Task
Speech-to-text task with CTC objective
- Parameters:
model (Speech2TextCTCExample) –
tokenizer (Tokenizer) – Text tokenizer.
decoder (Union[BeamDecoder, dict], optional) – Beam decoder or decoder’s config. Defaults to None.
log_metrics (List[str], optional) – Metrics to be logged. Defaults to [“cer”, “wer”].
- predict(x: Tensor, x_len: LongTensor)[source][source]#
- Parameters:
x (torch.Tensor) – (batch_size, timestamps, input_size)
x_len (torch.LongTensor) – (batch_size, )
- Returns:
(batch_size, timestamps, output_size) prediction (list): prediction strings valid_length (torch.LongTensor): (batch_size, )
- Return type:
logits (torch.Tensor)
- forward(_mode: str, x: Tensor, x_len: LongTensor, labels: ndarray, class_ids: LongTensor, unique_name: ndarray, beam_decode: bool = False, _dump_dir: Optional[str] = None)[source][source]#
Each forward step in the training loop
- Parameters:
mode (str) – train / valid / test
x (torch.Tensor) – Input waveform or acoustic features. (batch_size, timestamps, input_size)
x_len (torch.LongTensor) – Lengths of inputs. (batch_size, )
labels (np.ndarray) – Ground truth transcriptions (str). (batch_size, )
class_ids (torch.LongTensor) – Tokenized ground truth transcriptions.
unique_name (np.ndarray) – Unique names for each sample.