Source code for s3prl.dataio.dataset.encode

from typing import List

import torch

from s3prl.dataio.encoder.category import CategoryEncoder, CategoryEncoders
from s3prl.dataio.encoder.tokenizer import Tokenizer

from . import Dataset

__all__ = [
    "EncodeCategory",
    "EncodeCategories",
    "EncodeMultiLabel",
    "EncodeText",
]


[docs]class EncodeCategory(Dataset): def __init__(self, labels: List[str], encoder: CategoryEncoder) -> None: super().__init__() self.labels = labels self.encoder = encoder def __len__(self): return len(self.labels) def __getitem__(self, index: int): label = self.labels[index] return { "label": label, "class_id": self.encoder.encode(label), }
[docs]class EncodeCategories(Dataset): def __init__(self, labels: List[List[str]], encoders: CategoryEncoders) -> None: super().__init__() self.labels = labels self.encoders = encoders def __len__(self): return len(self.labels) def __getitem__(self, index: int): labels = self.labels[index] return { "labels": labels, "class_ids": torch.LongTensor(self.encoders.encode(labels)), }
[docs]class EncodeMultiLabel(Dataset): def __init__(self, labels: List[List[str]], encoder: CategoryEncoder) -> None: super().__init__() self.labels = labels self.encoder = encoder def __len__(self): return len(self.labels)
[docs] @staticmethod def label_to_binary_vector(label_ids: List[int], num_labels: int) -> torch.Tensor: if len(label_ids) == 0: binary_labels = torch.zeros((num_labels,), dtype=torch.float) else: binary_labels = torch.zeros((num_labels,)).scatter( 0, torch.tensor(label_ids), 1.0 ) assert set(torch.where(binary_labels == 1.0)[0].numpy()) == set(label_ids) return binary_labels
def __getitem__(self, index: int): labels = self.labels[index] label_ids = [self.encoder.encode(label) for label in labels] binary_labels = self.label_to_binary_vector(label_ids, len(self.encoder)) return { "labels": labels, "binary_labels": binary_labels, }
[docs]class EncodeText(Dataset): def __init__( self, text: List[str], tokenizer: Tokenizer, iob: List[str] = None ) -> None: super().__init__() self.text = text self.iob = iob if iob is not None: assert len(text) == len(iob) self.tokenizer = tokenizer def __len__(self): return len(self.text) def __getitem__(self, index: int): text = self.text[index] if self.iob is not None: iob = self.iob[index] tokenized_ids = self.tokenizer.encode(text, iob) text = self.tokenizer.decode(tokenized_ids) else: tokenized_ids = self.tokenizer.encode(text) return { "labels": text, "class_ids": torch.LongTensor(tokenized_ids), }