# coding: utf-8
"""
Tokenizer module
"""
import argparse
import shutil
from pathlib import Path
from typing import Callable, Dict, List, Union
import numpy as np
import sentencepiece as sp
from sacrebleu.metrics.bleu import _get_tokenizer
from subword_nmt import apply_bpe
from joeynmt.config import ConfigurationError
from joeynmt.data_augmentation import CMVN, SpecAugment
from joeynmt.helpers import remove_extra_spaces, remove_punctuation, unicode_normalize
from joeynmt.helpers_for_audio import get_features
from joeynmt.helpers_for_ddp import get_logger
logger = get_logger(__name__)
[docs]
class BasicTokenizer:
# pylint: disable=too-many-instance-attributes
SPACE = chr(32) # ' ': half-width white space (ascii)
SPACE_ESCAPE = chr(9601) # '▁': sentencepiece default
def __init__(
self,
level: str = "word",
lowercase: bool = False,
normalize: bool = False,
max_length: int = -1,
min_length: int = -1,
**kwargs,
):
# pylint: disable=unused-argument
self.level = level
self.lowercase = lowercase
self.normalize = normalize
# filter by length
self.max_length = max_length
self.min_length = min_length
# pretokenizer
self.pretokenizer = kwargs.get("pretokenizer", "none").lower()
assert self.pretokenizer in ["none", "moses"], \
"Currently, we support moses tokenizer only."
# sacremoses
if self.pretokenizer == "moses":
try:
from sacremoses import ( # pylint: disable=import-outside-toplevel
MosesDetokenizer,
MosesPunctNormalizer,
MosesTokenizer,
)
# sacremoses package has to be installed.
# https://github.com/alvations/sacremoses
except ImportError as e:
logger.error(e)
raise ImportError from e
self.lang = kwargs.get("lang", "en")
self.moses_tokenizer = MosesTokenizer(lang=self.lang)
self.moses_detokenizer = MosesDetokenizer(lang=self.lang)
if self.normalize:
self.moses_normalizer = MosesPunctNormalizer()
[docs]
def pre_process(self, raw_input: str, allow_empty: bool = False) -> str:
"""
Pre-process text
- ex.) Lowercase, Normalize, Remove emojis,
Pre-tokenize(add extra white space before punc) etc.
- applied for all inputs both in training and inference.
:param raw_input: raw input string
:param allow_empty: whether to allow empty string
:return: preprocessed input string
"""
if not allow_empty:
assert isinstance(raw_input, str) and raw_input.strip() != "", \
"The input sentence is empty! Please make sure " \
"that you are feeding a valid input."
if self.normalize:
raw_input = remove_extra_spaces(unicode_normalize(raw_input))
if self.pretokenizer == "moses":
if self.normalize:
raw_input = self.moses_normalizer.normalize(raw_input)
raw_input = self.moses_tokenizer.tokenize(raw_input, return_str=True)
if self.lowercase:
raw_input = raw_input.lower()
if not allow_empty:
# ensure the string is not empty.
assert raw_input is not None and len(raw_input) > 0, raw_input
return raw_input
def __call__(self, raw_input: str, is_train: bool = False) -> List[str]:
"""Tokenize single sentence"""
if raw_input is None:
return None
if self.level == "word":
sequence = raw_input.split(self.SPACE)
elif self.level == "char":
sequence = list(raw_input.replace(self.SPACE, self.SPACE_ESCAPE))
if is_train and self._filter_by_length(len(sequence)):
return None
return sequence
def _filter_by_length(self, length: int) -> bool:
"""
Check if the given seq length is out of the valid range.
:param length: (int) number of tokens
:return: True if the length is invalid(= to be filtered out), False if valid.
"""
return length > self.max_length > 0 or self.min_length > length > 0
def _remove_special(self, sequence: List[str], generate_unk: bool = False):
specials = self.specials if generate_unk else self.specials + [self.unk_token]
valid = [token for token in sequence if token not in specials]
if len(valid) == 0: # if empty, return <unk>
valid = [self.unk_token]
return valid
[docs]
def post_process(
self,
sequence: Union[List[str], str],
generate_unk: bool = True,
cut_at_sep: bool = True
) -> str:
"""Detokenize"""
if isinstance(sequence, list):
if cut_at_sep:
try:
sep_pos = sequence.index(self.sep_token) # cut off prompt
sequence = sequence[sep_pos + 1:]
except ValueError as e: # pylint: disable=unused-variable # noqa: F841
pass
sequence = self._remove_special(sequence, generate_unk=generate_unk)
if self.level == "word":
if self.pretokenizer == "moses":
sequence = self.moses_detokenizer.detokenize(sequence)
else:
sequence = self.SPACE.join(sequence)
elif self.level == "char":
sequence = "".join(sequence).replace(self.SPACE_ESCAPE, self.SPACE)
# Remove extra spaces
if self.normalize:
sequence = remove_extra_spaces(sequence)
# ensure the string is not empty.
assert sequence is not None and len(sequence) > 0, sequence
return sequence
[docs]
def set_vocab(self, vocab) -> None:
"""
Set vocab
:param vocab: (Vocabulary)
"""
# pylint: disable=attribute-defined-outside-init
self.unk_token = vocab.specials[vocab.unk_index]
self.eos_token = vocab.specials[vocab.eos_index]
self.sep_token = vocab.specials[vocab.sep_index] if vocab.sep_index else None
specials = vocab.specials + vocab.lang_tags
self.specials = [token for token in specials if token != self.unk_token]
self.lang_tags = vocab.lang_tags
def __repr__(self):
return (
f"{self.__class__.__name__}(level={self.level}, "
f"lowercase={self.lowercase}, normalize={self.normalize}, "
f"filter_by_length=({self.min_length}, {self.max_length}), "
f"pretokenizer={self.pretokenizer})"
)
[docs]
class SentencePieceTokenizer(BasicTokenizer):
def __init__(
self,
level: str = "bpe",
lowercase: bool = False,
normalize: bool = False,
max_length: int = -1,
min_length: int = -1,
**kwargs,
):
super().__init__(level, lowercase, normalize, max_length, min_length, **kwargs)
assert self.level == "bpe"
self.model_file: Path = Path(kwargs["model_file"])
assert self.model_file.is_file(), f"model file {self.model_file} not found."
self.spm = sp.SentencePieceProcessor()
self.spm.load(kwargs["model_file"])
self.nbest_size: int = kwargs.get("nbest_size", 5)
self.alpha: float = kwargs.get("alpha", 0.0)
def __call__(self, raw_input: str, is_train: bool = False) -> List[str]:
"""Tokenize"""
if raw_input is None:
return None
if is_train and self.alpha > 0:
tokenized = self.spm.sample_encode_as_pieces(
raw_input,
nbest_size=self.nbest_size,
alpha=self.alpha,
)
else:
tokenized = self.spm.encode(raw_input, out_type=str)
if is_train and self._filter_by_length(len(tokenized)):
return None
return tokenized
[docs]
def post_process(
self,
sequence: Union[List[str], str],
generate_unk: bool = True,
cut_at_sep: bool = True
) -> str:
"""Detokenize"""
if isinstance(sequence, list):
if cut_at_sep:
try:
sep_pos = sequence.index(self.sep_token) # cut off prompt
sequence = sequence[sep_pos:]
except ValueError as e: # pylint: disable=unused-variable # noqa: F841
pass
sequence = self._remove_special(sequence, generate_unk=generate_unk)
# Decode back to str
sequence = self.spm.decode(sequence)
sequence = sequence.replace(self.SPACE_ESCAPE, self.SPACE).strip()
# Apply moses detokenizer
if self.pretokenizer == "moses":
sequence = self.moses_detokenizer.detokenize(sequence.split())
# Remove extra spaces
if self.normalize:
sequence = remove_extra_spaces(sequence)
# ensure the string is not empty.
assert sequence is not None and len(sequence) > 0, sequence
return sequence
[docs]
def set_vocab(self, vocab) -> None:
"""Set vocab"""
super().set_vocab(vocab)
self.spm.SetVocabulary(vocab._itos) # pylint: disable=protected-access
[docs]
def copy_cfg_file(self, model_dir: Path) -> None:
"""Copy config file to model_dir"""
if (model_dir / self.model_file.name).is_file():
logger.warning(
"%s already exists. Stop copying.",
(model_dir / self.model_file.name).as_posix(),
)
shutil.copy2(self.model_file, (model_dir / self.model_file.name).as_posix())
def __repr__(self):
return (
f"{self.__class__.__name__}(level={self.level}, "
f"lowercase={self.lowercase}, normalize={self.normalize}, "
f"filter_by_length=({self.min_length}, {self.max_length}), "
f"pretokenizer={self.pretokenizer}, "
f"tokenizer={self.spm.__class__.__name__}, "
f"nbest_size={self.nbest_size}, alpha={self.alpha})"
)
[docs]
class SubwordNMTTokenizer(BasicTokenizer):
def __init__(
self,
level: str = "bpe",
lowercase: bool = False,
normalize: bool = False,
max_length: int = -1,
min_length: int = -1,
**kwargs,
):
super().__init__(level, lowercase, normalize, max_length, min_length, **kwargs)
assert self.level == "bpe"
codes_file = Path(kwargs["codes"])
assert codes_file.is_file(), f"codes file {codes_file} not found."
self.separator: str = kwargs.get("separator", "@@")
self.dropout: float = kwargs.get("dropout", 0.0)
bpe_parser = apply_bpe.create_parser()
for action in bpe_parser._actions: # workaround to ensure utf8 encoding
if action.dest == "codes":
action.type = argparse.FileType('r', encoding='utf8')
bpe_args = bpe_parser.parse_args([
"--codes", codes_file.as_posix(), "--separator", self.separator
])
self.bpe = apply_bpe.BPE(
bpe_args.codes,
bpe_args.merges,
bpe_args.separator,
None,
bpe_args.glossaries,
)
self.codes: Path = codes_file
def __call__(self, raw_input: str, is_train: bool = False) -> List[str]:
"""Tokenize"""
if raw_input is None:
return None
dropout = self.dropout if is_train else 0.0
tokenized = self.bpe.process_line(raw_input, dropout).strip().split()
if is_train and self._filter_by_length(len(tokenized)):
return None
return tokenized
[docs]
def post_process(
self,
sequence: Union[List[str], str],
generate_unk: bool = True,
cut_at_sep: bool = True
) -> str:
"""Detokenize"""
if isinstance(sequence, list):
if cut_at_sep:
try:
sep_pos = sequence.index(self.sep_token) # cut off prompt
sequence = sequence[sep_pos:]
except ValueError as e: # pylint: disable=unused-variable # noqa: F841
pass
sequence = self._remove_special(sequence, generate_unk=generate_unk)
# Remove separators, join with spaces
sequence = self.SPACE.join(sequence
).replace(self.separator + self.SPACE, "")
# Remove final merge marker.
if sequence.endswith(self.separator):
sequence = sequence[:-len(self.separator)]
# Moses detokenizer
if self.pretokenizer == "moses":
sequence = self.moses_detokenizer.detokenize(sequence.split())
# Remove extra spaces
if self.normalize:
sequence = remove_extra_spaces(sequence)
# ensure the string is not empty.
assert sequence is not None and len(sequence) > 0, sequence
return sequence
[docs]
def set_vocab(self, vocab) -> None:
"""Set vocab"""
# pylint: disable=protected-access
super().set_vocab(vocab)
self.bpe.vocab = set(vocab._itos) - set(vocab.specials) - set(vocab.lang_tags)
[docs]
def copy_cfg_file(self, model_dir: Path) -> None:
"""Copy config file to model_dir"""
shutil.copy2(self.codes, (model_dir / self.codes.name).as_posix())
def __repr__(self):
return (
f"{self.__class__.__name__}(level={self.level}, "
f"lowercase={self.lowercase}, normalize={self.normalize}, "
f"filter_by_length=({self.min_length}, {self.max_length}), "
f"pretokenizer={self.pretokenizer}, "
f"tokenizer={self.bpe.__class__.__name__}, "
f"separator={self.separator}, dropout={self.dropout})"
)
[docs]
class FastBPETokenizer(SubwordNMTTokenizer):
def __init__(
self,
level: str = "bpe",
lowercase: bool = False,
normalize: bool = False,
max_length: int = -1,
min_length: int = -1,
**kwargs,
):
try:
import fastBPE # pylint: disable=import-outside-toplevel
except ImportError as e:
logger.error(e)
raise ImportError from e
super(SubwordNMTTokenizer, self
).__init__(level, lowercase, normalize, max_length, min_length, **kwargs)
assert self.level == "bpe"
# set codes file path
self.codes: Path = Path(kwargs["codes"])
assert self.codes.is_file(), f"codes file {self.codes} not found."
# instantiate fastBPE object
self.bpe = fastBPE.fastBPE(self.codes.as_posix())
self.separator = "@@"
self.dropout = 0.0
def __call__(self, raw_input: str, is_train: bool = False) -> List[str]:
# fastBPE.apply()
tokenized = self.bpe.apply([raw_input])
tokenized = tokenized[0].strip().split()
# check if the input sequence length stays within the valid length range
if is_train and self._filter_by_length(len(tokenized)):
return None
return tokenized
[docs]
def set_vocab(self, vocab) -> None:
super(SubwordNMTTokenizer, self).set_vocab(vocab)
[docs]
class SpeechProcessor:
"""SpeechProcessor"""
def __init__(
self,
level: str = "frame",
num_freq: int = 80,
normalize: bool = False,
max_length: int = -1,
min_length: int = -1,
**kwargs,
):
self.level = level
self.num_freq = num_freq
self.normalize = normalize
# filter by length
self.max_length = max_length
self.min_length = min_length
self.specaugment: Callable = SpecAugment(**kwargs["specaugment"]) \
if "specaugment" in kwargs else None
self.cmvn: Callable = CMVN(**kwargs["cmvn"]) if "cmvn" in kwargs else None
self.root_path = "" # assigned later in dataset.__init__()
def __call__(self, line: str, is_train: bool = False) -> np.ndarray:
"""
get features
:param line: path to audio file or pre-extracted features
:param is_train:
:return: spectrogram in shape (num_frames, num_freq)
"""
# lookup
item = get_features(self.root_path, line) # shape = (num_frames, num_freq)
num_frames, num_freq = item.shape
assert num_freq == self.num_freq
if self._filter_too_short_item(num_frames):
# A too short sequence cannot be convolved!
# -> filter out anyway even in test-dev set.
return None
if self._filter_too_long_item(num_frames):
# Don't use too long sequence in training.
if is_train: # pylint: disable=no-else-return
return None
else: # in test, truncate the sequence
item = item[:self.max_length, :]
num_frames = item.shape[0]
assert num_frames <= self.max_length
# cmvn / specaugment
# pylint: disable=not-callable
if self.cmvn and self.cmvn.before:
item = self.cmvn(item)
if is_train and self.specaugment:
item = self.specaugment(item)
if self.cmvn and not self.cmvn.before:
item = self.cmvn(item)
return item
def _filter_too_short_item(self, length: int) -> bool:
return self.min_length > length > 0
def _filter_too_long_item(self, length: int) -> bool:
return length > self.max_length > 0
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"level={self.level}, normalize={self.normalize}, "
f"filter_by_length=({self.min_length}, {self.max_length}), "
f"cmvn={self.cmvn}, specaugment={self.specaugment})"
)
[docs]
class EvaluationTokenizer(BasicTokenizer):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers in
sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides
lowercasing, punctuation removal and character tokenization, which are applied
after sacreBLEU tokenization.
:param level: (str) tokenization level. {"word", "bpe", "char"}
:param lowercase: (bool) lowercase the text.
:param tokenize: (str) the type of sacreBLEU tokenizer to apply.
"""
ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"]
def __init__(self, lowercase: bool = False, tokenize: str = "13a", **kwargs):
super().__init__(
level="word",
lowercase=lowercase,
normalize=False,
max_length=-1,
min_length=-1
)
assert tokenize in self.ALL_TOKENIZER_TYPES, f"`{tokenize}` not supported."
self.tokenizer = _get_tokenizer(tokenize)()
self.no_punc = kwargs.get("no_punc", False)
def __call__(self, raw_input: str, is_train: bool = False) -> List[str]:
tokenized = self.tokenizer(raw_input)
if self.lowercase:
tokenized = tokenized.lower()
# Remove punctuation (apply this after tokenization!)
if self.no_punc:
tokenized = remove_punctuation(tokenized, space=self.SPACE)
return tokenized.split()
def __repr__(self):
return (
f"{self.__class__.__name__}(level={self.level}, "
f"lowercase={self.lowercase}, "
f"tokenizer={self.tokenizer}, "
f"no_punc={self.no_punc})"
)
def _build_tokenizer(cfg: Dict) -> BasicTokenizer:
"""Builds tokenizer."""
tokenizer = None
tokenizer_cfg = cfg.get("tokenizer_cfg", {})
# assign lang for moses tokenizer
if tokenizer_cfg.get("pretokenizer", "none") == "moses":
tokenizer_cfg["lang"] = cfg["lang"]
if cfg["level"] in ["word", "char"]:
tokenizer = BasicTokenizer(
level=cfg["level"],
lowercase=cfg.get("lowercase", False),
normalize=cfg.get("normalize", False),
max_length=cfg.get("max_length", -1),
min_length=cfg.get("min_length", -1),
**tokenizer_cfg,
)
elif cfg["level"] == "bpe":
tokenizer_type = cfg.get("tokenizer_type", cfg.get("bpe_type", "sentencepiece"))
if tokenizer_type == "sentencepiece":
assert "model_file" in tokenizer_cfg
tokenizer = SentencePieceTokenizer(
level=cfg["level"],
lowercase=cfg.get("lowercase", False),
normalize=cfg.get("normalize", False),
max_length=cfg.get("max_length", -1),
min_length=cfg.get("min_length", -1),
**tokenizer_cfg,
)
elif tokenizer_type == "subword-nmt":
assert "codes" in tokenizer_cfg
tokenizer = SubwordNMTTokenizer(
level=cfg["level"],
lowercase=cfg.get("lowercase", False),
normalize=cfg.get("normalize", False),
max_length=cfg.get("max_length", -1),
min_length=cfg.get("min_length", -1),
**tokenizer_cfg,
)
elif tokenizer_type == "fastbpe":
assert "codes" in tokenizer_cfg
tokenizer = FastBPETokenizer(
level=cfg["level"],
lowercase=cfg.get("lowercase", False),
normalize=cfg.get("normalize", False),
max_length=cfg.get("max_length", -1),
min_length=cfg.get("min_length", -1),
**tokenizer_cfg,
)
else:
raise ConfigurationError(
f"{tokenizer_type}: Unknown tokenizer type. "
"Valid options: {'sentencepiece', 'subword-nmt'}."
)
elif cfg["level"] == "frame":
tokenizer = SpeechProcessor(
level=cfg["level"],
num_freq=cfg["num_freq"],
normalize=cfg.get("normalize", False),
max_length=cfg.get("max_length", -1),
min_length=cfg.get("min_length", -1),
**tokenizer_cfg,
)
else:
raise ConfigurationError(
f"{cfg['level']}: Unknown tokenization level. "
"Valid options: {'word', 'bpe', 'char'}."
)
return tokenizer
[docs]
def build_tokenizer(cfg: Dict, task: str) -> Dict[str, BasicTokenizer]:
src_lang = cfg["src"]["lang"] if task == "MT" else "src"
trg_lang = cfg["trg"]["lang"] if task == "MT" else "trg"
tokenizer = {
src_lang: _build_tokenizer(cfg["src"]),
trg_lang: _build_tokenizer(cfg["trg"]),
}
logger.info("%s Tokenizer: %s", src_lang, tokenizer[src_lang])
logger.info("%s Tokenizer: %s", trg_lang, tokenizer[trg_lang])
return tokenizer