Source code for joeynmt.data_augmentation

# coding: utf-8
"""
Data Augmentation
"""
import math
from typing import Optional

import numpy as np

from joeynmt.helpers_for_ddp import get_logger

logger = get_logger(__name__)


[docs] class SpecAugment: """ SpecAugment (https://arxiv.org/abs/1904.08779) cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/data/audio/feature_transforms/specaugment.py """ # noqa def __init__( self, freq_mask_n: int = 2, freq_mask_f: int = 27, time_mask_n: int = 2, time_mask_t: int = 40, time_mask_p: float = 1.0, mask_value: Optional[float] = None ): self.freq_mask_n = freq_mask_n self.freq_mask_f = freq_mask_f self.time_mask_n = time_mask_n self.time_mask_t = time_mask_t self.time_mask_p = time_mask_p self.mask_value = mask_value def __call__(self, spectrogram: np.ndarray) -> np.ndarray: assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." distorted = spectrogram.copy() # make a copy of input spectrogram. num_frames, num_freqs = spectrogram.shape mask_value = self.mask_value if mask_value is None: # if no value was specified, use local mean. mask_value = spectrogram.mean() if num_frames == 0: return spectrogram if num_freqs < self.freq_mask_f: return spectrogram for _i in range(self.freq_mask_n): f = np.random.randint(0, self.freq_mask_f) f0 = np.random.randint(0, num_freqs - f) if f != 0: distorted[:, f0:f0 + f] = mask_value max_time_mask_t = min( self.time_mask_t, math.floor(num_frames * self.time_mask_p) ) if max_time_mask_t < 1: return distorted for _i in range(self.time_mask_n): t = np.random.randint(0, max_time_mask_t) t0 = np.random.randint(0, num_frames - t) if t != 0: distorted[t0:t0 + t, :] = mask_value assert distorted.shape == spectrogram.shape return distorted def __repr__(self): return ( f"{self.__class__.__name__}(freq_mask_n={self.freq_mask_n}, " f"freq_mask_f={self.freq_mask_f}, time_mask_n={self.time_mask_n}, " f"time_mask_t={self.time_mask_t}, time_mask_p={self.time_mask_p})" )
[docs] class CMVN: """ CMVN: Cepstral Mean and Variance Normalization (Utterance-level) cf.) https://github.com/pytorch/fairseq/blob/main/fairseq/data/audio/feature_transforms/utterance_cmvn.py """ # noqa def __init__( self, norm_means: bool = True, norm_vars: bool = True, before: bool = True ): self.norm_means = norm_means self.norm_vars = norm_vars self.before = before def __call__(self, x: np.ndarray) -> np.ndarray: orig_shape = x.shape mean = x.mean(axis=0) square_sums = (x**2).sum(axis=0) if self.norm_means: x = np.subtract(x, mean) if self.norm_vars: var = square_sums / x.shape[0] - mean**2 std = np.sqrt(np.maximum(var, 1e-10)) x = np.divide(x, std) assert orig_shape == x.shape return x def __repr__(self): return ( f"{self.__class__.__name__}(norm_means={self.norm_means}, " f"norm_vars={self.norm_vars}, before={self.before})" )