adversarial_VLNDUET/pretrain_src/data/dataset.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

573 lines
25 KiB
Python

'''
Instruction and trajectory (view and object features) dataset
'''
import os
import json
import jsonlines
import numpy as np
import h5py
import math
from .common import load_nav_graphs
from .common import get_angle_fts, get_view_rel_angles
from .common import calculate_vp_rel_pos_fts
from .common import softmax
MAX_DIST = 30 # normalize
MAX_STEP = 10 # normalize
TRAIN_MAX_STEP = 20
class ReverieTextPathData(object):
def __init__(
self, anno_files, img_ft_file, obj_ft_file, scanvp_cands_file, connectivity_dir,
image_feat_size=2048, image_prob_size=1000, angle_feat_size=4,
obj_feat_size=None, obj_prob_size=None, max_objects=20,
max_txt_len=100, in_memory=True, act_visited_node=False
):
self.img_ft_file = img_ft_file
self.obj_ft_file = obj_ft_file
self.image_feat_size = image_feat_size
self.image_prob_size = image_prob_size
self.angle_feat_size = angle_feat_size
self.obj_feat_size = obj_feat_size
self.obj_prob_size = obj_prob_size
self.obj_image_h = 480
self.obj_image_w = 640
self.obj_image_size = 480 * 640
self.max_txt_len = max_txt_len
self.max_objects = max_objects
self.act_visited_node = act_visited_node
self.in_memory = in_memory
if self.in_memory:
self._feature_store = {}
# {scan_vp: {vp: [viewidx, rel_angle_dist, rel_heading, rel_elevation]}}
self.scanvp_cands = json.load(open(scanvp_cands_file))
self.graphs, self.shortest_distances, self.shortest_paths = load_nav_graphs(connectivity_dir)
self.all_point_rel_angles = [get_view_rel_angles(baseViewId=i) for i in range(36)]
self.all_point_angle_fts = [get_angle_fts(x[:, 0], x[:, 1], self.angle_feat_size) for x in self.all_point_rel_angles]
self.data = []
for anno_file in anno_files:
with jsonlines.open(anno_file, 'r') as f:
for item in f:
self.data.append(item)
def __len__(self):
return len(self.data)
def get_scanvp_feature(self, scan, viewpoint):
key = '%s_%s' % (scan, viewpoint)
if self.in_memory and key in self._feature_store:
view_fts, obj_fts, obj_attrs = self._feature_store[key]
else:
with h5py.File(self.img_ft_file, 'r') as f:
view_fts = f[key][...].astype(np.float32)
obj_attrs = {}
obj_fts = np.zeros((0, self.obj_feat_size+self.obj_prob_size), dtype=np.float32)
if self.obj_ft_file is not None:
with h5py.File(self.obj_ft_file, 'r') as f:
if key in f:
obj_fts = f[key][...].astype(np.float32)
obj_fts = obj_fts[:self.max_objects]
for attr_key, attr_value in f[key].attrs.items():
if attr_key in ['directions', 'sizes', 'bboxes', 'obj_ids']:
obj_attrs[attr_key] = attr_value[:self.max_objects]
if self.in_memory:
self._feature_store[key] = (view_fts, obj_fts, obj_attrs)
return view_fts, obj_fts, obj_attrs
def get_obj_label(self, item, last_vp_objids):
gt_obj_id = item['instr_id'].split('_')[1]
for k, obj_id in enumerate(last_vp_objids):
if obj_id == gt_obj_id:
obj_label = k
break
else:
# it occurs when the gt_objid is not in max_objects
obj_label = -100 # ignore
# print('No groundtruth obj_id', item['instr_id'], len(obj_ids))
return obj_label
def get_act_labels(self, end_vp, item, gmap_vpids, gmap_visited_masks, traj_cand_vpids):
scan = item['scan']
pos_vps = item['pos_vps']
if end_vp in pos_vps:
global_act_label = local_act_label = 0
else:
global_act_label = local_act_label = -100
# global: unvisited vp
cand_min_dist = float('inf')
for k, cand_vp in enumerate(gmap_vpids):
if (k > 0) and (not gmap_visited_masks[k]):
min_dist = min([self.shortest_distances[scan][end_vp][cand_vp] \
+ self.shortest_distances[scan][cand_vp][pos_vp] for pos_vp in pos_vps])
if min_dist < cand_min_dist:
cand_min_dist = min_dist
global_act_label = k # [stop] is 0
# local:
cand_min_dist = float('inf')
for k, cand_vp in enumerate(traj_cand_vpids[-1]):
min_dist = min([self.shortest_distances[scan][end_vp][cand_vp] \
+ self.shortest_distances[scan][cand_vp][pos_vp] for pos_vp in pos_vps])
if min_dist < cand_min_dist:
cand_min_dist = min_dist
local_act_label = k + 1 # [stop] is 0
return global_act_label, local_act_label
def get_input(
self, idx, end_vp_type, return_img_probs=False, return_act_label=False,
return_obj_label=False, end_vp=None
):
item = self.data[idx]
scan = item['scan']
start_vp = item['path'][0]
start_heading = item.get('heading', 0)
pos_vps = item['pos_vps']
gt_path = item['path']
if end_vp is None:
if end_vp_type == 'pos':
end_vp = pos_vps[np.random.randint(len(pos_vps))]
elif end_vp_type == 'neg_in_gt_path':
end_vps = [vp for vp in gt_path if vp not in pos_vps]
if len(end_vps) == 0:
end_vps = gt_path
end_vp = end_vps[np.random.randint(len(end_vps))]
elif end_vp_type == 'neg_others':
noneg_vp_set = set(pos_vps + gt_path)
end_vps = [vp for vp in self.graphs[scan].nodes.keys() if vp not in noneg_vp_set]
end_vp = end_vps[np.random.randint(len(end_vps))]
gt_path = self.shortest_paths[scan][start_vp][end_vp]
cur_heading, cur_elevation = self.get_cur_angle(scan, gt_path, start_heading)
if len(gt_path) > TRAIN_MAX_STEP:
# truncate trajectory
gt_path = gt_path[:TRAIN_MAX_STEP] + [end_vp]
traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types, traj_cand_vpids, \
last_vp_angles, last_vp_objids = self.get_traj_pano_fts(scan, gt_path)
# global: the first token is [stop]
gmap_vpids, gmap_step_ids, gmap_visited_masks, gmap_pos_fts, gmap_pair_dists = \
self.get_gmap_inputs(scan, gt_path, cur_heading, cur_elevation)
# local: the first token is [stop]
vp_pos_fts = self.get_vp_pos_fts(scan, start_vp, end_vp,
traj_cand_vpids[-1], cur_heading, cur_elevation, len(traj_nav_types[-1]))
outs = {
'instr_id': item['instr_id'],
'instr_encoding': item['instr_encoding'][:self.max_txt_len],
'traj_view_img_fts': [x[:, :self.image_feat_size] for x in traj_view_img_fts],
'traj_obj_img_fts': [x[:, :self.obj_feat_size] for x in traj_obj_img_fts],
'traj_loc_fts': traj_loc_fts,
'traj_nav_types': traj_nav_types,
'traj_cand_vpids': traj_cand_vpids,
'traj_vpids': gt_path,
'gmap_vpids': gmap_vpids,
'gmap_step_ids': gmap_step_ids,
'gmap_visited_masks': gmap_visited_masks,
'gmap_pos_fts': gmap_pos_fts,
'gmap_pair_dists': gmap_pair_dists,
'vp_pos_fts': vp_pos_fts,
# 'vp_objids': last_vp_objids,
'vp_angles': last_vp_angles,
}
if return_obj_label:
outs['obj_labels'] = self.get_obj_label(item, last_vp_objids)
if return_act_label:
global_act_label, local_act_label = self.get_act_labels(
end_vp, item, gmap_vpids, gmap_visited_masks, traj_cand_vpids
)
outs['global_act_labels'] = global_act_label
outs['local_act_labels'] = local_act_label
if return_img_probs:
# TODO: whether adding gmap img probs
outs['vp_view_probs'] = softmax(traj_view_img_fts[-1][:, self.image_feat_size:], dim=1)
outs['vp_obj_probs'] = softmax(traj_obj_img_fts[-1][:, self.obj_feat_size:], dim=1)
return outs
def get_cur_angle(self, scan, path, start_heading):
if len(path) < 2:
heading = start_heading
elevation = 0
else:
prev_vp = path[-2]
cur_vp = path[-1]
viewidx = self.scanvp_cands['%s_%s'%(scan, prev_vp)][cur_vp][0]
heading = (viewidx % 12) * math.radians(30)
elevation = (viewidx // 12 - 1) * math.radians(30)
return heading, elevation
def get_traj_pano_fts(self, scan, path):
'''
Tokens in each pano: [cand_views, noncand_views, objs]
Each token consists of (img_fts, loc_fts (ang_fts, box_fts), nav_types)
'''
traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types, traj_cand_vpids = [], [], [], [], []
for vp in path:
view_fts, obj_img_fts, obj_attrs = self.get_scanvp_feature(scan, vp)
view_img_fts, view_angles, cand_vpids = [], [], []
# cand views
nav_cands = self.scanvp_cands['%s_%s'%(scan, vp)]
used_viewidxs = set()
for k, v in nav_cands.items():
used_viewidxs.add(v[0])
view_img_fts.append(view_fts[v[0]])
# TODO: whether using correct heading at each step
view_angle = self.all_point_rel_angles[12][v[0]]
view_angles.append([view_angle[0] + v[2], view_angle[1] + v[3]])
cand_vpids.append(k)
# non cand views
view_img_fts.extend([view_fts[idx] for idx in range(36) if idx not in used_viewidxs])
view_angles.extend([self.all_point_rel_angles[12][idx] for idx in range(36) if idx not in used_viewidxs])
# combine cand views and noncand views
view_img_fts = np.stack(view_img_fts, 0) # (n_views, dim_ft)
view_angles = np.stack(view_angles, 0)
view_ang_fts = get_angle_fts(view_angles[:, 0], view_angles[:, 1], self.angle_feat_size)
view_box_fts = np.array([[1, 1, 1]] * len(view_img_fts)).astype(np.float32)
# object features
num_objs = obj_img_fts.shape[0]
obj_angles = np.zeros((num_objs, 2), dtype=np.float32)
obj_ang_fts = np.zeros((num_objs, self.angle_feat_size), dtype=np.float32)
obj_box_fts = np.zeros((num_objs, 3), dtype=np.float32)
if num_objs > 0:
for k, (w, h) in enumerate(obj_attrs['sizes']):
obj_angles[k] = obj_attrs['directions'][k]
obj_box_fts[k] = [h/self.obj_image_h, w/self.obj_image_w, (h*w)/self.obj_image_size]
obj_ang_fts = get_angle_fts(obj_angles[:, 0], obj_angles[:, 1], self.angle_feat_size)
# combine pano features
traj_view_img_fts.append(view_img_fts)
traj_obj_img_fts.append(obj_img_fts)
traj_loc_fts.append(
np.concatenate(
[np.concatenate([view_ang_fts, view_box_fts], 1),
np.concatenate([obj_ang_fts, obj_box_fts], 1)], axis=0
)
)
traj_nav_types.append(
[1] * len(cand_vpids) + [0] * (36 - len(used_viewidxs)) + [2] * len(obj_img_fts)
)
traj_cand_vpids.append(cand_vpids)
last_vp_objids = obj_attrs.get('obj_ids', [])
last_vp_angles = np.concatenate([view_angles, obj_angles], 0)
return traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types, traj_cand_vpids, \
last_vp_angles, last_vp_objids
def get_gmap_inputs(self, scan, path, cur_heading, cur_elevation):
scan_graph = self.graphs[scan]
cur_vp = path[-1]
visited_vpids, unvisited_vpids = {}, {}
for t, vp in enumerate(path):
visited_vpids[vp] = t + 1
if vp in unvisited_vpids:
del unvisited_vpids[vp]
for next_vp in self.scanvp_cands['%s_%s'%(scan, vp)].keys():
if next_vp not in visited_vpids:
unvisited_vpids[next_vp] = 0
# add [stop] token
gmap_vpids = [None] + list(visited_vpids.keys()) + list(unvisited_vpids.keys())
gmap_step_ids = [0] + list(visited_vpids.values()) + list(unvisited_vpids.values())
if self.act_visited_node:
gmap_visited_masks = [0]
for vp in gmap_vpids[1:]:
if vp == path[-1]:
gmap_visited_masks.append(1)
else:
gmap_visited_masks.append(0)
else:
gmap_visited_masks = [0] + [1] * len(visited_vpids) + [0] * len(unvisited_vpids)
# shape=(num_gmap_vpids, 7)
gmap_pos_fts = self.get_gmap_pos_fts(scan, cur_vp, gmap_vpids, cur_heading, cur_elevation)
gmap_pair_dists = np.zeros((len(gmap_vpids), len(gmap_vpids)), dtype=np.float32)
for i in range(1, len(gmap_vpids)):
for j in range(i+1, len(gmap_vpids)):
gmap_pair_dists[i, j] = gmap_pair_dists[j, i] = \
self.shortest_distances[scan][gmap_vpids[i]][gmap_vpids[j]]
return gmap_vpids, gmap_step_ids, gmap_visited_masks, gmap_pos_fts, gmap_pair_dists
def get_gmap_pos_fts(self, scan, cur_vp, gmap_vpids, cur_heading, cur_elevation):
# dim=7 (sin(heading), cos(heading), sin(elevation), cos(elevation),
# line_dist, shortest_dist, shortest_step)
rel_angles, rel_dists = [], []
for vp in gmap_vpids:
if vp is None:
rel_angles.append([0, 0])
rel_dists.append([0, 0, 0])
else:
rel_heading, rel_elevation, rel_dist = calculate_vp_rel_pos_fts(
self.graphs[scan].nodes[cur_vp]['position'],
self.graphs[scan].nodes[vp]['position'],
base_heading=cur_heading, base_elevation=cur_elevation,
)
rel_angles.append([rel_heading, rel_elevation])
rel_dists.append(
[rel_dist / MAX_DIST, self.shortest_distances[scan][cur_vp][vp] / MAX_DIST, \
(len(self.shortest_paths[scan][cur_vp][vp]) - 1) / MAX_STEP]
)
rel_angles = np.array(rel_angles).astype(np.float32)
rel_dists = np.array(rel_dists).astype(np.float32)
rel_ang_fts = get_angle_fts(rel_angles[:, 0], rel_angles[:, 1], self.angle_feat_size)
return np.concatenate([rel_ang_fts, rel_dists], 1)
def get_vp_pos_fts(self, scan, start_vp, cur_vp, cand_vpids, cur_heading, cur_elevation, vp_ft_len):
cur_cand_pos_fts = self.get_gmap_pos_fts(scan, cur_vp, cand_vpids, cur_heading, cur_elevation)
cur_start_pos_fts = self.get_gmap_pos_fts(scan, cur_vp, [start_vp], cur_heading, cur_elevation)
# add [stop] token at beginning
vp_pos_fts = np.zeros((vp_ft_len+1, 14), dtype=np.float32)
vp_pos_fts[:, :7] = cur_start_pos_fts
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
return vp_pos_fts
class R2RTextPathData(ReverieTextPathData):
def __init__(
self, anno_files, img_ft_file, scanvp_cands_file, connectivity_dir,
image_feat_size=2048, image_prob_size=1000, angle_feat_size=4,
max_txt_len=100, in_memory=True, act_visited_node=False
):
super().__init__(
anno_files, img_ft_file, None, scanvp_cands_file, connectivity_dir,
image_feat_size=image_feat_size, image_prob_size=image_prob_size,
angle_feat_size=angle_feat_size, obj_feat_size=0, obj_prob_size=0,
max_objects=0, max_txt_len=max_txt_len, in_memory=in_memory,
act_visited_node=act_visited_node
)
def get_scanvp_feature(self, scan, viewpoint):
key = '%s_%s' % (scan, viewpoint)
if self.in_memory and key in self._feature_store:
view_fts = self._feature_store[key]
else:
with h5py.File(self.img_ft_file, 'r') as f:
view_fts = f[key][...].astype(np.float32)
if self.in_memory:
self._feature_store[key] = view_fts
return view_fts
def get_act_labels(self, end_vp, end_idx, item, gmap_vpids, traj_cand_vpids):
if end_vp == item['path'][-1]: # stop
global_act_label = local_act_label = 0
else:
global_act_label = local_act_label = -100
# global: unvisited vp
gt_next_vp = item['path'][end_idx + 1]
for k, cand_vp in enumerate(gmap_vpids):
if cand_vp == gt_next_vp:
global_act_label = k
break
# local:
for k, cand_vp in enumerate(traj_cand_vpids[-1]):
if cand_vp == gt_next_vp:
local_act_labels = k + 1 # [stop] is 0
break
return global_act_label, local_act_label
def get_input(
self, idx, end_vp_type, return_img_probs=False, return_act_label=False, end_vp=None
):
item = self.data[idx]
scan = item['scan']
start_vp = item['path'][0]
start_heading = item['heading']
gt_path = item['path']
if end_vp is None:
if end_vp_type == 'pos':
# name convention with REVERIE (last vp)
end_idx = len(gt_path) - 1
end_vp = gt_path[-1]
elif end_vp_type in ['neg_in_gt_path', 'neg_others']:
# name convention with REVERIE (mid vps in the path)
end_vps = gt_path[:-1]
end_idx = np.random.randint(len(end_vps))
end_vp = end_vps[end_idx]
else:
assert end_vp in gt_path
end_idx = gt_path.index(end_vp)
gt_path = gt_path[:end_idx+1]
cur_heading, cur_elevation = self.get_cur_angle(scan, gt_path, start_heading)
if len(gt_path) > TRAIN_MAX_STEP:
# truncate trajectory
gt_path = gt_path[:TRAIN_MAX_STEP] + [end_vp]
traj_view_img_fts, traj_loc_fts, traj_nav_types, traj_cand_vpids, \
last_vp_angles = self.get_traj_pano_fts(scan, gt_path)
# global: the first token is [stop]
gmap_vpids, gmap_step_ids, gmap_visited_masks, gmap_pos_fts, gmap_pair_dists = \
self.get_gmap_inputs(scan, gt_path, cur_heading, cur_elevation)
# local: the first token is [stop]
vp_pos_fts = self.get_vp_pos_fts(scan, start_vp, end_vp,
traj_cand_vpids[-1], cur_heading, cur_elevation, len(traj_nav_types[-1]))
outs = {
'instr_id': item['instr_id'],
'instr_encoding': item['instr_encoding'][:self.max_txt_len],
'traj_view_img_fts': [x[:, :self.image_feat_size] for x in traj_view_img_fts],
'traj_loc_fts': traj_loc_fts,
'traj_nav_types': traj_nav_types,
'traj_cand_vpids': traj_cand_vpids,
'traj_vpids': gt_path,
'gmap_vpids': gmap_vpids,
'gmap_step_ids': gmap_step_ids,
'gmap_visited_masks': gmap_visited_masks,
'gmap_pos_fts': gmap_pos_fts,
'gmap_pair_dists': gmap_pair_dists,
'vp_pos_fts': vp_pos_fts,
'vp_angles': last_vp_angles,
}
if return_act_label:
global_act_label, local_act_label = self.get_act_labels(
end_vp, end_idx, item, gmap_vpids, traj_cand_vpids
)
outs['global_act_labels'] = global_act_label
outs['local_act_labels'] = local_act_label
if return_img_probs:
# TODO: whether adding gmap img probs
outs['vp_view_probs'] = softmax(traj_view_img_fts[-1][:, self.image_feat_size:], dim=1)
return outs
def get_traj_pano_fts(self, scan, path):
'''
Tokens in each pano: [cand_views, noncand_views, objs]
Each token consists of (img_fts, loc_fts (ang_fts, box_fts), nav_types)
'''
traj_view_img_fts, traj_loc_fts, traj_nav_types, traj_cand_vpids = [], [], [], []
for vp in path:
view_fts = self.get_scanvp_feature(scan, vp)
view_img_fts, view_angles, cand_vpids = [], [], []
# cand views
nav_cands = self.scanvp_cands['%s_%s'%(scan, vp)]
used_viewidxs = set()
for k, v in nav_cands.items():
used_viewidxs.add(v[0])
view_img_fts.append(view_fts[v[0]])
# TODO: whether using correct heading at each step
view_angle = self.all_point_rel_angles[12][v[0]]
view_angles.append([view_angle[0] + v[2], view_angle[1] + v[3]])
cand_vpids.append(k)
# non cand views
view_img_fts.extend([view_fts[idx] for idx in range(36) if idx not in used_viewidxs])
view_angles.extend([self.all_point_rel_angles[12][idx] for idx in range(36) if idx not in used_viewidxs])
# combine cand views and noncand views
view_img_fts = np.stack(view_img_fts, 0) # (n_views, dim_ft)
view_angles = np.stack(view_angles, 0)
view_ang_fts = get_angle_fts(view_angles[:, 0], view_angles[:, 1], self.angle_feat_size)
view_box_fts = np.array([[1, 1, 1]] * len(view_img_fts)).astype(np.float32)
# combine pano features
traj_view_img_fts.append(view_img_fts)
traj_loc_fts.append(np.concatenate([view_ang_fts, view_box_fts], 1))
traj_nav_types.append([1] * len(cand_vpids) + [0] * (36 - len(used_viewidxs)))
traj_cand_vpids.append(cand_vpids)
last_vp_angles = view_angles
return traj_view_img_fts, traj_loc_fts, traj_nav_types, traj_cand_vpids, last_vp_angles
class SoonTextPathData(ReverieTextPathData):
def __init__(
self, anno_files, img_ft_file, obj_ft_file, scanvp_cands_file, connectivity_dir,
image_feat_size=2048, image_prob_size=1000, angle_feat_size=4,
obj_feat_size=None, obj_prob_size=None, max_objects=20,
max_txt_len=100, in_memory=True, act_visited_node=False
):
super().__init__(
anno_files, img_ft_file, obj_ft_file, scanvp_cands_file, connectivity_dir,
image_feat_size=image_feat_size, image_prob_size=image_prob_size,
angle_feat_size=angle_feat_size, obj_feat_size=obj_feat_size,
obj_prob_size=obj_prob_size, max_objects=max_objects,
max_txt_len=max_txt_len, in_memory=in_memory,
act_visited_node=act_visited_node
)
self.obj_image_h = self.obj_image_w = 600
self.obj_image_size = 600 * 600
def get_scanvp_feature(self, scan, viewpoint):
key = '%s_%s' % (scan, viewpoint)
if self.in_memory and key in self._feature_store:
view_fts, obj_fts, obj_attrs = self._feature_store[key]
else:
with h5py.File(self.img_ft_file, 'r') as f:
view_fts = f[key][...].astype(np.float32)
obj_attrs = {}
obj_fts = np.zeros((0, self.obj_feat_size+self.obj_prob_size), dtype=np.float32)
if self.obj_ft_file is not None:
with h5py.File(self.obj_ft_file, 'r') as f:
if key in f:
obj_fts = f[key][...].astype(np.float32)
obj_fts = obj_fts[:self.max_objects]
for attr_key, attr_value in f[key].attrs.items():
if attr_key in ['directions', 'bboxes', 'obj_ids']:
obj_attrs[attr_key] = attr_value[:self.max_objects]
obj_attrs['bboxes'] = np.array(obj_attrs['bboxes']).astype(np.float32)
obj_attrs['sizes'] = np.zeros((len(obj_attrs['bboxes']), 2), dtype=np.float32)
obj_attrs['sizes'][:, 0] = obj_attrs['bboxes'][:, 2] - obj_attrs['bboxes'][:, 0]
obj_attrs['sizes'][:, 1] = obj_attrs['bboxes'][:, 3] - obj_attrs['bboxes'][:, 1]
if self.in_memory:
self._feature_store[key] = (view_fts, obj_fts, obj_attrs)
return view_fts, obj_fts, obj_attrs
def get_obj_label(self, item, last_vp_objids):
obj_label = item['obj_pseudo_label']['idx']
if obj_label >= self.max_objects:
obj_label = -100
return obj_label
def get_input(
self, idx, end_vp_type, return_img_probs=False, return_act_label=False,
return_obj_label=False, end_vp=None
):
if end_vp_type == 'pos':
end_vp = self.data[idx]['path'][-1]
return super().get_input(
idx, end_vp_type,
return_img_probs=return_img_probs,
return_act_label=return_act_label,
return_obj_label=return_obj_label,
end_vp=end_vp
)