Source code for joeynmt.loss

# coding: utf-8
"""
Loss functions
"""
import logging
from typing import Tuple

import torch
from torch import Tensor, nn
from torch.autograd import Variable
from torch.nn.modules.loss import _Loss

logger = logging.getLogger(__name__)


[docs] class XentLoss(nn.Module): """ Cross-Entropy Loss with optional label smoothing """ def __init__(self, pad_index: int, smoothing: float = 0.0): super().__init__() self.smoothing = smoothing self.pad_index = pad_index self.criterion: _Loss # (type annotation) if self.smoothing <= 0.0: # standard xent loss self.criterion = nn.NLLLoss(ignore_index=self.pad_index, reduction="sum") else: # custom label-smoothed loss, computed with KL divergence loss self.criterion = nn.KLDivLoss(reduction="sum") self.require_ctc_layer = False def _smooth_targets(self, targets: Tensor, vocab_size: int) -> Variable: """ Smooth target distribution. All non-reference words get uniform probability mass according to "smoothing". :param targets: target indices, batch*seq_len :param vocab_size: size of the output vocabulary :return: smoothed target distributions, batch*seq_len x vocab_size """ # batch*seq_len x vocab_size smooth_dist = targets.new_zeros((targets.size(0), vocab_size)).float() # fill distribution uniformly with smoothing smooth_dist.fill_(self.smoothing / (vocab_size - 2)) # assign true label the probability of 1-smoothing ("confidence") smooth_dist.scatter_(1, targets.unsqueeze(1).data, 1.0 - self.smoothing) # give padding probability of 0 everywhere smooth_dist[:, self.pad_index] = 0 # masking out padding area (sum of probabilities for padding area = 0) padding_positions = torch.nonzero( targets.data == self.pad_index, as_tuple=False ) if len(padding_positions) > 0: smooth_dist.index_fill_(0, padding_positions.squeeze(), 0.0) return Variable(smooth_dist, requires_grad=False) def _reshape(self, log_probs: Tensor, targets: Tensor) -> Tensor: vocab_size = log_probs.size(-1) # reshape log_probs to (batch*seq_len x vocab_size) log_probs_flat = log_probs.contiguous().view(-1, vocab_size) if self.smoothing > 0: targets_flat = self._smooth_targets( targets=targets.contiguous().view(-1), vocab_size=vocab_size ) # targets: distributions with batch*seq_len x vocab_size assert log_probs_flat.size() == targets_flat.size(), ( log_probs.size(), targets_flat.size(), ) else: # targets: indices with batch*seq_len targets_flat = targets.contiguous().view(-1) assert log_probs_flat.size(0) == targets_flat.size(0), ( log_probs.size(0), targets_flat.size(0), ) return log_probs_flat, targets_flat
[docs] def forward(self, log_probs: Tensor, **kwargs) -> Tuple[Tensor]: """ Compute the cross-entropy between logits and targets. If label smoothing is used, target distributions are not one-hot, but "1-smoothing" for the correct target token and the rest of the probability mass is uniformly spread across the other tokens. :param log_probs: log probabilities as predicted by model :return: logits """ assert "trg" in kwargs log_probs, targets = self._reshape(log_probs, kwargs["trg"]) # compute loss logits = self.criterion(log_probs, targets) return (logits, )
def __repr__(self): return ( f"{self.__class__.__name__}(criterion={self.criterion}, " f"smoothing={self.smoothing})" )
[docs] class XentCTCLoss(XentLoss): """ Cross-Entropy + CTC loss with optional label smoothing """ def __init__( self, pad_index: int, bos_index: int, smoothing: float = 0.0, zero_infinity: bool = True, ctc_weight: float = 0.3 ): super().__init__(pad_index=pad_index, smoothing=smoothing) self.require_ctc_layer = True self.bos_index = bos_index self.ctc_weight = ctc_weight self.ctc = nn.CTCLoss( blank=bos_index, reduction='sum', zero_infinity=zero_infinity )
[docs] def forward(self, log_probs, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: """ Compute the cross-entropy loss and ctc loss :param log_probs: log probabilities as predicted by model shape (batch_size, seq_length, vocab_size) :return: - total loss - xent loss - ctc loss """ assert "trg" in kwargs assert "trg_length" in kwargs assert "src_mask" in kwargs assert "ctc_log_probs" in kwargs # reshape tensors for cross_entropy log_probs_flat, targets_flat = self._reshape(log_probs, kwargs["trg"]) # cross_entropy loss xent_loss = self.criterion(log_probs_flat, targets_flat) # ctc_loss # reshape ctc_log_probs to (seq_length, batch_size, vocab_size) ctc_loss = self.ctc( kwargs["ctc_log_probs"].transpose(0, 1).contiguous(), targets=kwargs["trg"], # (seq_length, batch_size) input_lengths=kwargs["src_mask"].squeeze(1).sum(dim=1), target_lengths=kwargs["trg_length"] ) # interpolation total_loss = (1.0 - self.ctc_weight) * xent_loss + self.ctc_weight * ctc_loss assert not total_loss.isnan(), "loss has to be non-NaN value." assert total_loss.item() >= 0.0, "loss has to be non-negative." return total_loss, xent_loss, ctc_loss
def __repr__(self): return ( f"{self.__class__.__name__}(criterion={self.criterion}, " f"smoothing={self.smoothing})" f"{self.__class__.__name__}(" f"criterion={self.criterion}, smoothing={self.smoothing}, " f"ctc={self.ctc}, ctc_weight={self.ctc_weight})" )