Source code for s3prl.dataio.encoder.tokenizer

"""
Load tokenizer to encode & decode

Modified from tensorflow_datasets.features.text.*
Reference: https://www.tensorflow.org/datasets/api_docs/python/tfds/features/text_lib

Authors:
  * Heng-Jui Chang 2022
"""

import abc
import re
import tempfile
from typing import List

# Replacing the 2 tokens right before english starts as <eos> & <unk>
BERT_FIRST_IDX = 997
# Drop rest of tokens
BERT_LAST_IDX = 29635

# Default vocabularies
CHARACTER_VOCAB = list(" 'ABCDEFGHIJKLMNOPQRSTUVWXYZ")
PHONEME_VOCAB = "SIL SPN AA0 AA1 AA2 AE0 AE1 AE2 AH0 AH1 AH2 AO0 AO1 AO2 AW0 AW1 AW2 AY0 AY1 AY2 B CH D DH EH0 EH1 EH2 ER0 ER1 ER2 EY0 EY1 EY2 F G HH IH0 IH1 IH2 IY0 IY1 IY2 JH K L M N NG OW0 OW1 OW2 OY0 OY1 OY2 P R S SH T TH UH0 UH1 UH2 UW0 UW1 UW2 V W Y Z ZH".split(
    " "
)


__all__ = [
    "CharacterTokenizer",
    "CharacterSlotTokenizer",
    "SubwordTokenizer",
    "SubwordSlotTokenizer",
    "WordTokenizer",
    "PhonemeTokenizer",
    "load_tokenizer",
    "default_phoneme_tokenizer",
]


[docs]class Tokenizer: def __init__(self): super().__init__()
[docs] @abc.abstractmethod def encode(self, text: str, iob: str = None) -> List[int]: raise NotImplementedError
[docs] @abc.abstractmethod def decode(self, idxs: List[int], ignore_repeat: bool = False) -> str: raise NotImplementedError
def __len__(self): return self.vocab_size @abc.abstractproperty def vocab_size(self) -> int: raise NotImplementedError @abc.abstractproperty def token_type(self) -> str: raise NotImplementedError
[docs] @abc.abstractclassmethod def load_from_file(cls, vocab_file: str): raise NotImplementedError
@property def pad_idx(self) -> int: return 0 @property def eos_idx(self) -> int: return 1 @property def unk_idx(self) -> int: return 2 def __repr__(self) -> str: return "<{} vocab_size={}>".format(type(self).__name__, self.vocab_size)
[docs]class CharacterTokenizer(Tokenizer): """Character tokenizer.""" def __init__(self, vocab_list: List[str] = None): super().__init__() if vocab_list is None: vocab_list = CHARACTER_VOCAB for tok in ["<pad>", "<eos>", "<unk>"]: # Note that vocab_list must not contain <pad>, <eos> and <unk> assert tok not in vocab_list # <pad> = 0, <eos> = 1, <unk> = 2 self._vocab_list = ["<pad>", "<eos>", "<unk>"] + vocab_list self._vocab2idx = {v: idx for idx, v in enumerate(self._vocab_list)}
[docs] def encode(self, s: str) -> List[int]: # Always strip trailing space, \r and \n s = s.strip("\r\n ") # Manually append eos to the end return [self.vocab_to_idx(v) for v in s] + [self.eos_idx]
[docs] def decode(self, idxs: List[int], ignore_repeat: bool = False) -> str: vocabs = [] for t, idx in enumerate(idxs): v = self.idx_to_vocab(idx) if idx == self.pad_idx or (ignore_repeat and t > 0 and idx == idxs[t - 1]): continue elif idx == self.eos_idx: break else: vocabs.append(v) return "".join(vocabs)
[docs] @classmethod def load_from_file(cls, vocab_file: str = None, vocab_list: List[str] = None): if vocab_file is not None: with open(vocab_file, "r") as f: # Do not strip space because character based text encoder should # have a space token vocab_list = [line.strip("\r\n") for line in f] elif vocab_list is not None: pass else: raise ValueError( "No vocabulary information give, please specify either vocab_file or vocab_list." ) return cls(vocab_list)
@property def vocab_size(self) -> int: return len(self._vocab_list) @property def token_type(self) -> str: return "character"
[docs] def vocab_to_idx(self, vocab): return self._vocab2idx.get(vocab, self.unk_idx)
[docs] def idx_to_vocab(self, idx): return self._vocab_list[idx]
[docs]class CharacterSlotTokenizer(Tokenizer): """Character tokenizer with slots.""" def __init__(self, vocab_list: List[str], slots: List[str]): super().__init__() for tok in ["<pad>", "<eos>", "<unk>"]: # Note that vocab_list must not contain <pad>, <eos> and <unk> assert tok not in vocab_list # <pad> = 0, <eos> = 1, <unk> = 2 self._vocab_list = ["<pad>", "<eos>", "<unk>"] + vocab_list self._vocab2idx = {v: idx for idx, v in enumerate(self._vocab_list)} self.space_idx = self.vocab_to_idx(" ") self.slots = slots self.slot2id = { self.slots[i]: (i + len(self._vocab_list)) for i in range(len(self.slots)) } self.id2slot = { (i + len(self._vocab_list)): self.slots[i] for i in range(len(self.slots)) }
[docs] def encode(self, sent: str, iobs: str) -> List[int]: # Always strip trailing space, \r and \n sent = sent.strip("\r\n ") iobs = iobs.strip("\r\n ") sent = re.sub(" +", " ", sent).strip(" ") sent = sent.split(" ") iobs = iobs.split(" ") assert len(sent) == len( iobs ), f"transcription and iobs should have same number of words (split by space)" if sent[0] == "BOS": sent = sent[1:] iobs = iobs[1:] if sent[-1] == "EOS": sent = sent[:-1] iobs = iobs[:-1] tokens = [] for i, (wrd, iob) in enumerate(zip(sent, iobs)): if iob != "O" and (i == 0 or iobs[i - 1] != iob): tokens.append(self.slot2id["B-" + iob]) tokens += [self.vocab_to_idx(v) for v in wrd] if iob != "O" and (i == len(sent) - 1 or iobs[i + 1] != iob): tokens.append(self.slot2id["E-" + iob]) if i == (len(sent) - 1): tokens.append(self.eos_idx) else: if len(tokens) > 0 and tokens[-1] != self.space_idx: tokens.append(self.space_idx) assert tokens[-1] == self.eos_idx return tokens
[docs] def decode(self, idxs: List[int], ignore_repeat: bool = False) -> str: vocabs = [] for t, idx in enumerate(idxs): v = self.idx_to_vocab(idx) if idx == self.pad_idx or (ignore_repeat and t > 0 and idx == idxs[t - 1]): continue elif idx == self.eos_idx: break else: vocabs.append(v) return "".join(vocabs)
[docs] @classmethod def load_from_file(cls, vocab_file: str, slots_file: str): with open(vocab_file, "r") as f: # Do not strip space because character based text encoder should # have a space token vocab_list = [line.strip("\r\n") for line in f] org_slots = open(slots_file).read().split("\n") slots = [] for slot in [slot for slot in org_slots if slot != "O"]: slots.append("B-" + slot) slots.append("E-" + slot) return cls(vocab_list, slots)
@property def vocab_size(self) -> int: return len(self._vocab_list) + len(self.slots) @property def token_type(self) -> str: return "character-slot"
[docs] def vocab_to_idx(self, vocab): return self._vocab2idx.get(vocab, self.unk_idx)
[docs] def idx_to_vocab(self, idx): idx = int(idx) if idx < len(self._vocab_list): return self._vocab_list[idx] else: token = self.id2slot[idx] if token[0] == "B": return token + " " elif token[0] == "E": return " " + token else: raise ValueError("id2slot get:", token)
[docs]class SubwordTokenizer(Tokenizer): """Subword tokenizer using sentencepiece.""" def __init__(self, spm): super().__init__() if spm.pad_id() != 0 or spm.eos_id() != 1 or spm.unk_id() != 2: raise ValueError( "Please train sentencepiece model with following argument:\n" "--pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1 --model_type=bpe --eos_piece=<eos>" ) self.spm = spm
[docs] def encode(self, s: str) -> List[int]: tokens = self.spm.encode_as_ids(s) assert tokens[-1] == self.eos_idx return tokens
[docs] def decode(self, idxs: List[int], ignore_repeat: bool = False) -> str: crop_idx = [] for t, idx in enumerate(idxs): if idx == self.eos_idx: break elif idx == self.pad_idx or ( ignore_repeat and t > 0 and idx == idxs[t - 1] ): continue else: crop_idx.append(idx) return self.spm.decode_ids(crop_idx)
[docs] @classmethod def load_from_file(cls, filepath: str): import sentencepiece as splib spm = splib.SentencePieceProcessor() spm.load(filepath) spm.set_encode_extra_options("eos") return cls(spm)
def __setstate__(self, state): self.__dict__.update(state) self.spm.set_encode_extra_options("eos") @property def vocab_size(self) -> int: return len(self.spm) @property def token_type(self) -> str: return "subword"
[docs]class SubwordSlotTokenizer(Tokenizer): """Subword tokenizer with slots.""" def __init__(self, spm, slots): super().__init__() if spm.pad_id() != 0 or spm.eos_id() != 1 or spm.unk_id() != 2: raise ValueError( "Please train sentencepiece model with following argument:\n" "--pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1 --model_type=bpe --eos_piece=<eos>" ) self.spm = spm self.slots = slots self.slot2id = { self.slots[i]: (i + len(self.spm)) for i in range(len(self.slots)) } self.id2slot = { (i + len(self.spm)): self.slots[i] for i in range(len(self.slots)) }
[docs] def encode(self, sent: str, iobs: str) -> List[int]: # Always strip trailing space, \r and \n sent = sent.strip("\r\n ") iobs = iobs.strip("\r\n ") sent = re.sub(" +", " ", sent).strip(" ") sent = sent.split(" ") iobs = iobs.split(" ") assert len(sent) == len( iobs ), f"transcription and iobs should have same number of words (split by space)" if sent[0] == "BOS": sent = sent[1:] iobs = iobs[1:] if sent[-1] == "EOS": sent = sent[:-1] iobs = iobs[:-1] tokens = [] for i, (wrd, iob) in enumerate(zip(sent, iobs)): if iob != "O" and (i == 0 or iobs[i - 1] != iob): tokens.append(self.slot2id["B-" + iob]) encoded = self.spm.encode_as_ids(wrd) assert encoded[-1] == self.eos_idx tokens += encoded[:-1] # drop eos if iob != "O" and (i == len(sent) - 1 or iobs[i + 1] != iob): tokens.append(self.slot2id["E-" + iob]) assert tokens[-1] != self.eos_idx tokens.append(self.eos_idx) return tokens
[docs] def decode(self, idxs: List[int], ignore_repeat: bool = False) -> str: crop_idx = [] for t, idx in enumerate(idxs): if idx == self.eos_idx: break elif idx == self.pad_idx or ( ignore_repeat and t > 0 and idx == idxs[t - 1] ): continue else: crop_idx.append(idx) sent, ret = [], [] for i, x in enumerate(crop_idx): if x >= len(self.spm): # x is slot token slot = self.id2slot[x] ret.append(slot) if len(sent) > 0: decoded = self.spm.decode_ids(sent) ret.insert(-1, decoded) sent = [] else: # x is a regular token interpretable by spm sent.append(x) return " ".join(ret)
[docs] @classmethod def load_from_file(cls, filepath: str, slots_file: str): import sentencepiece as splib spm = splib.SentencePieceProcessor() spm.load(filepath) spm.set_encode_extra_options(":eos") org_slots = open(slots_file).read().split("\n") slots = [] for slot in [slot for slot in org_slots if slot != "O"]: slots.append("B-" + slot) slots.append("E-" + slot) return cls(spm, slots)
def __setstate__(self, state): self.__dict__.update(state) self.spm.set_encode_extra_options("eos") @property def vocab_size(self) -> int: return len(self.spm) + len(self.slots) @property def token_type(self) -> str: return "subword-slot"
[docs]class WordTokenizer(CharacterTokenizer): """Word tokenizer."""
[docs] def encode(self, s: str) -> List[int]: # Always strip trailing space, \r and \n s = s.strip("\r\n ") # Space as the delimiter between words words = s.split(" ") # Manually append eos to the end return [self.vocab_to_idx(v) for v in words] + [self.eos_idx]
[docs] def decode(self, idxs: List[int], ignore_repeat: bool = False) -> str: vocabs = [] for t, idx in enumerate(idxs): v = self.idx_to_vocab(idx) if idx == self.eos_idx: break elif idx == self.pad_idx or ( ignore_repeat and t > 0 and idx == idxs[t - 1] ): continue else: vocabs.append(v) return " ".join(vocabs)
@property def token_type(self) -> str: return "word"
[docs]class PhonemeTokenizer(WordTokenizer): """Phoneme tokenizer.""" @property def token_type(self) -> str: return "phoneme"
[docs]class BertTokenizer(Tokenizer): """Bert Tokenizer. https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/tokenization_bert.py """ def __init__(self, tokenizer): super().__init__() self._tokenizer = tokenizer self._tokenizer.pad_token = "<pad>" self._tokenizer.eos_token = "<eos>" self._tokenizer.unk_token = "<unk>"
[docs] def encode(self, s: str) -> List[int]: # Reduce vocab size manually reduced_idx = [] for idx in self._tokenizer.encode(s): try: r_idx = idx - BERT_FIRST_IDX assert r_idx > 0 reduced_idx.append(r_idx) except AssertionError: reduced_idx.append(self.unk_idx) reduced_idx.append(self.eos_idx) return reduced_idx
[docs] def decode(self, idxs: List[int], ignore_repeat: bool = False) -> str: crop_idx = [] for t, idx in enumerate(idxs): if idx == self.eos_idx: break elif idx == self.pad_idx or ( ignore_repeat and t > 0 and idx == idxs[t - 1] ): continue else: # Shift to correct idx for bert tokenizer crop_idx.append(idx + BERT_FIRST_IDX) return self._tokenizer.decode(crop_idx)
[docs] @classmethod def load_from_file(cls, vocab_file: str): from pytorch_transformers import BertTokenizer as bert_tokenizer return cls(bert_tokenizer.from_pretrained(vocab_file))
@property def vocab_size(self) -> int: return BERT_LAST_IDX - BERT_FIRST_IDX + 1 @property def token_type(self) -> str: return "bert"
[docs]def load_tokenizer( mode: str, vocab_file: str = None, vocab_list: List[str] = None, slots_file: str = None, ) -> Tokenizer: """Load a text tokenizer. Args: mode (str): Mode ("character", "character-slot", "subword", "subword-slot", "word", "bert-...") vocab_file (str, optional): Path to vocabularies. Defaults to None. vocab_list (List[str], optional): List of vocabularies. Defaults to None. slots_file (str, optional): Path to slots. Defaults to None. Raises: NotImplementedError: If mode is not implemented. Returns: Tokenizer: Text tokenizer. """ assert ( int(vocab_file is not None) + int(vocab_list is not None) <= 1 ), "For 'vocab_file' and 'vocab_list', at most one argument can be presented" with tempfile.NamedTemporaryFile("w") as f: if vocab_list is not None: f.writelines([f"{vocab}\n" for vocab in vocab_list]) f.flush() vocab_file = f.name if slots_file is not None and not mode.endswith("slot"): mode = f"{mode}-slot" if mode == "character": return CharacterTokenizer.load_from_file(vocab_file) elif mode == "character-slot": return CharacterSlotTokenizer.load_from_file(vocab_file, slots_file) elif mode == "subword": return SubwordTokenizer.load_from_file(vocab_file) elif mode == "subword-slot": return SubwordSlotTokenizer.load_from_file(vocab_file, slots_file) elif mode == "word": return WordTokenizer.load_from_file(vocab_file) elif mode == "phoneme": return PhonemeTokenizer.load_from_file(vocab_file) elif mode.startswith("bert-"): return BertTokenizer.load_from_file(mode) else: raise NotImplementedError("`{}` is not yet supported.".format(mode))
[docs]def default_phoneme_tokenizer() -> PhonemeTokenizer: """Returns a default LibriSpeech phoneme tokenizer. Returns: PhonemeTokenizer: Vocabs include 71 phonemes """ return PhonemeTokenizer.load_from_file(vocab_list=PHONEME_VOCAB)