Source code for joeynmt.tokenizers

# coding: utf-8
"""
Tokenizer module
"""
import argparse
import shutil
from pathlib import Path
from typing import Callable, Dict, List, Union

import numpy as np
import sentencepiece as sp
from sacrebleu.metrics.bleu import _get_tokenizer
from subword_nmt import apply_bpe

from joeynmt.config import ConfigurationError
from joeynmt.data_augmentation import CMVN, SpecAugment
from joeynmt.helpers import remove_extra_spaces, remove_punctuation, unicode_normalize
from joeynmt.helpers_for_audio import get_features
from joeynmt.helpers_for_ddp import get_logger

logger = get_logger(__name__)


[docs] class BasicTokenizer: # pylint: disable=too-many-instance-attributes SPACE = chr(32) # ' ': half-width white space (ascii) SPACE_ESCAPE = chr(9601) # '▁': sentencepiece default def __init__( self, level: str = "word", lowercase: bool = False, normalize: bool = False, max_length: int = -1, min_length: int = -1, **kwargs, ): # pylint: disable=unused-argument self.level = level self.lowercase = lowercase self.normalize = normalize # filter by length self.max_length = max_length self.min_length = min_length # pretokenizer self.pretokenizer = kwargs.get("pretokenizer", "none").lower() assert self.pretokenizer in ["none", "moses"], \ "Currently, we support moses tokenizer only." # sacremoses if self.pretokenizer == "moses": try: from sacremoses import ( # pylint: disable=import-outside-toplevel MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer, ) # sacremoses package has to be installed. # https://github.com/alvations/sacremoses except ImportError as e: logger.error(e) raise ImportError from e self.lang = kwargs.get("lang", "en") self.moses_tokenizer = MosesTokenizer(lang=self.lang) self.moses_detokenizer = MosesDetokenizer(lang=self.lang) if self.normalize: self.moses_normalizer = MosesPunctNormalizer()
[docs] def pre_process(self, raw_input: str, allow_empty: bool = False) -> str: """ Pre-process text - ex.) Lowercase, Normalize, Remove emojis, Pre-tokenize(add extra white space before punc) etc. - applied for all inputs both in training and inference. :param raw_input: raw input string :param allow_empty: whether to allow empty string :return: preprocessed input string """ if not allow_empty: assert isinstance(raw_input, str) and raw_input.strip() != "", \ "The input sentence is empty! Please make sure " \ "that you are feeding a valid input." if self.normalize: raw_input = remove_extra_spaces(unicode_normalize(raw_input)) if self.pretokenizer == "moses": if self.normalize: raw_input = self.moses_normalizer.normalize(raw_input) raw_input = self.moses_tokenizer.tokenize(raw_input, return_str=True) if self.lowercase: raw_input = raw_input.lower() if not allow_empty: # ensure the string is not empty. assert raw_input is not None and len(raw_input) > 0, raw_input return raw_input
def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: """Tokenize single sentence""" if raw_input is None: return None if self.level == "word": sequence = raw_input.split(self.SPACE) elif self.level == "char": sequence = list(raw_input.replace(self.SPACE, self.SPACE_ESCAPE)) if is_train and self._filter_by_length(len(sequence)): return None return sequence def _filter_by_length(self, length: int) -> bool: """ Check if the given seq length is out of the valid range. :param length: (int) number of tokens :return: True if the length is invalid(= to be filtered out), False if valid. """ return length > self.max_length > 0 or self.min_length > length > 0 def _remove_special(self, sequence: List[str], generate_unk: bool = False): specials = self.specials if generate_unk else self.specials + [self.unk_token] valid = [token for token in sequence if token not in specials] if len(valid) == 0: # if empty, return <unk> valid = [self.unk_token] return valid
[docs] def post_process( self, sequence: Union[List[str], str], generate_unk: bool = True, cut_at_sep: bool = True ) -> str: """Detokenize""" if isinstance(sequence, list): if cut_at_sep: try: sep_pos = sequence.index(self.sep_token) # cut off prompt sequence = sequence[sep_pos + 1:] except ValueError as e: # pylint: disable=unused-variable # noqa: F841 pass sequence = self._remove_special(sequence, generate_unk=generate_unk) if self.level == "word": if self.pretokenizer == "moses": sequence = self.moses_detokenizer.detokenize(sequence) else: sequence = self.SPACE.join(sequence) elif self.level == "char": sequence = "".join(sequence).replace(self.SPACE_ESCAPE, self.SPACE) # Remove extra spaces if self.normalize: sequence = remove_extra_spaces(sequence) # ensure the string is not empty. assert sequence is not None and len(sequence) > 0, sequence return sequence
[docs] def set_vocab(self, vocab) -> None: """ Set vocab :param vocab: (Vocabulary) """ # pylint: disable=attribute-defined-outside-init self.unk_token = vocab.specials[vocab.unk_index] self.eos_token = vocab.specials[vocab.eos_index] self.sep_token = vocab.specials[vocab.sep_index] if vocab.sep_index else None specials = vocab.specials + vocab.lang_tags self.specials = [token for token in specials if token != self.unk_token] self.lang_tags = vocab.lang_tags
def __repr__(self): return ( f"{self.__class__.__name__}(level={self.level}, " f"lowercase={self.lowercase}, normalize={self.normalize}, " f"filter_by_length=({self.min_length}, {self.max_length}), " f"pretokenizer={self.pretokenizer})" )
[docs] class SentencePieceTokenizer(BasicTokenizer): def __init__( self, level: str = "bpe", lowercase: bool = False, normalize: bool = False, max_length: int = -1, min_length: int = -1, **kwargs, ): super().__init__(level, lowercase, normalize, max_length, min_length, **kwargs) assert self.level == "bpe" self.model_file: Path = Path(kwargs["model_file"]) assert self.model_file.is_file(), f"model file {self.model_file} not found." self.spm = sp.SentencePieceProcessor() self.spm.load(kwargs["model_file"]) self.nbest_size: int = kwargs.get("nbest_size", 5) self.alpha: float = kwargs.get("alpha", 0.0) def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: """Tokenize""" if raw_input is None: return None if is_train and self.alpha > 0: tokenized = self.spm.sample_encode_as_pieces( raw_input, nbest_size=self.nbest_size, alpha=self.alpha, ) else: tokenized = self.spm.encode(raw_input, out_type=str) if is_train and self._filter_by_length(len(tokenized)): return None return tokenized
[docs] def post_process( self, sequence: Union[List[str], str], generate_unk: bool = True, cut_at_sep: bool = True ) -> str: """Detokenize""" if isinstance(sequence, list): if cut_at_sep: try: sep_pos = sequence.index(self.sep_token) # cut off prompt sequence = sequence[sep_pos:] except ValueError as e: # pylint: disable=unused-variable # noqa: F841 pass sequence = self._remove_special(sequence, generate_unk=generate_unk) # Decode back to str sequence = self.spm.decode(sequence) sequence = sequence.replace(self.SPACE_ESCAPE, self.SPACE).strip() # Apply moses detokenizer if self.pretokenizer == "moses": sequence = self.moses_detokenizer.detokenize(sequence.split()) # Remove extra spaces if self.normalize: sequence = remove_extra_spaces(sequence) # ensure the string is not empty. assert sequence is not None and len(sequence) > 0, sequence return sequence
[docs] def set_vocab(self, vocab) -> None: """Set vocab""" super().set_vocab(vocab) self.spm.SetVocabulary(vocab._itos) # pylint: disable=protected-access
[docs] def copy_cfg_file(self, model_dir: Path) -> None: """Copy config file to model_dir""" if (model_dir / self.model_file.name).is_file(): logger.warning( "%s already exists. Stop copying.", (model_dir / self.model_file.name).as_posix(), ) shutil.copy2(self.model_file, (model_dir / self.model_file.name).as_posix())
def __repr__(self): return ( f"{self.__class__.__name__}(level={self.level}, " f"lowercase={self.lowercase}, normalize={self.normalize}, " f"filter_by_length=({self.min_length}, {self.max_length}), " f"pretokenizer={self.pretokenizer}, " f"tokenizer={self.spm.__class__.__name__}, " f"nbest_size={self.nbest_size}, alpha={self.alpha})" )
[docs] class SubwordNMTTokenizer(BasicTokenizer): def __init__( self, level: str = "bpe", lowercase: bool = False, normalize: bool = False, max_length: int = -1, min_length: int = -1, **kwargs, ): super().__init__(level, lowercase, normalize, max_length, min_length, **kwargs) assert self.level == "bpe" codes_file = Path(kwargs["codes"]) assert codes_file.is_file(), f"codes file {codes_file} not found." self.separator: str = kwargs.get("separator", "@@") self.dropout: float = kwargs.get("dropout", 0.0) bpe_parser = apply_bpe.create_parser() for action in bpe_parser._actions: # workaround to ensure utf8 encoding if action.dest == "codes": action.type = argparse.FileType('r', encoding='utf8') bpe_args = bpe_parser.parse_args([ "--codes", codes_file.as_posix(), "--separator", self.separator ]) self.bpe = apply_bpe.BPE( bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries, ) self.codes: Path = codes_file def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: """Tokenize""" if raw_input is None: return None dropout = self.dropout if is_train else 0.0 tokenized = self.bpe.process_line(raw_input, dropout).strip().split() if is_train and self._filter_by_length(len(tokenized)): return None return tokenized
[docs] def post_process( self, sequence: Union[List[str], str], generate_unk: bool = True, cut_at_sep: bool = True ) -> str: """Detokenize""" if isinstance(sequence, list): if cut_at_sep: try: sep_pos = sequence.index(self.sep_token) # cut off prompt sequence = sequence[sep_pos:] except ValueError as e: # pylint: disable=unused-variable # noqa: F841 pass sequence = self._remove_special(sequence, generate_unk=generate_unk) # Remove separators, join with spaces sequence = self.SPACE.join(sequence ).replace(self.separator + self.SPACE, "") # Remove final merge marker. if sequence.endswith(self.separator): sequence = sequence[:-len(self.separator)] # Moses detokenizer if self.pretokenizer == "moses": sequence = self.moses_detokenizer.detokenize(sequence.split()) # Remove extra spaces if self.normalize: sequence = remove_extra_spaces(sequence) # ensure the string is not empty. assert sequence is not None and len(sequence) > 0, sequence return sequence
[docs] def set_vocab(self, vocab) -> None: """Set vocab""" # pylint: disable=protected-access super().set_vocab(vocab) self.bpe.vocab = set(vocab._itos) - set(vocab.specials) - set(vocab.lang_tags)
[docs] def copy_cfg_file(self, model_dir: Path) -> None: """Copy config file to model_dir""" shutil.copy2(self.codes, (model_dir / self.codes.name).as_posix())
def __repr__(self): return ( f"{self.__class__.__name__}(level={self.level}, " f"lowercase={self.lowercase}, normalize={self.normalize}, " f"filter_by_length=({self.min_length}, {self.max_length}), " f"pretokenizer={self.pretokenizer}, " f"tokenizer={self.bpe.__class__.__name__}, " f"separator={self.separator}, dropout={self.dropout})" )
[docs] class FastBPETokenizer(SubwordNMTTokenizer): def __init__( self, level: str = "bpe", lowercase: bool = False, normalize: bool = False, max_length: int = -1, min_length: int = -1, **kwargs, ): try: import fastBPE # pylint: disable=import-outside-toplevel except ImportError as e: logger.error(e) raise ImportError from e super(SubwordNMTTokenizer, self ).__init__(level, lowercase, normalize, max_length, min_length, **kwargs) assert self.level == "bpe" # set codes file path self.codes: Path = Path(kwargs["codes"]) assert self.codes.is_file(), f"codes file {self.codes} not found." # instantiate fastBPE object self.bpe = fastBPE.fastBPE(self.codes.as_posix()) self.separator = "@@" self.dropout = 0.0 def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: # fastBPE.apply() tokenized = self.bpe.apply([raw_input]) tokenized = tokenized[0].strip().split() # check if the input sequence length stays within the valid length range if is_train and self._filter_by_length(len(tokenized)): return None return tokenized
[docs] def set_vocab(self, vocab) -> None: super(SubwordNMTTokenizer, self).set_vocab(vocab)
[docs] class SpeechProcessor: """SpeechProcessor""" def __init__( self, level: str = "frame", num_freq: int = 80, normalize: bool = False, max_length: int = -1, min_length: int = -1, **kwargs, ): self.level = level self.num_freq = num_freq self.normalize = normalize # filter by length self.max_length = max_length self.min_length = min_length self.specaugment: Callable = SpecAugment(**kwargs["specaugment"]) \ if "specaugment" in kwargs else None self.cmvn: Callable = CMVN(**kwargs["cmvn"]) if "cmvn" in kwargs else None self.root_path = "" # assigned later in dataset.__init__() def __call__(self, line: str, is_train: bool = False) -> np.ndarray: """ get features :param line: path to audio file or pre-extracted features :param is_train: :return: spectrogram in shape (num_frames, num_freq) """ # lookup item = get_features(self.root_path, line) # shape = (num_frames, num_freq) num_frames, num_freq = item.shape assert num_freq == self.num_freq if self._filter_too_short_item(num_frames): # A too short sequence cannot be convolved! # -> filter out anyway even in test-dev set. return None if self._filter_too_long_item(num_frames): # Don't use too long sequence in training. if is_train: # pylint: disable=no-else-return return None else: # in test, truncate the sequence item = item[:self.max_length, :] num_frames = item.shape[0] assert num_frames <= self.max_length # cmvn / specaugment # pylint: disable=not-callable if self.cmvn and self.cmvn.before: item = self.cmvn(item) if is_train and self.specaugment: item = self.specaugment(item) if self.cmvn and not self.cmvn.before: item = self.cmvn(item) return item def _filter_too_short_item(self, length: int) -> bool: return self.min_length > length > 0 def _filter_too_long_item(self, length: int) -> bool: return length > self.max_length > 0 def __repr__(self): return ( f"{self.__class__.__name__}(" f"level={self.level}, normalize={self.normalize}, " f"filter_by_length=({self.min_length}, {self.max_length}), " f"cmvn={self.cmvn}, specaugment={self.specaugment})" )
[docs] class EvaluationTokenizer(BasicTokenizer): """A generic evaluation-time tokenizer, which leverages built-in tokenizers in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides lowercasing, punctuation removal and character tokenization, which are applied after sacreBLEU tokenization. :param level: (str) tokenization level. {"word", "bpe", "char"} :param lowercase: (bool) lowercase the text. :param tokenize: (str) the type of sacreBLEU tokenizer to apply. """ ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] def __init__(self, lowercase: bool = False, tokenize: str = "13a", **kwargs): super().__init__( level="word", lowercase=lowercase, normalize=False, max_length=-1, min_length=-1 ) assert tokenize in self.ALL_TOKENIZER_TYPES, f"`{tokenize}` not supported." self.tokenizer = _get_tokenizer(tokenize)() self.no_punc = kwargs.get("no_punc", False) def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: tokenized = self.tokenizer(raw_input) if self.lowercase: tokenized = tokenized.lower() # Remove punctuation (apply this after tokenization!) if self.no_punc: tokenized = remove_punctuation(tokenized, space=self.SPACE) return tokenized.split() def __repr__(self): return ( f"{self.__class__.__name__}(level={self.level}, " f"lowercase={self.lowercase}, " f"tokenizer={self.tokenizer}, " f"no_punc={self.no_punc})" )
def _build_tokenizer(cfg: Dict) -> BasicTokenizer: """Builds tokenizer.""" tokenizer = None tokenizer_cfg = cfg.get("tokenizer_cfg", {}) # assign lang for moses tokenizer if tokenizer_cfg.get("pretokenizer", "none") == "moses": tokenizer_cfg["lang"] = cfg["lang"] if cfg["level"] in ["word", "char"]: tokenizer = BasicTokenizer( level=cfg["level"], lowercase=cfg.get("lowercase", False), normalize=cfg.get("normalize", False), max_length=cfg.get("max_length", -1), min_length=cfg.get("min_length", -1), **tokenizer_cfg, ) elif cfg["level"] == "bpe": tokenizer_type = cfg.get("tokenizer_type", cfg.get("bpe_type", "sentencepiece")) if tokenizer_type == "sentencepiece": assert "model_file" in tokenizer_cfg tokenizer = SentencePieceTokenizer( level=cfg["level"], lowercase=cfg.get("lowercase", False), normalize=cfg.get("normalize", False), max_length=cfg.get("max_length", -1), min_length=cfg.get("min_length", -1), **tokenizer_cfg, ) elif tokenizer_type == "subword-nmt": assert "codes" in tokenizer_cfg tokenizer = SubwordNMTTokenizer( level=cfg["level"], lowercase=cfg.get("lowercase", False), normalize=cfg.get("normalize", False), max_length=cfg.get("max_length", -1), min_length=cfg.get("min_length", -1), **tokenizer_cfg, ) elif tokenizer_type == "fastbpe": assert "codes" in tokenizer_cfg tokenizer = FastBPETokenizer( level=cfg["level"], lowercase=cfg.get("lowercase", False), normalize=cfg.get("normalize", False), max_length=cfg.get("max_length", -1), min_length=cfg.get("min_length", -1), **tokenizer_cfg, ) else: raise ConfigurationError( f"{tokenizer_type}: Unknown tokenizer type. " "Valid options: {'sentencepiece', 'subword-nmt'}." ) elif cfg["level"] == "frame": tokenizer = SpeechProcessor( level=cfg["level"], num_freq=cfg["num_freq"], normalize=cfg.get("normalize", False), max_length=cfg.get("max_length", -1), min_length=cfg.get("min_length", -1), **tokenizer_cfg, ) else: raise ConfigurationError( f"{cfg['level']}: Unknown tokenization level. " "Valid options: {'word', 'bpe', 'char'}." ) return tokenizer
[docs] def build_tokenizer(cfg: Dict, task: str) -> Dict[str, BasicTokenizer]: src_lang = cfg["src"]["lang"] if task == "MT" else "src" trg_lang = cfg["trg"]["lang"] if task == "MT" else "trg" tokenizer = { src_lang: _build_tokenizer(cfg["src"]), trg_lang: _build_tokenizer(cfg["trg"]), } logger.info("%s Tokenizer: %s", src_lang, tokenizer[src_lang]) logger.info("%s Tokenizer: %s", trg_lang, tokenizer[trg_lang]) return tokenizer