38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
def pad_tensors(tensors, lens=None, pad=0):
|
|
"""B x [T, ...]"""
|
|
if lens is None:
|
|
lens = [t.size(0) for t in tensors]
|
|
max_len = max(lens)
|
|
bs = len(tensors)
|
|
hid = list(tensors[0].size()[1:])
|
|
size = [bs, max_len] + hid
|
|
|
|
dtype = tensors[0].dtype
|
|
device = tensors[0].device
|
|
output = torch.zeros(*size, dtype=dtype).to(device)
|
|
if pad:
|
|
output.data.fill_(pad)
|
|
for i, (t, l) in enumerate(zip(tensors, lens)):
|
|
output.data[i, :l, ...] = t.data
|
|
return output
|
|
|
|
def gen_seq_masks(seq_lens, max_len=None):
|
|
if max_len is None:
|
|
max_len = max(seq_lens)
|
|
|
|
if isinstance(seq_lens, torch.Tensor):
|
|
device = seq_lens.device
|
|
masks = torch.arange(max_len).to(device).repeat(len(seq_lens), 1) < seq_lens.unsqueeze(1)
|
|
return masks
|
|
|
|
if max_len == 0:
|
|
return np.zeros((len(seq_lens), 0), dtype=np.bool)
|
|
|
|
seq_lens = np.array(seq_lens)
|
|
batch_size = len(seq_lens)
|
|
masks = np.arange(max_len).reshape(-1, max_len).repeat(batch_size, 0)
|
|
masks = masks < seq_lens.reshape(-1, 1)
|
|
return masks |