task#

(s3prl.task)

Define how a model is trained & evaluated for each step in the train/valid/test loop

s3prl.task.base

The abstract Task

s3prl.task.diarization

Diarization Permutation Invarant Task

s3prl.task.dump_feature

Dump feature Task

s3prl.task.event_prediction

s3prl.task.scene_prediction

s3prl.task.speaker_verification_task

Speaker Verification with Softmax-based loss

s3prl.task.speech2text_ctc_task

Speech2Text with CTC loss

s3prl.task.utterance_classification_task

Utterance Classification Tasks

Task#

class s3prl.task.Task[source][source]#

Bases: Module

get_state()[source][source]#
set_state(state: dict)[source][source]#
parse_cached_results(cached_results: List[dict])[source][source]#
abstract predict()[source][source]#
forward(mode: str, *args, **kwargs)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reduction(mode: str, *args, **kwargs)[source][source]#
abstract train_step()[source][source]#
abstract valid_step()[source][source]#
abstract test_step()[source][source]#
abstract train_reduction()[source][source]#
abstract valid_reduction()[source][source]#
abstract test_reduction()[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
training: bool[source]#

DiarizationPIT#

class s3prl.task.DiarizationPIT(model: Module)[source][source]#

Bases: Task

predict(x, x_len)[source][source]#
forward(_mode: str, x, x_len, label, label_len, record_id: str, chunk_id: int, _dump_dir: Optional[str] = None)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reduction(_mode: str, cached_results: List[dict], _dump_dir: Optional[str] = None)[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
get_state()[source]#
parse_cached_results(cached_results: List[dict])[source]#
set_state(state: dict)[source]#
abstract test_reduction()[source]#
abstract test_step()[source]#
abstract train_reduction()[source]#
abstract train_step()[source]#
abstract valid_reduction()[source]#
abstract valid_step()[source]#
training: bool[source]#

DumpFeature#

class s3prl.task.DumpFeature(model: Module, dump_feat_dir: str = 'feat')[source][source]#

Bases: Task

forward(split: str, x, x_len, unique_name, _dump_dir: str)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reduction(split: str, batch_results: list, _dump_dir: str)[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
get_state()[source]#
parse_cached_results(cached_results: List[dict])[source]#
abstract predict()[source]#
set_state(state: dict)[source]#
abstract test_reduction()[source]#
abstract test_step()[source]#
abstract train_reduction()[source]#
abstract train_step()[source]#
abstract valid_reduction()[source]#
abstract valid_step()[source]#
training: bool[source]#

SpeakerVerification#

class s3prl.task.SpeakerVerification(model: SpeakerClassifier, category: CategoryEncoder, test_trials: Optional[List[Tuple[int, str, str]]] = None, loss_type: str = 'amsoftmax', loss_conf: Optional[dict] = None)[source][source]#

Bases: Task

model.output_size should match len(categories)

Parameters:
  • model (SpeakerClassifier) – actual model or a callable config for the model

  • categories (dict[str]) – each key in the Dictionary is the final prediction content in str. use categories[key] to encode as numeric label

  • test_trials (List[Tuple[int, str, str]]) – each tuple in the list consists of (label, enroll_utt, test_utt)

  • loss_type (str) – softmax or amsoftmax

  • loss_conf (dict) – arguments for the loss_type class

get_state()[source][source]#
set_state(state: dict)[source][source]#
predict(x: Tensor, x_len: LongTensor)[source][source]#
Parameters:
  • x (torch.Tensor) – (batch_size, timestamps, input_size)

  • x_len (torch.LongTensor) – (batch_size, )

Returns:

torch.Tensor

(batch_size, output_size)

train_step(x: Tensor, x_len: LongTensor, class_id: LongTensor, unique_name: List[str], _dump_dir: Optional[str] = None)[source][source]#
train_reduction(cached_results: list, _dump_dir: Optional[str] = None)[source][source]#
test_step(x: Tensor, x_len: LongTensor, unique_name: List[str], _dump_dir: str)[source][source]#
Parameters:
  • x (torch.Tensor) – (batch_size, timestamps, input_size)

  • x_len – torch.LongTensor

  • unique_name (List[str]) –

Returns:

unique_name (List[str]) output (torch.Tensor):

speaker embeddings corresponding to unique_name

test_reduction(cached_results: List[dict], _dump_dir: str)[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
forward(mode: str, *args, **kwargs)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

parse_cached_results(cached_results: List[dict])[source]#
reduction(mode: str, *args, **kwargs)[source]#
abstract valid_reduction()[source]#
abstract valid_step()[source]#
training: bool[source]#

Speech2TextCTCTask#

class s3prl.task.Speech2TextCTCTask(model: Module, tokenizer: Tokenizer, decoder: Optional[Union[BeamDecoder, dict]] = None, log_metrics: List[str] = ['cer', 'wer'])[source][source]#

Bases: Task

Speech-to-text task with CTC objective

Parameters:
  • model (Speech2TextCTCExample) –

  • tokenizer (Tokenizer) – Text tokenizer.

  • decoder (Union[BeamDecoder, dict], optional) – Beam decoder or decoder’s config. Defaults to None.

  • log_metrics (List[str], optional) – Metrics to be logged. Defaults to [“cer”, “wer”].

predict(x: Tensor, x_len: LongTensor)[source][source]#
Parameters:
  • x (torch.Tensor) – (batch_size, timestamps, input_size)

  • x_len (torch.LongTensor) – (batch_size, )

Returns:

(batch_size, timestamps, output_size) prediction (list): prediction strings valid_length (torch.LongTensor): (batch_size, )

Return type:

logits (torch.Tensor)

forward(_mode: str, x: Tensor, x_len: LongTensor, labels: ndarray, class_ids: LongTensor, unique_name: ndarray, beam_decode: bool = False, _dump_dir: Optional[str] = None)[source][source]#

Each forward step in the training loop

Parameters:
  • mode (str) – train / valid / test

  • x (torch.Tensor) – Input waveform or acoustic features. (batch_size, timestamps, input_size)

  • x_len (torch.LongTensor) – Lengths of inputs. (batch_size, )

  • labels (np.ndarray) – Ground truth transcriptions (str). (batch_size, )

  • class_ids (torch.LongTensor) – Tokenized ground truth transcriptions.

  • unique_name (np.ndarray) – Unique names for each sample.

reduction(_mode: str, cached_results: List[dict], _dump_dir: Optional[str] = None)[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
get_state()[source]#
parse_cached_results(cached_results: List[dict])[source]#
set_state(state: dict)[source]#
abstract test_reduction()[source]#
abstract test_step()[source]#
abstract train_reduction()[source]#
abstract train_step()[source]#
abstract valid_reduction()[source]#
abstract valid_step()[source]#
training: bool[source]#

UtteranceClassificationTask#

class s3prl.task.UtteranceClassificationTask(model: UtteranceClassifierExample, category: CategoryEncoder)[source][source]#

Bases: Task

input_size[source]#

defined by model.input_size

Type:

int

output_size[source]#

defined by len(categories)

Type:

int

predict(x: Tensor, x_len: LongTensor)[source][source]#
Parameters:
  • x (torch.Tensor) – (batch_size, timestamps, input_size)

  • x_len (torch.LongTensor) – (batch_size, )

Returns:

(batch_size, output_size) prediction (list): prediction strings

Return type:

logits (torch.Tensor)

forward(_mode: str, x: Tensor, x_len: LongTensor, class_id: LongTensor, label: List[str], unique_name: List[str], _dump_dir: Optional[str] = None)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reduction(_mode: str, cached_results: List[dict], _dump_dir: Optional[str] = None)[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
get_state()[source]#
parse_cached_results(cached_results: List[dict])[source]#
set_state(state: dict)[source]#
abstract test_reduction()[source]#
abstract test_step()[source]#
abstract train_reduction()[source]#
abstract train_step()[source]#
abstract valid_reduction()[source]#
abstract valid_step()[source]#
training: bool[source]#

UtteranceMultiClassClassificationTask#

class s3prl.task.UtteranceMultiClassClassificationTask(model: UtteranceClassifierExample, categories: CategoryEncoders)[source][source]#

Bases: Task

predict(x: Tensor, x_len: LongTensor)[source][source]#
Parameters:
  • x (torch.Tensor) – (batch_size, timestamps, input_size)

  • x_len (torch.LongTensor) – (batch_size, )

Returns:

List[(batch_size, sub_output_size)] prediction (np.array): (batch_size, num_category)

Return type:

logit (torch.Tensor)

forward(_mode: str, x: Tensor, x_len: LongTensor, class_ids: LongTensor, labels: ndarray, unique_name: List[str], _dump_dir: Optional[str] = None)[source][source]#
Parameters:
  • x – torch.Tensor, (batch_size, timestamps, input_size)

  • x_len – torch.LongTensor, (batch_size)

  • class_ids – torch.LongTensor, (batch_size, num_category)

  • labels – np.ndarray, (batch_size, num_category)

Returns:

torch.Tensor prediction: np.ndarray label: np.ndarray

Return type:

loss

reduction(_mode: str, cached_results: List[dict], _dump_dir: Optional[str] = None)[source][source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
get_state()[source]#
parse_cached_results(cached_results: List[dict])[source]#
set_state(state: dict)[source]#
abstract test_reduction()[source]#
abstract test_step()[source]#
abstract train_reduction()[source]#
abstract train_step()[source]#
abstract valid_reduction()[source]#
abstract valid_step()[source]#
training: bool[source]#