from typing import List
import torch
from s3prl.dataio.encoder.category import CategoryEncoder, CategoryEncoders
from s3prl.dataio.encoder.tokenizer import Tokenizer
from . import Dataset
__all__ = [
"EncodeCategory",
"EncodeCategories",
"EncodeMultiLabel",
"EncodeText",
]
[docs]class EncodeCategory(Dataset):
def __init__(self, labels: List[str], encoder: CategoryEncoder) -> None:
super().__init__()
self.labels = labels
self.encoder = encoder
def __len__(self):
return len(self.labels)
def __getitem__(self, index: int):
label = self.labels[index]
return {
"label": label,
"class_id": self.encoder.encode(label),
}
[docs]class EncodeCategories(Dataset):
def __init__(self, labels: List[List[str]], encoders: CategoryEncoders) -> None:
super().__init__()
self.labels = labels
self.encoders = encoders
def __len__(self):
return len(self.labels)
def __getitem__(self, index: int):
labels = self.labels[index]
return {
"labels": labels,
"class_ids": torch.LongTensor(self.encoders.encode(labels)),
}
[docs]class EncodeMultiLabel(Dataset):
def __init__(self, labels: List[List[str]], encoder: CategoryEncoder) -> None:
super().__init__()
self.labels = labels
self.encoder = encoder
def __len__(self):
return len(self.labels)
[docs] @staticmethod
def label_to_binary_vector(label_ids: List[int], num_labels: int) -> torch.Tensor:
if len(label_ids) == 0:
binary_labels = torch.zeros((num_labels,), dtype=torch.float)
else:
binary_labels = torch.zeros((num_labels,)).scatter(
0, torch.tensor(label_ids), 1.0
)
assert set(torch.where(binary_labels == 1.0)[0].numpy()) == set(label_ids)
return binary_labels
def __getitem__(self, index: int):
labels = self.labels[index]
label_ids = [self.encoder.encode(label) for label in labels]
binary_labels = self.label_to_binary_vector(label_ids, len(self.encoder))
return {
"labels": labels,
"binary_labels": binary_labels,
}
[docs]class EncodeText(Dataset):
def __init__(
self, text: List[str], tokenizer: Tokenizer, iob: List[str] = None
) -> None:
super().__init__()
self.text = text
self.iob = iob
if iob is not None:
assert len(text) == len(iob)
self.tokenizer = tokenizer
def __len__(self):
return len(self.text)
def __getitem__(self, index: int):
text = self.text[index]
if self.iob is not None:
iob = self.iob[index]
tokenized_ids = self.tokenizer.encode(text, iob)
text = self.tokenizer.decode(tokenized_ids)
else:
tokenized_ids = self.tokenizer.encode(text)
return {
"labels": text,
"class_ids": torch.LongTensor(tokenized_ids),
}