Source code for s3prl.dataio.encoder.vocabulary

"""
Create vocabulary (train tokenizer)

Authors:
  * Heng-Jui Chang 2022
"""

import logging
import os
import tempfile
from collections import Counter
from typing import List, Union

logger = logging.getLogger(__name__)

__all__ = ["generate_basic_vocab", "generate_subword_vocab", "generate_vocab"]


[docs]def generate_basic_vocab( mode: str, text_list: List[str], vocab_size: int = -1, coverage: float = 1.0, sort_vocab: bool = True, ) -> List[str]: """Generates basic vocabularies, including character and word-based vocabularies. Args: mode (str): Vocabulary type (character or word). text_list (List[str]): List of text data. vocab_size (int, optional): Vocabulary size, if not specified, vocab_size would be `coverage * actual vocab size`. Defaults to -1. coverage (float, optional): Vocabulary coverage. Defaults to 1.0. sort_vocab (bool, optional): Sort vocabularies alphabetically. Defaults to True. Returns: List[str]: A list of vocabularies. """ assert mode in {"character", "word"}, mode assert vocab_size == -1 or vocab_size > 0, vocab_size assert coverage > 0.0 and coverage <= 1.0, coverage logger.info( f"Generating vocab (type = {mode}, coverage = {coverage}) from {len(text_list)} sentences." ) counter = Counter() for text in text_list: if mode == "character": counter.update(text) if mode == "word": counter.update(text.split()) if vocab_size < 0: vocab_size = int(len(counter) * coverage) else: vocab_size = min(vocab_size, len(counter)) if vocab_size < len(counter): vocab_list = sorted(counter.keys(), key=lambda k: counter[k], reverse=True) vocab_list = vocab_list[:vocab_size] else: vocab_list = list(counter.keys()) if sort_vocab: vocab_list = sorted(vocab_list) logger.info(f"Generated {vocab_size} {mode} vocabularies.") return vocab_list
[docs]def generate_subword_vocab( text_list: List[str] = None, text_file: str = None, output_file: str = None, vocab_size: int = 1000, character_coverage: float = 1.0, ) -> str: """Generates subword vocabularies based on `sentencepiece`. Args: text_list (List[str], optional): List of text data. Defaults to None. text_file (str, optional): Path to text data. Defaults to None. output_file (str, optional): Path to save trained subword vocabularies. Defaults to "". vocab_size (int, optional): Vocabulary size. Defaults to 8000. character_coverage (float, optional): Coverage of characters in text data. Defaults to 1.0. Raises: ImportError: If `sentencepiece` is not installed. Returns: str: Path to `${output_file}.model`. """ try: import sentencepiece as splib except ImportError: raise ImportError( "`sentencepiece` cannot be imported, please run `pip install sentencepiece` first" ) assert output_file is not None output_file = str(output_file) assert vocab_size > 0, vocab_size cmd = ( "--input={} --model_prefix={} --model_type=unigram " "--vocab_size={} --character_coverage={} " "--pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1 " "--eos_piece=<eos> --remove_extra_whitespaces=true " ) if text_list is not None: assert isinstance(text_list, list) assert isinstance(text_list[0], str) logger.info( f"Generating vocab (type = subword, coverage = {character_coverage}) from {len(text_list)} sentences." ) with tempfile.TemporaryDirectory() as directory: input_file = os.path.join(directory, "text.txt") with open(input_file, "w") as fp: for text in text_list: fp.write(text + "\n") cmd = cmd.format( input_file, output_file, vocab_size, character_coverage, ) splib.SentencePieceTrainer.Train(cmd) if text_file is not None: logger.info( f"Generating vocab (type = subword, coverage = {character_coverage}) from {text_file}" ) cmd = cmd.format( text_file, output_file, vocab_size, character_coverage, ) splib.SentencePieceTrainer.Train(cmd) return output_file + ".model"
[docs]def generate_vocab( mode: str, text_list: List[str] = None, text_file: str = None, read_lines: int = 10000000, **vocab_args, ) -> Union[List[str], str]: """Generates vocabularies given text data. Args: mode (str): Vocabulary type text_list (List[str], optional): List of text data. Defaults to None. text_file (str, optional): Path to text data. Defaults to None. read_lines (int, optional): Maximum lines to read from `text_file`. Defaults to 10000000. vocab_args: if :code:`mode != subword`, arguments for :obj:`generate_basic_vocab` if :code:`mode == subword`, arguments for :obj:`generate_subword_vocab` Returns: Union[List[str], str]: A list of vocabularies or a path to `.vocab` file. """ if text_list is None and mode in {"character", "word", "phoneme"}: assert isinstance(text_file, str) with open(text_file, "r", encoding="UTF-8") as fp: text_list = [ line.strip("\r\n ") for i, line in enumerate(fp) if i < read_lines ] if mode == "character": return generate_basic_vocab("character", text_list, **vocab_args) if mode in {"word", "phoneme"}: return generate_basic_vocab("word", text_list, **vocab_args) if mode == "subword": return generate_subword_vocab( text_list=text_list, text_file=text_file, **vocab_args ) else: raise ValueError(f"Unsupported mode (vocabulary type): {mode}")