Source code for joeynmt.helpers_for_ddp

# coding: utf-8
"""
helper functions for DDP
"""
import logging
import math
import os
from typing import Optional, Union

import torch
import torch.distributed as dist
from torch import Tensor
from torch.utils.data import Dataset, SequentialSampler
from torch.utils.data.distributed import DistributedSampler


[docs] def ddp_setup( rank: int, world_size: int, master_addr: str = "localhost", master_port: int = 12355, ) -> None: """ Setup distributed environment :param rank: Unique identifier of each process :param world_size: Total number of processes :param master_addr: :param master_port: """ if dist.is_available(): if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = master_addr if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = str(master_port) dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank)
[docs] def use_ddp() -> bool: """Check if DDP environment is available""" return dist.is_available() and dist.is_initialized()
[docs] def ddp_cleanup() -> None: """Cleanup distributed environment""" if use_ddp(): dist.destroy_process_group()
[docs] def ddp_synchronize() -> None: """Synchronize distributed processes""" if use_ddp(): dist.barrier()
[docs] def ddp_merge(data: Tensor, pad_index: int = 1) -> Tensor: """ Merge tensors from multiple devices :param data: tensor to merge :param pad_index: :return: merged tensor """ if data is None: return None assert torch.is_tensor(data), data if use_ddp(): dim = len(data.size()) if dim == 2: batch_size, seq_length = data.size() elif dim == 3: batch_size, seq_length, vocab_size = data.size() else: raise ValueError world_size = dist.get_world_size() # check tensor size in each device local_size = torch.tensor(data.size(), device=data.device) all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] dist.all_gather(all_sizes, local_size) # resolve length differences # `dist.all_gather()` requires the same tensor dim across devices # # ex) world_size = 2; # rank 0: data = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] in shape (2, 5) # rank 1: data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] in shape (3, 3) # -> max length along 0th dim = 3, max length along 1st dim = 5 # # after padding (pad_index = -1): all tensors in shape (3, 5) # rank 0: padding = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [-1, -1, -1, -1, -1]] # rank 1: padding = [[1, 2, 3, -1, -1], [4, 5, 6, -1, -1], [7, 8, 9, -1, -1]] # # merged tensor in shape (3 * world_size, 5): # tmp = [ # [1, 2, 3, 4, 5], # rank 0 # [6, 7, 8, 9, 10], # rank 0 # [-1, -1, -1, -1, -1], # rank 0 # [1, 2, 3, -1, -1], # rank 1 # [4, 5, 6, -1, -1], # rank 1 # [7, 8, 9, -1, -1], # rank 1 # ] # # will be returned a tensor in shape (5, 5): # data = [ # [1, 2, 3, 4, 5], # [6, 7, 8, 9, 10], # # [-1, -1, -1, -1, -1], # the padding-only row should be omitted! # [1, 2, 3, -1, -1], # [4, 5, 6, -1, -1], # [7, 8, 9, -1, -1], # ] # check max lengths max_length0 = max(size[0] for size in all_sizes).item() max_length1 = max(size[1] for size in all_sizes).item() max_length2 = max(size[2] for size in all_sizes).item() if dim == 3 else 0 # padding if dim == 2: pad_size = (max_length0, max_length1) elif dim == 3: pad_size = (max_length0, max_length1, max_length2) padding = torch.full(pad_size, pad_index, device=data.device, dtype=data.dtype) if dim == 2: padding[:batch_size, :seq_length] = data elif dim == 3: padding[:batch_size, :seq_length, :vocab_size] = data # gather tmp = [torch.zeros_like(padding) for _ in range(world_size)] dist.all_gather(tmp, padding) data = torch.cat(tmp, dim=0) # omit padding-only rows valid = [] valid_batch_size = 0 for i, size in enumerate(all_sizes): valid_batch_size += size[0].item() offset = i * max_length0 end = offset + size[0].item() valid.append(data[offset:end]) if end < offset + max_length0: cutoff = data[end:offset + max_length0] assert torch.all(cutoff == pad_index).item() data = torch.cat(valid, dim=0) assert data.size(0) == valid_batch_size, (data.size(), valid_batch_size) return data
[docs] def ddp_reduce(data: Union[Tensor, int], device=None, dtype=None) -> Tensor: """ Reduce tensors from multiple devices :param data: tensor to reduce :return: reduced tensor """ if data is None: return None if not torch.is_tensor(data): assert device is not None and dtype is not None data = torch.tensor(data, device=device, dtype=dtype) if use_ddp(): if len(data.size()) < 1: data = data.unsqueeze(0) dist.all_reduce(data, op=dist.ReduceOp.SUM) return data
[docs] class MultiProcessAdapter(logging.LoggerAdapter): """ An adapter to assist with logging in multiprocess. taken from Huggingface's Accelerate logger """
[docs] def log(self, level, msg, *args, **kwargs): """ Delegates logger call after checking if we should log. """ flag = False master_only = kwargs.pop("master_only", True) if master_only: rank = dist.get_rank() if use_ddp() else 0 flag = rank == 0 if self.isEnabledFor(level) and flag: msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs)
[docs] def get_logger(name: str = "", log_file: str = None) -> logging.Logger: """ Create a logger for logging the training/testing process. :param name: logger name. :param log_file: path to file where log is stored as well :return: logging.Logger """ formatter = logging.Formatter( "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ) def _add_filehandler(logger, log_file): fh = logging.FileHandler(log_file, encoding="utf-8") fh.setLevel(level=logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) def _add_streamhandler(logger): sh = logging.StreamHandler() sh.setLevel(logging.INFO) sh.setFormatter(formatter) logger.addHandler(sh) # assign file handler whenever `log_file` arg is provided if log_file is not None: for logger_name in logging.root.manager.loggerDict: if logger_name.startswith("joeynmt."): logger = logging.getLogger(logger_name) if len(logger.handlers) < 2: _add_filehandler(logger, log_file) current_logger = logging.getLogger(name) if len(current_logger.handlers) == 0: current_logger.setLevel(level=logging.DEBUG) _add_streamhandler(current_logger) if log_file is not None: _add_filehandler(current_logger, log_file) current_logger.propagate = False # otherwise root logger prints things again return MultiProcessAdapter(current_logger, {})
[docs] class DistributedSubsetSampler(DistributedSampler): """ DistributedSampler with random subsampling. `drop_last` logic is simplified; raise error if `len(dataset)` is not divisible by `world_size` and cut off leftovers. .. warning:: Token-based batch sampling is not supported in distributed learning. :param data_source (Dataset): dataset to sample from :param num_replicas (int): ddp world size :param rank (int): ddp local rank :param shuffle (bool): whether to permute or not :param drop_last (bool): must be true! :param generator (Generator): Generator used in sampling. """ def __init__( self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, drop_last: bool = True, generator: torch.Generator = None ): # pylint: disable=super-init-not-called # super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) if num_replicas is None: if not use_ddp(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not use_ddp(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( f"Invalid rank {rank}, rank should be in the interval" f" [0, {num_replicas - 1}]" ) self.data_source = dataset # alias self.num_replicas = num_replicas self.rank = rank self.shuffle = shuffle self.drop_last = drop_last self.generator = generator @property def num_samples(self) -> int: """total size""" return len(self.data_source.indices) def __iter__(self): indices = self.data_source.indices if self.shuffle: # permute perm = torch.randperm(len(indices), generator=self.generator).tolist() indices = [indices[i] for i in perm] # don't assign permuted indices to self.data_source.indices if len(indices) % self.num_replicas != 0 and not self.drop_last: raise RuntimeError("`len(dataset)` must be divisible by `world_size`.") # set `random_subset` with a divisible value or enable drop_last # remove tail of data to make it evenly divisible. total_samples = (self.num_samples // self.num_replicas) * self.num_replicas indices = indices[:total_samples] assert len(indices) % self.num_replicas == 0, ( len(indices), self.num_samples, self.num_replicas ) self.data_source.indices = indices # reset indices after dropping leftovers # distribute samples indices_per_replica = indices[self.rank:self.num_samples:self.num_replicas] assert len(indices_per_replica) == math.ceil( self.num_samples / self.num_replicas ) return iter(indices_per_replica) def _subsample(self): """get random subset; indices are still sorted (no permutation!)""" orig_len = len(self.data_source) subset_len = self.data_source.random_subset if 0 < subset_len < orig_len: subset = torch.randperm(n=orig_len, generator=self.generator).tolist()[:subset_len] self.data_source.indices = sorted(subset) assert len(subset) == self.num_samples
[docs] def reset(self): self.data_source.reset_indices()
[docs] def set_seed(self, seed: int) -> None: """set seed and resample""" self.generator.manual_seed(seed) self._subsample()
[docs] class RandomSubsetSampler(SequentialSampler): """Samples subset randomly from a given data_source without replacement. If shuffle = False, yields subset elements sequentially. :param data_source (Dataset): dataset to sample from :param shuffle (bool): whether to permute or not :param generator (Generator): Generator used in sampling. """ def __init__(self, data_source: Dataset, shuffle: bool, generator: torch.Generator): super().__init__(data_source) self.generator = generator self.shuffle = shuffle @property def num_samples(self) -> int: return len(self.data_source.indices) def __iter__(self): indices = self.data_source.indices if self.shuffle: # permute perm = torch.randperm(n=len(indices), generator=self.generator).tolist() return iter([indices[i] for i in perm]) return iter(indices) def __len__(self) -> int: return self.num_samples def _subsample(self): """get random subset; indices are still sorted (no permutation!)""" orig_len = len(self.data_source) subset_len = self.data_source.random_subset if 0 < subset_len < orig_len: subset = torch.randperm(n=orig_len, generator=self.generator).tolist()[:subset_len] self.data_source.indices = sorted(subset) assert len(subset) == self.num_samples
[docs] def reset(self): self.data_source.reset_indices()
[docs] def set_seed(self, seed: int) -> None: """set seed and resample""" self.generator.manual_seed(seed) self._subsample()