Source code for s3prl.task.speech2text_ctc_task

"""
Speech2Text with CTC loss

Authors
  * Heng-Jui Chang 2022
"""

import logging
from pathlib import Path
from typing import List, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from s3prl.dataio.encoder.tokenizer import Tokenizer
from s3prl.metric import cer, per, wer
from s3prl.metric.slot_filling import (
    slot_edit_f1_full,
    slot_edit_f1_part,
    slot_type_f1,
    slot_value_cer,
    slot_value_wer,
)
from s3prl.nn import BeamDecoder

from . import Task

logger = logging.getLogger(__name__)

__all__ = [
    "Speech2TextCTCExample",
    "Speech2TextCTCTask",
]


[docs]class Speech2TextCTCExample(nn.Module): """An example speech-to-text task with CTC objective Args: input_size (int, optional): Input size. Defaults to 3. output_size (int, optional): Output size. Defaults to 4. """ def __init__(self, input_size=3, output_size=4): super().__init__() self._input_size = input_size self._output_size = output_size @property def input_size(self): return self._input_size @property def output_size(self): return self._output_size
[docs] def forward(self, x, x_len): """ Args: x (torch.Tensor): (batch_size, timestemps, input_size) x_len (torch.LongTensor): (batch_size, ) Return: y (torch.Tensor): (batch_size, output_size) y_len (torch.LongTensor): (batch_size) """ assert x.size(-1) == self.input_size output = torch.randn(x.size(0), x.size(1), self.output_size) assert output, x_len
[docs]class Speech2TextCTCTask(Task): """Speech-to-text task with CTC objective Args: model (Speech2TextCTCExample) tokenizer (Tokenizer): Text tokenizer. decoder (Union[BeamDecoder, dict], optional): Beam decoder or decoder's config. Defaults to None. log_metrics (List[str], optional): Metrics to be logged. Defaults to ["cer", "wer"]. """ def __init__( self, model: torch.nn.Module, tokenizer: Tokenizer, decoder: Union[BeamDecoder, dict] = None, log_metrics: List[str] = ["cer", "wer"], ) -> None: super().__init__() self.model = model assert isinstance(tokenizer, Tokenizer) self.tokenizer = tokenizer self.log_metrics = log_metrics if BeamDecoder is None: decoder = None if isinstance(decoder, dict): decoder = BeamDecoder(**decoder) logger.info("Using flashlight decoder.") self.decoder = decoder self.criterion = nn.CTCLoss( blank=self.tokenizer.pad_idx, zero_infinity=True, )
[docs] def predict(self, x: torch.Tensor, x_len: torch.LongTensor): """ Args: x (torch.Tensor): (batch_size, timestamps, input_size) x_len (torch.LongTensor): (batch_size, ) Return: logits (torch.Tensor): (batch_size, timestamps, output_size) prediction (list): prediction strings valid_length (torch.LongTensor): (batch_size, ) """ logits, x_len = self.model(x, x_len) predicted_tokens = torch.argmax(logits, dim=2).detach().cpu() filtered_tokens = [ [ token for token in pred_token.unique_consecutive().tolist() if token != self.tokenizer.pad_idx and token != self.tokenizer.eos_idx ] for pred_token in predicted_tokens ] predictions = [ self.tokenizer.decode(token_list) for token_list in filtered_tokens ] return logits, predictions, x_len
[docs] def forward( self, _mode: str, x: torch.Tensor, x_len: torch.LongTensor, labels: np.ndarray, class_ids: torch.LongTensor, unique_name: np.ndarray, beam_decode: bool = False, _dump_dir: str = None, ): """ Each forward step in the training loop Args: mode (str): train / valid / test x (torch.Tensor): Input waveform or acoustic features. (batch_size, timestamps, input_size) x_len (torch.LongTensor): Lengths of inputs. (batch_size, ) labels (np.ndarray): Ground truth transcriptions (str). (batch_size, ) class_ids (torch.LongTensor): Tokenized ground truth transcriptions. unique_name (np.ndarray): Unique names for each sample. """ logits, prediction, x_len = self.predict(x, x_len) log_probs = F.log_softmax(logits, dim=2) y = class_ids y_len = torch.tensor( [(ids != self.tokenizer.pad_idx).long().sum() for ids in class_ids], dtype=torch.long, device=logits.device, ) loss = self.criterion(log_probs.transpose(0, 1), y, x_len, y_len) hyps = None if beam_decode and self.decoder is not None: hyps = self.decoder.decode(log_probs.detach()) cacheable = dict( loss=loss.detach().cpu().item(), prediction=prediction, label=labels.tolist(), unique_name=unique_name.tolist(), hypotheses=hyps, ) return loss, cacheable
[docs] def reduction(self, _mode: str, cached_results: List[dict], _dump_dir: str = None): results = self.parse_cached_results(cached_results) losses = results["loss"] predictions = results["prediction"] labels = results["label"] unique_names = results["unique_name"] if _dump_dir is not None: with (Path(_dump_dir) / "hyp").open("w") as f: f.writelines( [f"{uid} {p}\n" for p, uid in zip(predictions, unique_names)] ) with (Path(_dump_dir) / "ref").open("w") as f: f.writelines([f"{uid} {p}\n" for p, uid in zip(labels, unique_names)]) beam_hyps = None if results["hypotheses"][0] is not None: beam_hyps = [" ".join(hyp[0].words) for hyp in results["hypotheses"]] logs = {} logs["loss"] = float(np.mean(losses)) if "wer" in self.log_metrics: logs["wer"] = wer(predictions, labels) if "cer" in self.log_metrics: logs["cer"] = cer(predictions, labels) if "per" in self.log_metrics: logs["per"] = per(predictions, labels) if "slot_type_f1" in self.log_metrics: logs["slot_type_f1"] = slot_type_f1(predictions, labels) if "slot_value_cer" in self.log_metrics: logs["slot_value_cer"] = slot_value_cer(predictions, labels) if "slot_value_wer" in self.log_metrics: logs["slot_value_wer"] = slot_value_wer(predictions, labels) if "slot_edit_f1_full" in self.log_metrics: logs["slot_edit_f1_full"] = slot_edit_f1_full(predictions, labels) if "slot_edit_f1_part" in self.log_metrics: logs["slot_edit_f1_part"] = slot_edit_f1_part(predictions, labels) if beam_hyps is not None: logs["wer_beam"] = wer(beam_hyps, labels) logs["char_beam"] = cer(beam_hyps, labels) return logs