Source code for s3prl.task.base

"""
The abstract Task

Authors
  * Leo 2022
"""

import abc
from collections import defaultdict
from typing import List

import torch

__all__ = ["Task"]


[docs]class Task(torch.nn.Module): def __init__(self) -> None: super().__init__()
[docs] def get_state(self): # self.model will be separately saved, do not save self.model.state_dict() here return {}
[docs] def set_state(self, state: dict): pass
[docs] def parse_cached_results(self, cached_results: List[dict]): keys = list(cached_results[0].keys()) dol = defaultdict(list) for d in cached_results: assert sorted(keys) == sorted(list(d.keys())) for k, v in d.items(): if isinstance(v, (tuple, list)): dol[k].extend(v) else: dol[k].append(v) return dict(dol)
[docs] @abc.abstractmethod def predict(self): raise NotImplementedError
[docs] def forward(self, mode: str, *args, **kwargs): return getattr(self, f"{mode}_step")(*args, **kwargs)
[docs] def reduction(self, mode: str, *args, **kwargs): return getattr(self, f"{mode}_reduction")(*args, **kwargs)
[docs] @abc.abstractmethod def train_step(self): raise NotImplementedError
[docs] @abc.abstractmethod def valid_step(self): raise NotImplementedError
[docs] @abc.abstractmethod def test_step(self): raise NotImplementedError
[docs] @abc.abstractmethod def train_reduction(self): raise NotImplementedError
[docs] @abc.abstractmethod def valid_reduction(self): raise NotImplementedError
[docs] @abc.abstractmethod def test_reduction(self): raise NotImplementedError