adversarial_VLNDUET/pretrain_src/data/tasks.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

438 lines
19 KiB
Python

import random
import math
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from .common import pad_tensors, gen_seq_masks
############### Masked Language Modeling ###############
def random_word(tokens, vocab_range, mask):
"""
Masking some random tokens for Language Model task with probabilities as in
the original BERT paper.
:param tokens: list of int, tokenized sentence.
:param vocab_range: for choosing a random word
:return: (list of int, list of int), masked tokens and related labels for
LM prediction
"""
output_tokens, output_label = [], []
for i, token in enumerate(tokens):
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
output_tokens.append(mask)
# 10% randomly change token to random token
elif prob < 0.9:
output_tokens.append(random.choice(list(range(*vocab_range))))
# -> rest 10% randomly keep current token
else:
output_tokens.append(token)
# append current token to output (we will predict these later)
output_label.append(token)
else:
output_tokens.append(token)
# no masking token (will be ignored by loss function later)
output_label.append(-1)
if all(o == -1 for o in output_label):
# at least mask 1
output_label[0] = tokens[0]
output_tokens[0] = mask
return output_tokens, output_label
class MlmDataset(Dataset):
def __init__(self, nav_db, tok):
self.nav_db = nav_db
self.tok = tok
self.vocab_range = [1996, 29611] #TODO: manually checked in bert-base-uncased
self.cls_token_id = self.tok.cls_token_id # 101
self.sep_token_id = self.tok.sep_token_id # 102
self.mask_token_id = self.tok.mask_token_id # 103
self.pad_token_id = self.tok.pad_token_id # 0
def __len__(self):
return len(self.nav_db)
def __getitem__(self, idx):
inputs = self.nav_db.get_input(idx, 'pos')
output = {}
txt_ids, txt_labels = random_word(inputs['instr_encoding'],
self.vocab_range, self.mask_token_id)
output['txt_ids'] = torch.LongTensor(txt_ids)
output['txt_labels'] = torch.LongTensor(txt_labels)
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
if 'traj_obj_img_fts' in inputs:
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
output['traj_vpids'] = inputs['traj_vpids']
output['gmap_vpids'] = inputs['gmap_vpids']
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
output['vp_angles'] = inputs['vp_angles']
return output
def mlm_collate(inputs):
batch = {
k: [x[k] for x in inputs] for k in inputs[0].keys()
}
# text batches
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
batch['txt_labels'] = pad_sequence(batch['txt_labels'], batch_first=True, padding_value=-1)
# trajectory batches: traj_cand_vpids, traj_vpids
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
batch['traj_vp_view_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
)
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
if 'traj_obj_img_fts' in batch:
batch['traj_vp_obj_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
)
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
# gmap batches: gmap_vpids
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
max_gmap_len = max(batch['gmap_lens'])
batch_size = len(batch['gmap_lens'])
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
for i in range(batch_size):
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
batch['gmap_pair_dists'] = gmap_pair_dists
# vp batches: vp_angles
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
return batch
############### Masked Region Modeling ###############
def _get_img_mask(mask_prob, num_images):
img_mask = [np.random.rand() < mask_prob for _ in range(num_images)]
if not any(img_mask):
# at least mask 1
img_mask[np.random.randint(num_images)] = True
img_mask = torch.tensor(img_mask)
return img_mask
def _mask_img_feat(img_feat, img_masks):
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat)
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
return img_feat_masked
def _get_targets(img_soft_label, img_masks):
soft_label_dim = img_soft_label.size(-1)
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(-1, soft_label_dim)
return label_targets
class MrcDataset(Dataset):
def __init__(self, nav_db, tok, mask_prob, end_vp_pos_ratio=1):
self.nav_db = nav_db
self.tok = tok
self.mask_prob = mask_prob
self.cls_token_id = self.tok.cls_token_id # 101
self.sep_token_id = self.tok.sep_token_id # 102
self.pad_token_id = self.tok.pad_token_id # 0
self.end_vp_pos_ratio = end_vp_pos_ratio
def __len__(self):
return len(self.nav_db.data)
def __getitem__(self, idx):
r = np.random.rand()
if r < self.end_vp_pos_ratio:
end_vp_type = 'pos'
else:
end_vp_type = 'neg_in_gt_path'
inputs = self.nav_db.get_input(idx, end_vp_type, return_img_probs=True)
output = {}
output['txt_ids'] = torch.LongTensor(inputs['instr_encoding'])
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
# mask image
view_mrc_masks = _get_img_mask(self.mask_prob, len(output['traj_view_img_fts'][-1]))
output['traj_view_img_fts'][-1] = _mask_img_feat(output['traj_view_img_fts'][-1], view_mrc_masks)
output['vp_view_probs'] = torch.from_numpy(inputs['vp_view_probs']) # no [stop]
output['vp_view_mrc_masks'] = view_mrc_masks
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
output['traj_vpids'] = inputs['traj_vpids']
output['gmap_vpids'] = inputs['gmap_vpids']
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
output['vp_angles'] = inputs['vp_angles']
if 'traj_obj_img_fts' in inputs:
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
if len(output['traj_obj_img_fts'][-1]) > 0:
obj_mrc_masks = _get_img_mask(self.mask_prob, len(output['traj_obj_img_fts'][-1]))
output['traj_obj_img_fts'][-1] = _mask_img_feat(output['traj_obj_img_fts'][-1], obj_mrc_masks)
else:
obj_mrc_masks = torch.zeros(0, ).bool()
output['vp_obj_probs'] = torch.from_numpy(inputs['vp_obj_probs'])
output['vp_obj_mrc_masks'] = obj_mrc_masks
return output
def mrc_collate(inputs):
batch = {
k: [x[k] for x in inputs] for k in inputs[0].keys()
}
# text batches
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
# trajectory batches: traj_cand_vpids, traj_vpids
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
batch['traj_vp_view_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
)
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
# gmap batches: gmap_vpids
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
max_gmap_len = max(batch['gmap_lens'])
batch_size = len(batch['gmap_lens'])
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
for i in range(batch_size):
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
batch['gmap_pair_dists'] = gmap_pair_dists
# vp batches: vp_angles
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
# vp labels
batch['vp_view_mrc_masks'] = pad_sequence(batch['vp_view_mrc_masks'], batch_first=True, padding_value=0)
batch['vp_view_probs'] = pad_tensors(batch['vp_view_probs'])
if 'traj_obj_img_fts' in batch:
batch['traj_vp_obj_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
)
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
batch['vp_obj_mrc_masks'] = pad_sequence(batch['vp_obj_mrc_masks'], batch_first=True, padding_value=0)
batch['vp_obj_probs'] = pad_tensors(batch['vp_obj_probs'])
return batch
############### Single-step Action Prediction ###############
class SapDataset(Dataset):
def __init__(self, nav_db, tok, end_vp_pos_ratio=0.2):
'''Instruction Trajectory Matching'''
self.nav_db = nav_db
self.tok = tok
self.cls_token_id = self.tok.cls_token_id # 101
self.sep_token_id = self.tok.sep_token_id # 102
self.pad_token_id = self.tok.pad_token_id # 0
self.end_vp_pos_ratio = end_vp_pos_ratio
def __len__(self):
return len(self.nav_db.data)
def __getitem__(self, idx):
r = np.random.rand()
if r < self.end_vp_pos_ratio:
end_vp_type = 'pos'
elif r < 0.6:
end_vp_type = 'neg_in_gt_path'
else:
end_vp_type = 'neg_others'
inputs = self.nav_db.get_input(idx, end_vp_type, return_act_label=True)
output = {}
output['txt_ids'] = torch.LongTensor(inputs['instr_encoding'])
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
if 'traj_obj_img_fts' in inputs:
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
output['traj_vpids'] = inputs['traj_vpids']
output['gmap_vpids'] = inputs['gmap_vpids']
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
output['vp_angles'] = inputs['vp_angles']
output['local_act_labels'] = inputs['local_act_labels']
output['global_act_labels'] = inputs['global_act_labels']
return output
def sap_collate(inputs):
batch = {
k: [x[k] for x in inputs] for k in inputs[0].keys()
}
# text batches
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
# trajectory batches: traj_cand_vpids, traj_vpids
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
batch['traj_vp_view_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
)
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
if 'traj_obj_img_fts' in batch:
batch['traj_vp_obj_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
)
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
# gmap batches: gmap_vpids
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
max_gmap_len = max(batch['gmap_lens'])
batch_size = len(batch['gmap_lens'])
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
for i in range(batch_size):
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
batch['gmap_pair_dists'] = gmap_pair_dists
# vp batches: vp_angles
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
# action labels
batch['local_act_labels'] = torch.LongTensor(batch['local_act_labels'])
batch['global_act_labels'] = torch.LongTensor(batch['global_act_labels'])
return batch
############### Object Grounding ###############
class OGDataset(Dataset):
def __init__(self, nav_db, tok):
self.nav_db = nav_db
self.tok = tok
def __len__(self):
return len(self.nav_db.data)
def __getitem__(self, idx):
inputs = self.nav_db.get_input(idx, 'pos', return_obj_label=True)
output = {}
output['txt_ids'] = torch.LongTensor(inputs['instr_encoding'])
output['traj_view_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_view_img_fts']]
output['traj_obj_img_fts'] = [torch.from_numpy(x) for x in inputs['traj_obj_img_fts']]
output['traj_loc_fts'] = [torch.from_numpy(x) for x in inputs['traj_loc_fts']]
output['traj_nav_types'] = [torch.LongTensor(x) for x in inputs['traj_nav_types']]
output['traj_cand_vpids'] = inputs['traj_cand_vpids']
output['traj_vpids'] = inputs['traj_vpids']
output['gmap_vpids'] = inputs['gmap_vpids']
output['gmap_step_ids'] = torch.LongTensor(inputs['gmap_step_ids'])
output['gmap_visited_masks'] = torch.BoolTensor(inputs['gmap_visited_masks'])
output['gmap_pos_fts'] = torch.from_numpy(inputs['gmap_pos_fts'])
output['gmap_pair_dists'] = torch.from_numpy(inputs['gmap_pair_dists'])
output['vp_pos_fts'] = torch.from_numpy(inputs['vp_pos_fts'])
output['vp_angles'] = inputs['vp_angles']
output['obj_labels'] = inputs['obj_labels']
return output
def og_collate(inputs):
batch = {
k: [x[k] for x in inputs] for k in inputs[0].keys()
}
# text batches
batch['txt_lens'] = torch.LongTensor([len(x) for x in batch['txt_ids']])
batch['txt_ids'] = pad_sequence(batch['txt_ids'], batch_first=True, padding_value=0)
# trajectory batches: traj_cand_vpids, traj_vpids
batch['traj_step_lens'] = [len(x) for x in batch['traj_view_img_fts']]
batch['traj_vp_view_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_view_img_fts']], [])
)
batch['traj_vp_obj_lens'] = torch.LongTensor(
sum([[len(y) for y in x] for x in batch['traj_obj_img_fts']], [])
)
batch['traj_view_img_fts'] = pad_tensors(sum(batch['traj_view_img_fts'], []))
batch['traj_obj_img_fts'] = pad_tensors(sum(batch['traj_obj_img_fts'], []))
batch['traj_loc_fts'] = pad_tensors(sum(batch['traj_loc_fts'], []))
batch['traj_nav_types'] = pad_sequence(sum(batch['traj_nav_types'], []), batch_first=True, padding_value=0)
# gmap batches: gmap_vpids
batch['gmap_lens'] = torch.LongTensor([len(x) for x in batch['gmap_step_ids']]) # included [stop]
batch['gmap_step_ids'] = pad_sequence(batch['gmap_step_ids'], batch_first=True, padding_value=0)
batch['gmap_visited_masks'] = pad_sequence(batch['gmap_visited_masks'], batch_first=True, padding_value=0)
batch['gmap_pos_fts'] = pad_tensors(batch['gmap_pos_fts'])
max_gmap_len = max(batch['gmap_lens'])
batch_size = len(batch['gmap_lens'])
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
for i in range(batch_size):
gmap_pair_dists[i, :batch['gmap_lens'][i], :batch['gmap_lens'][i]] = batch['gmap_pair_dists'][i]
batch['gmap_pair_dists'] = gmap_pair_dists
# vp batches: vp_angles
batch['vp_lens'] = torch.LongTensor([len(x[-1]) for x in batch['vp_pos_fts']]) # included [stop]
batch['vp_pos_fts'] = pad_tensors(batch['vp_pos_fts'])
# vp labels
batch['obj_labels'] = torch.LongTensor(batch['obj_labels'])
return batch