Compare commits

...

7 Commits

5 changed files with 80 additions and 18 deletions

View File

@ -27,7 +27,7 @@ class BaseAgent(object):
def get_results(self, detailed_output=False): def get_results(self, detailed_output=False):
output = [] output = []
for k, v in self.results.items(): for k, v in self.results.items():
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']}) output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
if detailed_output: if detailed_output:
output[-1]['details'] = v['details'] output[-1]['details'] = v['details']
return output return output

View File

@ -174,11 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent):
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types): def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
batch_size = len(obs) batch_size = len(obs)
# print("PANO shape", pano_embeds.shape)
# add [stop] token # add [stop] token & [NOT FOUND] token
# [STOP] 在最前面, [NOT FOUND] 在最後面
vp_img_embeds = torch.cat( vp_img_embeds = torch.cat(
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1 [torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
) )
# print("SHAPE:", vp_img_embeds.shape)
batch_vp_pos_fts = [] batch_vp_pos_fts = []
for i, gmap in enumerate(gmaps): for i, gmap in enumerate(gmaps):
@ -192,19 +195,33 @@ class GMapObjectNavAgent(Seq2SeqAgent):
) )
# add [stop] token at beginning # add [stop] token at beginning
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32) vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
# print("vp_pos_fts:", vp_pos_fts.shape)
vp_pos_fts[:, :7] = cur_start_pos_fts vp_pos_fts[:, :7] = cur_start_pos_fts
# print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
# print("cur_start_pos_fts:", cur_start_pos_fts.shape)
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
# print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
# print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts)) batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda() batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1) # 要把 stop 和 not found 的 mask 補上去
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2], 1) # 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
# print('vp_nav_masks:', vp_nav_masks.shape)
# print('vp_obj_masks:', vp_obj_masks.shape)
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
# print()
return { return {
'vp_img_embeds': vp_img_embeds, 'vp_img_embeds': vp_img_embeds,
'vp_pos_fts': batch_vp_pos_fts, 'vp_pos_fts': batch_vp_pos_fts,
'vp_masks': gen_seq_masks(view_lens+obj_lens+1), 'vp_masks': vp_masks,
'vp_nav_masks': vp_nav_masks, 'vp_nav_masks': vp_nav_masks,
'vp_obj_masks': vp_obj_masks, 'vp_obj_masks': vp_obj_masks,
'vp_cand_vpids': [[None]+x for x in cand_vpids], 'vp_cand_vpids': [[None]+x for x in cand_vpids],
@ -242,7 +259,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
return torch.from_numpy(a).cuda() return torch.from_numpy(a).cuda()
def _teacher_object(self, obs, ended, view_lens): def _teacher_object(self, obs, ended, view_lens, obj_logits):
targets = np.zeros(len(obs), dtype=np.int64) targets = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs): for i, ob in enumerate(obs):
if ended[i]: if ended[i]:
@ -252,12 +269,18 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if i_vp not in ob['gt_end_vps']: if i_vp not in ob['gt_end_vps']:
targets[i] = self.args.ignoreid targets[i] = self.args.ignoreid
else: else:
i_objids = ob['obj_ids'] i_objids = ob['obj_ids']
targets[i] = self.args.ignoreid targets[i] = self.args.ignoreid
for j, obj_id in enumerate(i_objids): for j, obj_id in enumerate(i_objids):
if str(obj_id) == str(ob['gt_obj_id']): if str(obj_id) == str(ob['gt_obj_id']):
targets[i] = j + view_lens[i] + 1
if ob['gt_found'] == True: # 可以找得到
targets[i] = j + view_lens[i] + 1
else:
targets[i] = len(obj_logits[i])-1 # 不能找到,
break break
return torch.from_numpy(targets).cuda() return torch.from_numpy(targets).cuda()
def make_equiv_action(self, a_t, gmaps, obs, traj=None): def make_equiv_action(self, a_t, gmaps, obs, traj=None):
@ -298,7 +321,8 @@ class GMapObjectNavAgent(Seq2SeqAgent):
batch_size = len(obs) batch_size = len(obs)
# build graph: keep the start viewpoint # build graph: keep the start viewpoint
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point
for i, ob in enumerate(obs): for i, ob in enumerate(obs):
gmaps[i].update_graph(ob) gmaps[i].update_graph(ob)
@ -307,6 +331,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
'instr_id': ob['instr_id'], 'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]], 'path': [[ob['viewpoint']]],
'pred_objid': None, 'pred_objid': None,
'gt_objid': None,
'found': None,
'gt_found': None,
'details': {}, 'details': {},
} for ob in obs] } for ob in obs]
@ -379,10 +406,22 @@ class GMapObjectNavAgent(Seq2SeqAgent):
i_vp = obs[i]['viewpoint'] i_vp = obs[i]['viewpoint']
# update i_vp: stop and object grounding scores # update i_vp: stop and object grounding scores
i_objids = obs[i]['obj_ids'] i_objids = obs[i]['obj_ids']
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found
if len(i_objids) > 0:
if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格)
og = -1
else:
og = i_objids[torch.argmax(i_obj_logits)]
else:
og = None
# 如果有找到og 會是 object id
# 如果是 not foundog 會是 -1
# 如果這個 viewpoint 看不到物件og 會是 None
gmap.node_stop_scores[i_vp] = { gmap.node_stop_scores[i_vp] = {
'stop': nav_probs[i, 0].data.item(), 'stop': nav_probs[i, 0].data.item(),
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None, 'og': og,
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]}, 'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
} }
@ -403,7 +442,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
) )
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding # objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens']) obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id']) # print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets) og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none')) # print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
@ -451,9 +490,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
else: else:
cpu_a_t.append(nav_vpids[i][a_t[i]]) cpu_a_t.append(nav_vpids[i][a_t[i]])
original_gt_founds = [ ob['gt_found'] for ob in obs ]
# Make action and get the new state # Make action and get the new state
self.make_equiv_action(cpu_a_t, gmaps, obs, traj) self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
for i in range(batch_size): for i in range(batch_size):
traj[i]['gt_found'] = original_gt_founds[i]
if (not ended[i]) and just_ended[i]: if (not ended[i]) and just_ended[i]:
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None} stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
for k, v in gmaps[i].node_stop_scores.items(): for k, v in gmaps[i].node_stop_scores.items():
@ -463,6 +504,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if stop_node is not None and obs[i]['viewpoint'] != stop_node: if stop_node is not None and obs[i]['viewpoint'] != stop_node:
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node)) traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
traj[i]['pred_objid'] = stop_score['og'] traj[i]['pred_objid'] = stop_score['og']
if stop_score['og'] == -1:
traj[i]['found'] = False
else:
traj[i]['found'] = True
if self.args.detailed_output: if self.args.detailed_output:
for k, v in gmaps[i].node_stop_scores.items(): for k, v in gmaps[i].node_stop_scores.items():
traj[i]['details'][k] = { traj[i]['details'][k] = {
@ -492,4 +537,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
self.logs['IL_loss'].append(ml_loss.item()) self.logs['IL_loss'].append(ml_loss.item())
self.logs['OG_loss'].append(og_loss.item()) self.logs['OG_loss'].append(og_loss.item())
'''
print("TRAJ:")
for i in traj:
print(" GT: {}, PREDICT: {}, SCORE: {}".format(i['gt_found'], i['found'], 1 if i['gt_found']==i['found'] else 0))
'''
return traj return traj

View File

@ -87,6 +87,7 @@ def construct_instrs(anno_dir, dataset, splits, tokenizer, max_instr_len=512):
new_item['objId'] = None new_item['objId'] = None
new_item['instruction'] = instr new_item['instruction'] = instr
new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len] new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len]
new_item['found'] = item['found'][j]
del new_item['instructions'] del new_item['instructions']
del new_item['instr_encodings'] del new_item['instr_encodings']
data.append(new_item) data.append(new_item)

View File

@ -311,6 +311,7 @@ class ReverieObjectNavBatch(object):
'navigableLocations' : state.navigableLocations, 'navigableLocations' : state.navigableLocations,
'instruction' : item['instruction'], 'instruction' : item['instruction'],
'instr_encoding': item['instr_encoding'], 'instr_encoding': item['instr_encoding'],
'gt_found' : item['found'],
'gt_path' : item['path'], 'gt_path' : item['path'],
'gt_end_vps': item.get('end_vps', []), 'gt_end_vps': item.get('end_vps', []),
'gt_obj_id': item['objId'], 'gt_obj_id': item['objId'],
@ -351,7 +352,7 @@ class ReverieObjectNavBatch(object):
############### Nav Evaluation ############### ############### Nav Evaluation ###############
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid): def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid, pred_found, gt_found):
scores = {} scores = {}
shortest_distances = self.shortest_distances[scan] shortest_distances = self.shortest_distances[scan]
@ -369,8 +370,10 @@ class ReverieObjectNavBatch(object):
assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid)) assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid))
scores['success'] = float(path[-1] in goal_viewpoints) scores['success'] = float(path[-1] in goal_viewpoints)
scores['found_success'] = float(pred_found == gt_found)
scores['oracle_success'] = float(any(x in goal_viewpoints for x in path)) scores['oracle_success'] = float(any(x in goal_viewpoints for x in path))
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
scores['sspl'] = scores['spl'] * scores['found_success']
scores['rgs'] = str(pred_objid) == str(gt_objid) scores['rgs'] = str(pred_objid) == str(gt_objid)
scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
@ -380,6 +383,7 @@ class ReverieObjectNavBatch(object):
''' Evaluate each agent trajectory based on how close it got to the goal location ''' Evaluate each agent trajectory based on how close it got to the goal location
the path contains [view_id, angle, vofv]''' the path contains [view_id, angle, vofv]'''
print('eval %d predictions' % (len(preds))) print('eval %d predictions' % (len(preds)))
print(preds[0])
metrics = defaultdict(list) metrics = defaultdict(list)
for item in preds: for item in preds:
@ -387,7 +391,9 @@ class ReverieObjectNavBatch(object):
traj = item['trajectory'] traj = item['trajectory']
pred_objid = item.get('pred_objid', None) pred_objid = item.get('pred_objid', None)
scan, gt_traj, gt_objid = self.gt_trajs[instr_id] scan, gt_traj, gt_objid = self.gt_trajs[instr_id]
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid) pred_found = item['found']
gt_found = item['gt_found']
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid, pred_found, gt_found)
for k, v in traj_scores.items(): for k, v in traj_scores.items():
metrics[k].append(v) metrics[k].append(v)
metrics['instr_id'].append(instr_id) metrics['instr_id'].append(instr_id)
@ -401,6 +407,8 @@ class ReverieObjectNavBatch(object):
'spl': np.mean(metrics['spl']) * 100, 'spl': np.mean(metrics['spl']) * 100,
'rgs': np.mean(metrics['rgs']) * 100, 'rgs': np.mean(metrics['rgs']) * 100,
'rgspl': np.mean(metrics['rgspl']) * 100, 'rgspl': np.mean(metrics['rgspl']) * 100,
'sspl': np.mean(metrics['sspl']) * 100,
'found_sr': np.mean(metrics['found_success']) * 100,
} }
return avg_metrics, metrics return avg_metrics, metrics

View File

@ -65,8 +65,8 @@ def build_dataset(args, rank=0):
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints, multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
) )
# val_env_names = ['val_train_seen'] # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] val_env_names = [ 'val_seen', 'val_unseen']
if args.submit: if args.submit:
val_env_names.append('test') val_env_names.append('test')
@ -136,7 +136,7 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
'\nListener training starts, start iteration: %s' % str(start_iter), record_file '\nListener training starts, start iteration: %s' % str(start_iter), record_file
) )
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}} best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", "sspl": 0., 'found_sr': 0.}}
for idx in range(start_iter, start_iter+args.iters, args.log_every): for idx in range(start_iter, start_iter+args.iters, args.log_every):
listner.logs = defaultdict(list) listner.logs = defaultdict(list)
@ -201,9 +201,11 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
# select model by spl # select model by spl
if env_name in best_val: if env_name in best_val:
if score_summary['spl'] >= best_val[env_name]['spl']: if score_summary['sspl'] >= best_val[env_name]['sspl']:
best_val[env_name]['spl'] = score_summary['spl'] best_val[env_name]['spl'] = score_summary['spl']
best_val[env_name]['sspl'] = score_summary['sspl']
best_val[env_name]['sr'] = score_summary['sr'] best_val[env_name]['sr'] = score_summary['sr']
best_val[env_name]['found_sr'] = score_summary['found_sr']
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str) best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name))) listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))