438 lines
19 KiB
Python
438 lines
19 KiB
Python
import random
|
|
import math
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
from .common import pad_tensors, gen_seq_masks
|
|
|
|
############### Masked Language Modeling ###############
|
|
def random_word(tokens, vocab_range, mask):
|
|
"""
|
|
Masking some random tokens for Language Model task with probabilities as in
|
|
the original BERT paper.
|
|
:param tokens: list of int, tokenized sentence.
|
|
:param vocab_range: for choosing a random word
|
|
:return: (list of int, list of int), masked tokens and related labels for
|
|
LM prediction
|
|
"""
|
|
output_tokens, output_label = [], []
|
|
|
|
for i, token in enumerate(tokens):
|
|
prob = random.random()
|
|
# mask token with 15% probability
|
|
if prob < 0.15:
|
|
prob /= 0.15
|
|
|
|
# 80% randomly change token to mask token
|
|
if prob < 0.8:
|
|
output_tokens.append(mask)
|
|
|
|
# 10% randomly change token to random token
|
|
elif prob < 0.9:
|
|
output_tokens.append(random.choice(list(range(*vocab_range))))
|
|
|
|
# -> rest 10% randomly keep current token
|
|
else:
|
|
output_tokens.append(token)
|
|
|
|
# append current token to output (we will predict these later)
|
|
output_label.append(token)
|
|
else:
|
|
output_tokens.append(token)
|
|
# no masking token (will be ignored by loss function later)
|
|
output_label.append(-1)
|
|
|
|
if all(o == -1 for o in output_label):
|
|
# at least mask 1
|
|
output_label[0] = tokens[0]
|
|
output_tokens[0] = mask
|
|
|
|
return output_tokens, output_label
|
|
|
|
class MlmDataset(Dataset):
|
|
def __init__(self, nav_db, tok):
|
|
self.nav_db = nav_db
|
|
self.tok = tok
|
|
|
|
self.vocab_range = [1996, 29611] #TODO: manually checked in bert-base-uncased
|
|
self.cls_token_id = self.tok.cls_token_id # 101
|
|
self.sep_token_id = self.tok.sep_token_id # 102
|
|
self.mask_token_id = self.tok.mask_token_id # 103
|
|
self.pad_token_id = self.tok.pad_token_id # 0
|
|
|
|
def __len__(self):
|
|
return len(self.nav_db)
|
|
|
|
def __getitem__(self, idx):
|
|
inputs = self.nav_db.get_input(idx, 'pos')
|
|
|
|
output = {}
|
|
|
|
txt_ids, txt_labels = random_word(inputs['instr_encoding'],
|
|
self.vocab_range, self.mask_token_id)
|
|
output['txt_ids'] = torch.LongTensor(txt_ids)
|
|
output['txt_labels'] = torch.LongTensor(txt_labels)
|
|
|
|
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
|
|
if 'traj_obj_img_fts' in inputs:
|
|
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
|
|
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
|
|
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
|
|
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
|
|
output['traj_vpids'] = inputs['traj_vpids']
|
|
|
|
output['gmap_vpids'] = inputs['gmap_vpids']
|
|
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
|
|
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
|
|
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
|
|
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
|
|
|
|
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
|
|
output['vp_angles'] = inputs['vp_angles']
|
|
return output
|
|
|
|
def mlm_collate(inputs):
|
|
batch = {
|
|
k: [x[k] for x in inputs] for k in inputs[0].keys()
|
|
}
|
|
# text batches
|
|
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
|
|
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
|
|
batch['txt_labels'] = pad_sequence(batch['txt_labels'], batch_first=True, padding_value=-1)
|
|
|
|
# trajectory batches: traj_cand_vpids, traj_vpids
|
|
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
|
|
batch['traj_vp_view_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
|
|
)
|
|
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
|
|
if 'traj_obj_img_fts' in batch:
|
|
batch['traj_vp_obj_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
|
|
)
|
|
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
|
|
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
|
|
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
|
|
|
|
# gmap batches: gmap_vpids
|
|
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
|
|
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
|
|
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
|
|
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
|
|
max_gmap_len = max(batch['gmap_lens'])
|
|
batch_size = len(batch['gmap_lens'])
|
|
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
|
|
for i in range(batch_size):
|
|
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
|
|
batch['gmap_pair_dists'] = gmap_pair_dists
|
|
|
|
# vp batches: vp_angles
|
|
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
|
|
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
|
|
|
|
return batch
|
|
|
|
|
|
############### Masked Region Modeling ###############
|
|
def _get_img_mask(mask_prob, num_images):
|
|
img_mask = [np.random.rand() < mask_prob for _ in range(num_images)]
|
|
if not any(img_mask):
|
|
# at least mask 1
|
|
img_mask[np.random.randint(num_images)] = True
|
|
img_mask = torch.tensor(img_mask)
|
|
return img_mask
|
|
|
|
def _mask_img_feat(img_feat, img_masks):
|
|
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat)
|
|
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
|
|
return img_feat_masked
|
|
|
|
def _get_targets(img_soft_label, img_masks):
|
|
soft_label_dim = img_soft_label.size(-1)
|
|
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
|
|
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(-1, soft_label_dim)
|
|
return label_targets
|
|
|
|
class MrcDataset(Dataset):
|
|
def __init__(self, nav_db, tok, mask_prob, end_vp_pos_ratio=1):
|
|
self.nav_db = nav_db
|
|
self.tok = tok
|
|
self.mask_prob = mask_prob
|
|
|
|
self.cls_token_id = self.tok.cls_token_id # 101
|
|
self.sep_token_id = self.tok.sep_token_id # 102
|
|
self.pad_token_id = self.tok.pad_token_id # 0
|
|
|
|
self.end_vp_pos_ratio = end_vp_pos_ratio
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.nav_db.data)
|
|
|
|
def __getitem__(self, idx):
|
|
r = np.random.rand()
|
|
if r < self.end_vp_pos_ratio:
|
|
end_vp_type = 'pos'
|
|
else:
|
|
end_vp_type = 'neg_in_gt_path'
|
|
inputs = self.nav_db.get_input(idx, end_vp_type, return_img_probs=True)
|
|
|
|
output = {}
|
|
|
|
output['txt_ids'] = torch.LongTensor(inputs['instr_encoding'])
|
|
|
|
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
|
|
|
|
# mask image
|
|
view_mrc_masks = _get_img_mask(self.mask_prob, len(output['traj_view_img_fts'][-1]))
|
|
output['traj_view_img_fts'][-1] = _mask_img_feat(output['traj_view_img_fts'][-1], view_mrc_masks)
|
|
output['vp_view_probs'] = torch.from_numpy(inputs['vp_view_probs']) # no [stop]
|
|
output['vp_view_mrc_masks'] = view_mrc_masks
|
|
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
|
|
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
|
|
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
|
|
output['traj_vpids'] = inputs['traj_vpids']
|
|
|
|
output['gmap_vpids'] = inputs['gmap_vpids']
|
|
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
|
|
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
|
|
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
|
|
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
|
|
|
|
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
|
|
output['vp_angles'] = inputs['vp_angles']
|
|
|
|
if 'traj_obj_img_fts' in inputs:
|
|
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
|
|
if len(output['traj_obj_img_fts'][-1]) > 0:
|
|
obj_mrc_masks = _get_img_mask(self.mask_prob, len(output['traj_obj_img_fts'][-1]))
|
|
output['traj_obj_img_fts'][-1] = _mask_img_feat(output['traj_obj_img_fts'][-1], obj_mrc_masks)
|
|
else:
|
|
obj_mrc_masks = torch.zeros(0, ).bool()
|
|
output['vp_obj_probs'] = torch.from_numpy(inputs['vp_obj_probs'])
|
|
output['vp_obj_mrc_masks'] = obj_mrc_masks
|
|
|
|
return output
|
|
|
|
def mrc_collate(inputs):
|
|
batch = {
|
|
k: [x[k] for x in inputs] for k in inputs[0].keys()
|
|
}
|
|
# text batches
|
|
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
|
|
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
|
|
|
|
# trajectory batches: traj_cand_vpids, traj_vpids
|
|
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
|
|
batch['traj_vp_view_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
|
|
)
|
|
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
|
|
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
|
|
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
|
|
|
|
# gmap batches: gmap_vpids
|
|
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
|
|
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
|
|
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
|
|
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
|
|
max_gmap_len = max(batch['gmap_lens'])
|
|
batch_size = len(batch['gmap_lens'])
|
|
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
|
|
for i in range(batch_size):
|
|
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
|
|
batch['gmap_pair_dists'] = gmap_pair_dists
|
|
|
|
# vp batches: vp_angles
|
|
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
|
|
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
|
|
|
|
# vp labels
|
|
batch['vp_view_mrc_masks'] = pad_sequence(batch['vp_view_mrc_masks'], batch_first=True, padding_value=0)
|
|
batch['vp_view_probs'] = pad_tensors(batch['vp_view_probs'])
|
|
|
|
if 'traj_obj_img_fts' in batch:
|
|
batch['traj_vp_obj_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
|
|
)
|
|
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
|
|
batch['vp_obj_mrc_masks'] = pad_sequence(batch['vp_obj_mrc_masks'], batch_first=True, padding_value=0)
|
|
batch['vp_obj_probs'] = pad_tensors(batch['vp_obj_probs'])
|
|
|
|
return batch
|
|
|
|
|
|
############### Single-step Action Prediction ###############
|
|
class SapDataset(Dataset):
|
|
def __init__(self, nav_db, tok, end_vp_pos_ratio=0.2):
|
|
'''Instruction Trajectory Matching'''
|
|
self.nav_db = nav_db
|
|
self.tok = tok
|
|
|
|
self.cls_token_id = self.tok.cls_token_id # 101
|
|
self.sep_token_id = self.tok.sep_token_id # 102
|
|
self.pad_token_id = self.tok.pad_token_id # 0
|
|
|
|
self.end_vp_pos_ratio = end_vp_pos_ratio
|
|
|
|
def __len__(self):
|
|
return len(self.nav_db.data)
|
|
|
|
def __getitem__(self, idx):
|
|
r = np.random.rand()
|
|
if r < self.end_vp_pos_ratio:
|
|
end_vp_type = 'pos'
|
|
elif r < 0.6:
|
|
end_vp_type = 'neg_in_gt_path'
|
|
else:
|
|
end_vp_type = 'neg_others'
|
|
inputs = self.nav_db.get_input(idx, end_vp_type, return_act_label=True)
|
|
|
|
output = {}
|
|
|
|
output['txt_ids'] = torch.LongTensor(inputs['instr_encoding'])
|
|
|
|
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
|
|
if 'traj_obj_img_fts' in inputs:
|
|
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
|
|
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
|
|
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
|
|
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
|
|
output['traj_vpids'] = inputs['traj_vpids']
|
|
|
|
output['gmap_vpids'] = inputs['gmap_vpids']
|
|
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
|
|
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
|
|
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
|
|
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
|
|
|
|
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
|
|
output['vp_angles'] = inputs['vp_angles']
|
|
|
|
output['local_act_labels'] = inputs['local_act_labels']
|
|
output['global_act_labels'] = inputs['global_act_labels']
|
|
return output
|
|
|
|
def sap_collate(inputs):
|
|
batch = {
|
|
k: [x[k] for x in inputs] for k in inputs[0].keys()
|
|
}
|
|
# text batches
|
|
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
|
|
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
|
|
|
|
# trajectory batches: traj_cand_vpids, traj_vpids
|
|
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
|
|
batch['traj_vp_view_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
|
|
)
|
|
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
|
|
if 'traj_obj_img_fts' in batch:
|
|
batch['traj_vp_obj_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
|
|
)
|
|
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
|
|
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
|
|
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
|
|
|
|
# gmap batches: gmap_vpids
|
|
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
|
|
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
|
|
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
|
|
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
|
|
max_gmap_len = max(batch['gmap_lens'])
|
|
batch_size = len(batch['gmap_lens'])
|
|
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
|
|
for i in range(batch_size):
|
|
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
|
|
batch['gmap_pair_dists'] = gmap_pair_dists
|
|
|
|
# vp batches: vp_angles
|
|
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
|
|
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
|
|
|
|
# action labels
|
|
batch['local_act_labels'] = torch.LongTensor(batch['local_act_labels'])
|
|
batch['global_act_labels'] = torch.LongTensor(batch['global_act_labels'])
|
|
return batch
|
|
|
|
|
|
############### Object Grounding ###############
|
|
class OGDataset(Dataset):
|
|
def __init__(self, nav_db, tok):
|
|
self.nav_db = nav_db
|
|
self.tok = tok
|
|
|
|
def __len__(self):
|
|
return len(self.nav_db.data)
|
|
|
|
def __getitem__(self, idx):
|
|
inputs = self.nav_db.get_input(idx, 'pos', return_obj_label=True)
|
|
|
|
output = {}
|
|
|
|
output['txt_ids'] = torch.LongTensor(inputs['instr_encoding'])
|
|
|
|
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
|
|
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
|
|
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
|
|
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
|
|
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
|
|
output['traj_vpids'] = inputs['traj_vpids']
|
|
|
|
output['gmap_vpids'] = inputs['gmap_vpids']
|
|
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
|
|
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
|
|
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
|
|
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
|
|
|
|
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
|
|
output['vp_angles'] = inputs['vp_angles']
|
|
|
|
output['obj_labels'] = inputs['obj_labels']
|
|
return output
|
|
|
|
def og_collate(inputs):
|
|
batch = {
|
|
k: [x[k] for x in inputs] for k in inputs[0].keys()
|
|
}
|
|
# text batches
|
|
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
|
|
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
|
|
|
|
# trajectory batches: traj_cand_vpids, traj_vpids
|
|
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
|
|
batch['traj_vp_view_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
|
|
)
|
|
batch['traj_vp_obj_lens'] = torch.LongTensor(
|
|
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
|
|
)
|
|
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
|
|
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
|
|
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
|
|
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
|
|
|
|
# gmap batches: gmap_vpids
|
|
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
|
|
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
|
|
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
|
|
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
|
|
max_gmap_len = max(batch['gmap_lens'])
|
|
batch_size = len(batch['gmap_lens'])
|
|
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
|
|
for i in range(batch_size):
|
|
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
|
|
batch['gmap_pair_dists'] = gmap_pair_dists
|
|
|
|
# vp batches: vp_angles
|
|
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
|
|
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
|
|
|
|
# vp labels
|
|
batch['obj_labels'] = torch.LongTensor(batch['obj_labels'])
|
|
return batch
|