Source code for joeynmt.encoders
# coding: utf-8
"""
Various encoders
"""
from typing import List, Tuple
import torch
from torch import Tensor, nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from joeynmt.helpers import freeze_params, lengths_to_padding_mask, pad
from joeynmt.helpers_for_ddp import get_logger
from joeynmt.transformer_layers import (
ConformerEncoderLayer,
PositionalEncoding,
TransformerEncoderLayer,
)
logger = get_logger(__name__)
[docs]
class Encoder(nn.Module):
"""
Base encoder class
"""
# pylint: disable=abstract-method
@property
def output_size(self):
"""
Return the output size
:return:
"""
return self._output_size
[docs]
class RecurrentEncoder(Encoder):
"""Encodes a sequence of word embeddings"""
# pylint: disable=unused-argument
def __init__(
self,
rnn_type: str = "gru",
hidden_size: int = 1,
emb_size: int = 1,
num_layers: int = 1,
dropout: float = 0.0,
emb_dropout: float = 0.0,
bidirectional: bool = True,
freeze: bool = False,
**kwargs,
) -> None:
"""
Create a new recurrent encoder.
:param rnn_type: RNN type: `gru` or `lstm`.
:param hidden_size: Size of each RNN.
:param emb_size: Size of the word embeddings.
:param num_layers: Number of encoder RNN layers.
:param dropout: Is applied between RNN layers.
:param emb_dropout: Is applied to the RNN input (word embeddings).
:param bidirectional: Use a bi-directional RNN.
:param freeze: freeze the parameters of the encoder during training
:param kwargs:
"""
super().__init__()
self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False)
self.type = rnn_type
self.emb_size = emb_size
rnn = nn.GRU if rnn_type == "gru" else nn.LSTM
self.rnn = rnn(
emb_size,
hidden_size,
num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=dropout if num_layers > 1 else 0.0,
)
self._output_size = 2 * hidden_size if bidirectional else hidden_size
if freeze:
freeze_params(self)
def _check_shapes_input_forward(
self, src_embed: Tensor, src_length: Tensor, mask: Tensor
) -> None:
"""
Make sure the shape of the inputs to `self.forward` are correct.
Same input semantics as `self.forward`.
:param src_embed: embedded source tokens
:param src_length: source length
:param mask: source mask
"""
# pylint: disable=unused-argument
assert src_embed.shape[0] == src_length.shape[0]
assert src_embed.shape[2] == self.emb_size
# assert mask.shape == src_embed.shape
assert len(src_length.shape) == 1
[docs]
def forward(self, src_embed: Tensor, src_length: Tensor, mask: Tensor,
**kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""
Applies a bidirectional RNN to sequence of embeddings x.
The input mini-batch x needs to be sorted by src length.
x and mask should have the same dimensions [batch, time, dim].
:param src_embed: embedded src inputs,
shape (batch_size, src_len, embed_size)
:param src_length: length of src inputs
(counting tokens before padding), shape (batch_size)
:param mask: indicates padding areas (zeros where padding), shape
(batch_size, src_len, embed_size)
:param kwargs:
:return:
- output: hidden states with
shape (batch_size, max_length, directions*hidden),
- hidden_concat: last hidden state with
shape (batch_size, directions*hidden)
"""
self._check_shapes_input_forward(
src_embed=src_embed, src_length=src_length, mask=mask
)
total_length = src_embed.size(1)
# apply dropout to the rnn input
src_embed = self.emb_dropout(src_embed)
packed = pack_padded_sequence(src_embed, src_length.cpu(), batch_first=True)
output, hidden = self.rnn(packed)
if isinstance(hidden, tuple):
hidden, memory_cell = hidden # pylint: disable=unused-variable
output, _ = pad_packed_sequence(
output, batch_first=True, total_length=total_length
)
# hidden: dir*layers x batch x hidden
# output: batch x max_length x directions*hidden
batch_size = hidden.size()[1]
# separate final hidden states by layer and direction
hidden_layerwise = hidden.view(
self.rnn.num_layers,
2 if self.rnn.bidirectional else 1,
batch_size,
self.rnn.hidden_size,
)
# final_layers: layers x directions x batch x hidden
# concatenate the final states of the last layer for each directions
# thanks to pack_padded_sequence final states don't include padding
fwd_hidden_last = hidden_layerwise[-1:, 0]
bwd_hidden_last = hidden_layerwise[-1:, 1]
# only feed the final state of the top-most layer to the decoder
# pylint: disable=no-member
hidden_concat = torch.cat([fwd_hidden_last, bwd_hidden_last], dim=2).squeeze(0)
# final: batch x directions*hidden
assert hidden_concat.size(0) == output.size(0), (
hidden_concat.size(),
output.size(),
)
return output, hidden_concat, None
def __repr__(self):
return f"{self.__class__.__name__}(rnn={self.rnn})"
[docs]
class TransformerEncoder(Encoder):
"""
Transformer Encoder
"""
def __init__(
self,
hidden_size: int = 512,
ff_size: int = 2048,
num_layers: int = 8,
num_heads: int = 4,
dropout: float = 0.1,
emb_dropout: float = 0.1,
freeze: bool = False,
**kwargs,
):
"""
Initializes the Transformer.
:param hidden_size: hidden size and size of embeddings
:param ff_size: position-wise feed-forward layer size.
(Typically this is 2*hidden_size.)
:param num_layers: number of layers
:param num_heads: number of heads for multi-headed attention
:param dropout: dropout probability for Transformer layers
:param emb_dropout: Is applied to the input (word embeddings).
:param freeze: freeze the parameters of the encoder during training
:param kwargs:
"""
super().__init__()
self._output_size = hidden_size
# build all (num_layers) layers
self.layers = nn.ModuleList([
TransformerEncoderLayer(
size=hidden_size,
ff_size=ff_size,
num_heads=num_heads,
dropout=dropout,
alpha=kwargs.get("alpha", 1.0),
layer_norm=kwargs.get("layer_norm", "pre"),
activation=kwargs.get("activation", "relu"),
) for _ in range(num_layers)
])
self.pe = PositionalEncoding(hidden_size)
self.emb_dropout = nn.Dropout(p=emb_dropout)
self.layer_norm = (
nn.LayerNorm(hidden_size, eps=1e-6)
if kwargs.get("layer_norm", "post") == "pre" else None
)
if freeze:
freeze_params(self)
# conv1d subsampling for audio inputs
self.subsample = kwargs.get("subsample", False)
if self.subsample:
self.subsampler = Conv1dSubsampler(
kwargs["in_channels"], kwargs["conv_channels"], hidden_size,
kwargs.get("conv_kernel_sizes", [3, 3])
)
self.pad_index = kwargs.get("pad_index", 1)
assert self.pad_index is not None
[docs]
def forward(
self,
src_embed: Tensor,
src_length: Tensor, # unused
mask: Tensor = None,
**kwargs,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Pass the input (and mask) through each layer in turn.
Applies a Transformer encoder to sequence of embeddings x.
The input mini-batch x needs to be sorted by src length.
x and mask should have the same dimensions [batch, time, dim].
:param src_embed: embedded src inputs,
shape (batch_size, src_len, embed_size)
:param src_length: length of src inputs
(counting tokens before padding), shape (batch_size)
:param mask: indicates padding areas (zeros where padding), shape
(batch_size, 1, src_len)
:param kwargs:
:return:
- output: hidden states with shape (batch_size, max_length, hidden)
- None
- mask
"""
# pylint: disable=unused-argument
if self.subsample:
src_embed, src_length = self.subsampler(src_embed, src_length)
if mask is None:
mask = lengths_to_padding_mask(src_length).unsqueeze(1)
x = self.pe(src_embed) # add position encoding to word embeddings
if kwargs.get("src_prompt_mask", None) is not None: # add src_prompt_mask
x = x + kwargs["src_prompt_mask"]
x = self.emb_dropout(x)
for layer in self.layers:
x = layer(x, mask)
if self.layer_norm is not None:
x = self.layer_norm(x)
if kwargs.get('repad', False) and "src_max_len" in kwargs and self.subsample:
x, mask = self._repad(x, mask, kwargs["src_max_len"])
assert src_length.size() == (x.size(0), ), (src_length.size(), x.size())
assert mask.size() == (x.size(0), 1, x.size(1)), (mask.size(), x.size())
return x, None, mask
def _repad(self, x, mask, src_max_len):
# re-pad `x` and `mask` so that all seqs in parallel gpus have the same len!
src_max_len = int(
self.subsampler.get_out_seq_lens_tensor(torch.tensor(src_max_len).float()
).item()
)
x = pad(x, src_max_len, pad_index=self.pad_index, dim=1)
mask = pad(mask, src_max_len, pad_index=self.pad_index, dim=-1)
return x, mask
def __repr__(self):
return (
f"{self.__class__.__name__}(num_layers={len(self.layers)}, "
f"num_heads={self.layers[0].src_src_att.num_heads}, "
f"alpha={self.layers[0].alpha}, "
f'layer_norm="{self.layers[0]._layer_norm_position}", '
f'activation="{self.layers[0].feed_forward.pwff_layer[1]}", '
f'subsample={self.subsample})'
)
[docs]
class Conv1dSubsampler(nn.Module):
"""
Convolutional subsampler: a stack of 1D convolution (along temporal dimension)
followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/models/speech_to_text/s2t_transformer.py
:param in_channels: the number of input channels (embed_size = num_freq)
:param mid_channels: the number of intermediate channels
:param out_channels: the number of output channels (hidden_size)
:param kernel_sizes: the kernel size for each convolutional layer
:return:
- output tensor
- sequence length after subsampling
""" # noqa: E501
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int = None,
kernel_sizes: List[int] = (3, 3)
):
super().__init__()
self.kernel_sizes = kernel_sizes
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
) for i, k in enumerate(kernel_sizes)
)
[docs]
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for k in self.kernel_sizes:
out = ((out.float() + 2 * (k // 2) - (k - 1) - 1) / 2 + 1).floor().long()
return out
[docs]
def forward(self, src_tokens, src_lengths):
# reshape after DataParallel batch split
max_len = torch.max(src_lengths).item()
assert max_len > 0, "empty batch!"
if src_tokens.size(1) != max_len:
src_tokens = src_tokens[:, :max_len, :]
assert src_tokens.size(1) == max_len, (src_tokens.size(), max_len, src_lengths)
_, in_seq_len, _ = src_tokens.size() # -> B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
_, _, out_seq_len = x.size()
x = x.transpose(1, 2).contiguous() # -> B x T x (C x D)
out_seq_lens = self.get_out_seq_lens_tensor(src_lengths)
assert x.size(1) == torch.max(out_seq_lens).item(), \
(x.size(), in_seq_len, out_seq_len, out_seq_lens)
return x, out_seq_lens
[docs]
class ConformerEncoder(TransformerEncoder):
"""
Conformer Encoder
"""
def __init__(
self,
hidden_size: int = 512,
ff_size: int = 2048,
num_layers: int = 8,
num_heads: int = 4,
dropout: float = 0.1,
emb_dropout: float = 0.1,
freeze: bool = False,
**kwargs,
):
super().__init__()
self._output_size = hidden_size
# build all (num_layers) layers
self.layers = nn.ModuleList([
ConformerEncoderLayer(
size=hidden_size,
ff_size=ff_size,
num_heads=num_heads,
dropout=dropout,
alpha=kwargs.get("alpha", 1.0),
layer_norm=kwargs.get("layer_norm", "pre"),
depthwise_conv_kernel_size=kwargs.get("depthwise_conv_kernel_size", 31)
) for _ in range(num_layers)
])
self.pe = PositionalEncoding(hidden_size)
self.emb_dropout = nn.Dropout(p=emb_dropout)
self.linear = nn.Linear(hidden_size, hidden_size)
if freeze:
freeze_params(self)
# conv1d subsampling for audio inputs
self.subsampler = Conv1dSubsampler(
kwargs["in_channels"], kwargs["conv_channels"], hidden_size,
kwargs.get("conv_kernel_sizes", [3, 3])
)
self.pad_index = kwargs.get("pad_index", 1)
assert self.pad_index is not None
[docs]
def forward(
self,
src_embed: Tensor,
src_length: Tensor,
mask: Tensor = None,
**kwargs,
) -> Tuple[Tensor, Tensor, Tensor]:
x, src_length = self.subsampler(src_embed, src_length) # always subsample
mask = lengths_to_padding_mask(src_length).unsqueeze(1) # recompute src mask
x = self.pe(x) # add position encoding to spectrogram features
x = self.linear(x)
x = self.emb_dropout(x)
for layer in self.layers:
x = layer(x, mask) # T x B x C
if kwargs.get('repad', False) and "src_max_len" in kwargs:
x, mask = self._repad(x, mask, kwargs["src_max_len"])
assert src_length.size() == (x.size(0), ), (src_length.size(), x.size())
assert mask.size() == (x.size(0), 1, x.size(1)), (mask.size(), x.size())
return x, None, mask