# coding: utf-8
"""
Torch Hub Interface
"""
from pathlib import Path
from typing import List, NamedTuple, Optional, Union
import numpy as np
import plotly.express as px
from torch import nn
from joeynmt.config import (
BaseConfig,
TestConfig,
_check_options,
load_config,
parse_global_args,
)
from joeynmt.datasets import BaseDataset, SpeechStreamDataset, StreamDataset
from joeynmt.helpers_for_ddp import get_logger
from joeynmt.model import Model
from joeynmt.prediction import predict, prepare
logger = get_logger(__name__)
PredictionOutput = NamedTuple(
"PredictionOutput",
[
("translation", List[str]),
("tokens", Optional[List[List[str]]]),
("token_probs", Optional[List[List[float]]]),
("sequence_probs", Optional[List[float]]),
("attention_probs", Optional[List[List[float]]]),
],
)
def _check_file_path(path: Union[str, Path], model_dir: Path) -> Path:
"""Check torch hub cache path"""
if path is None:
return None
p = Path(path) if isinstance(path, str) else path
if not p.is_file():
p = model_dir / p.name
assert p.is_file(), p
return p
def _from_pretrained(
model_name_or_path: Union[str, Path],
cfg_file: Union[str, Path] = "config.yaml",
**kwargs,
):
"""Prepare model and data placeholder"""
# model dir
model_dir = Path(model_name_or_path
) if isinstance(model_name_or_path, str) else model_name_or_path
assert model_dir.is_dir(), model_dir
# cfg file
cfg_file = _check_file_path(cfg_file, model_dir)
assert cfg_file.is_file(), cfg_file
cfg = load_config(cfg_file)
cfg.update(kwargs)
cfg["model_dir"] = model_dir.as_posix() # override model_dir
if "task" in cfg["data"]: # for backwards compatibility
cfg["task"] = cfg["data"]["task"]
task = cfg.get("task", "MT").upper()
_check_options("task", task, ["MT", "S2T"])
# rewrite paths in cfg
for side in ["src", "trg"]:
if task == "S2T" and side == "src":
assert cfg["data"]["dataset_type"] == "speech"
assert cfg["data"][side]["tokenizer_type"] == "speech"
else:
data_side = cfg["data"][side]
data_side["voc_file"] = _check_file_path(data_side["voc_file"],
model_dir).as_posix()
if "tokenizer_cfg" in data_side:
for tok_model in ["codes", "model_file"]:
if tok_model in data_side["tokenizer_cfg"]:
data_side["tokenizer_cfg"][tok_model] = _check_file_path(
data_side["tokenizer_cfg"][tok_model], model_dir
).as_posix()
if "load_model" in cfg["testing"]:
cfg["testing"]["load_model"] = _check_file_path(
cfg["testing"]["load_model"], model_dir
).as_posix()
# parse args
args = parse_global_args(cfg, rank=0, mode="translate")
# load the data
model, _, _, test_data = prepare(args, rank=0, mode="translate")
return model, test_data, args
[docs]
class TranslatorHubInterface(nn.Module):
"""
PyTorch Hub interface for generating sequences from a pre-trained
encoder-decoder model.
"""
def __init__(self, model: Model, dataset: BaseDataset, args: BaseConfig):
super().__init__()
self.args = args
self.dataset = dataset
self.model = model
if self.args.device.type == "cuda":
self.model.to(self.args.device)
self.model.eval()
[docs]
def score(
self,
src: List[str],
trg: Optional[List[str]] = None,
**kwargs,
) -> List[PredictionOutput]:
assert isinstance(src, list), "Please provide a list of sentences!"
kwargs["return_prob"] = "hyp" if trg is None else "ref"
kwargs["return_attention"] = True
translations, tokens, probs, attn, test_cfg = self._generate(src, trg, **kwargs)
beam_size = test_cfg.get("beam_size", 1)
n_best = test_cfg.get("n_best", 1)
out = []
for i in range(len(src)):
offset = i * n_best
pred = PredictionOutput(
translation=trg[i] if trg else translations[offset:offset + n_best],
tokens=tokens[offset:offset + n_best],
token_probs=probs[offset:offset + n_best] if beam_size == 1 else None,
sequence_probs=[p[0] for p in probs[offset:offset + n_best]] \
if beam_size > 1 else None, # noqa: E131
attention_probs=attn[offset:offset + n_best] if attn else None,
)
out.append(pred)
return out
[docs]
def generate(self, src: List[str], **kwargs) -> List[str]:
assert isinstance(src, list), "Please provide a list of sentences!"
kwargs["return_prob"] = "none"
translations, _, _, _, _ = self._generate(src, **kwargs)
return translations
def _generate(
self,
src: List[str],
trg: Optional[List[str]] = None,
src_prompt: Optional[List[str]] = None,
trg_prompt: Optional[List[str]] = None,
**kwargs,
) -> List[str]:
# overwrite config
test_cfg = self.args.test._asdict()
test_cfg.update(kwargs)
if self.args.task == "MT":
assert isinstance(self.dataset, StreamDataset), self.dataset
elif self.args.task == "S2T":
assert isinstance(self.dataset, SpeechStreamDataset), self.dataset
test_cfg["batch_type"] = "sentence"
test_cfg["batch_size"] = len(src)
if src_prompt:
assert len(src) == len(
src_prompt
), "src and src_prompt must have the same length!"
else:
src_prompt = [None] * len(src)
if trg_prompt:
assert len(src) == len(
trg_prompt
), "trg and trg_prompt must have the same length!"
else:
trg_prompt = [None] * len(src)
self.dataset.reset_cache() # reset cache
if trg is not None:
assert len(src) == len(trg), "src and trg must have the same length!"
self.dataset.has_trg = True
test_cfg["n_best"] = 1
test_cfg["beam_size"] = 1
test_cfg["return_prob"] = "ref"
for src_sent, trg_sent, src_p, trg_p in zip(
src, trg, src_prompt, trg_prompt
):
self.dataset.set_item(src_sent, trg_sent, src_p, trg_p)
else:
self.dataset.has_trg = False
for src_sent, src_p, trg_p in zip(src, src_prompt, trg_prompt):
self.dataset.set_item(src_sent, None, src_p, trg_p)
assert len(self.dataset) == len(src), (len(self.dataset), self.dataset.cache)
_, _, translations, tokens, probs, attention_probs = predict(
model=self.model,
data=self.dataset,
compute_loss=trg is not None,
device=self.args.device,
n_gpu=self.args.n_gpu,
normalization=self.args.train.normalization,
num_workers=self.args.num_workers,
args=TestConfig(**test_cfg),
autocast=self.args.autocast,
)
if translations:
assert len(src) * test_cfg.get("n_best", 1) == len(translations)
self.dataset.reset_cache() # reset cache
return translations, tokens, probs, attention_probs, test_cfg
[docs]
def plot_attention(self, src: str, trg: str, attention_scores: np.ndarray) -> None:
# preprocess and tokenize sentences
self.dataset.reset_cache() # reset cache
self.dataset.has_trg = True
self.dataset.set_item(src, trg)
src_tokens = self.dataset.get_item(
idx=0, lang=self.dataset.src_lang, is_train=False
)
trg_tokens = self.dataset.get_item(
idx=0, lang=self.dataset.trg_lang, is_train=False
)
self.dataset.reset_cache() # reset cache
assert len(src_tokens) + 1 == attention_scores.shape[1]
assert len(trg_tokens) + 1 == attention_scores.shape[0]
# plot attention scores
fig = px.imshow(
attention_scores,
labels={
"x": "Src",
"y": "Trg",
},
x=src_tokens + [self.dataset.tokenizer[self.dataset.src_lang].eos_token],
y=trg_tokens + [self.dataset.tokenizer[self.dataset.trg_lang].eos_token],
)
fig.update_xaxes(side="top", tickangle=270)
fig.show()