import numpy as np import torch def pad_tensors(tensors, lens=None, pad=0): """B x [T, ...]""" if lens is None: lens = [t.size(0) for t in tensors] max_len = max(lens) bs = len(tensors) hid = list(tensors[0].size()[1:]) size = [bs, max_len] + hid dtype = tensors[0].dtype device = tensors[0].device output = torch.zeros(*size, dtype=dtype).to(device) if pad: output.data.fill_(pad) for i, (t, l) in enumerate(zip(tensors, lens)): output.data[i, :l, ...] = t.data return output def gen_seq_masks(seq_lens, max_len=None): if max_len is None: max_len = max(seq_lens) if isinstance(seq_lens, torch.Tensor): device = seq_lens.device masks = torch.arange(max_len).to(device).repeat(len(seq_lens), 1) < seq_lens.unsqueeze(1) return masks if max_len == 0: return np.zeros((len(seq_lens), 0), dtype=np.bool) seq_lens = np.array(seq_lens) batch_size = len(seq_lens) masks = np.arange(max_len).reshape(-1, max_len).repeat(batch_size, 0) masks = masks < seq_lens.reshape(-1, 1) return masks