Compare commits

..

No commits in common. "b2dce6111efffd728419a83bfb6cc46070bd179b" and "93e8b233164bc079a6db48b8a0a78d123ec8de41" have entirely different histories.

5 changed files with 18 additions and 80 deletions

View File

@ -27,7 +27,7 @@ class BaseAgent(object):
def get_results(self, detailed_output=False):
output = []
for k, v in self.results.items():
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']})
if detailed_output:
output[-1]['details'] = v['details']
return output

View File

@ -174,14 +174,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
batch_size = len(obs)
# print("PANO shape", pano_embeds.shape)
# add [stop] token & [NOT FOUND] token
# [STOP] 在最前面, [NOT FOUND] 在最後面
# add [stop] token
vp_img_embeds = torch.cat(
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1
)
# print("SHAPE:", vp_img_embeds.shape)
batch_vp_pos_fts = []
for i, gmap in enumerate(gmaps):
@ -195,33 +192,19 @@ class GMapObjectNavAgent(Seq2SeqAgent):
)
# add [stop] token at beginning
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
# print("vp_pos_fts:", vp_pos_fts.shape)
vp_pos_fts[:, :7] = cur_start_pos_fts
# print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
# print("cur_start_pos_fts:", cur_start_pos_fts.shape)
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
# print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
# print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()
# 要把 stop 和 not found 的 mask 補上去
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
# print('vp_nav_masks:', vp_nav_masks.shape)
# print('vp_obj_masks:', vp_obj_masks.shape)
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
# print()
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2], 1)
return {
'vp_img_embeds': vp_img_embeds,
'vp_pos_fts': batch_vp_pos_fts,
'vp_masks': vp_masks,
'vp_masks': gen_seq_masks(view_lens+obj_lens+1),
'vp_nav_masks': vp_nav_masks,
'vp_obj_masks': vp_obj_masks,
'vp_cand_vpids': [[None]+x for x in cand_vpids],
@ -259,7 +242,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
return torch.from_numpy(a).cuda()
def _teacher_object(self, obs, ended, view_lens, obj_logits):
def _teacher_object(self, obs, ended, view_lens):
targets = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if ended[i]:
@ -269,18 +252,12 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if i_vp not in ob['gt_end_vps']:
targets[i] = self.args.ignoreid
else:
i_objids = ob['obj_ids']
targets[i] = self.args.ignoreid
for j, obj_id in enumerate(i_objids):
if str(obj_id) == str(ob['gt_obj_id']):
if ob['gt_found'] == True: # 可以找得到
targets[i] = j + view_lens[i] + 1
else:
targets[i] = len(obj_logits[i])-1 # 不能找到,
break
return torch.from_numpy(targets).cuda()
def make_equiv_action(self, a_t, gmaps, obs, traj=None):
@ -321,8 +298,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
batch_size = len(obs)
# build graph: keep the start viewpoint
gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
for i, ob in enumerate(obs):
gmaps[i].update_graph(ob)
@ -331,9 +307,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]],
'pred_objid': None,
'gt_objid': None,
'found': None,
'gt_found': None,
'details': {},
} for ob in obs]
@ -406,22 +379,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
i_vp = obs[i]['viewpoint']
# update i_vp: stop and object grounding scores
i_objids = obs[i]['obj_ids']
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found
if len(i_objids) > 0:
if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格)
og = -1
else:
og = i_objids[torch.argmax(i_obj_logits)]
else:
og = None
# 如果有找到og 會是 object id
# 如果是 not foundog 會是 -1
# 如果這個 viewpoint 看不到物件og 會是 None
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:]
gmap.node_stop_scores[i_vp] = {
'stop': nav_probs[i, 0].data.item(),
'og': og,
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
}
@ -442,7 +403,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
)
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'])
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
@ -490,11 +451,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
else:
cpu_a_t.append(nav_vpids[i][a_t[i]])
original_gt_founds = [ ob['gt_found'] for ob in obs ]
# Make action and get the new state
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
for i in range(batch_size):
traj[i]['gt_found'] = original_gt_founds[i]
if (not ended[i]) and just_ended[i]:
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
for k, v in gmaps[i].node_stop_scores.items():
@ -504,10 +463,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
traj[i]['pred_objid'] = stop_score['og']
if stop_score['og'] == -1:
traj[i]['found'] = False
else:
traj[i]['found'] = True
if self.args.detailed_output:
for k, v in gmaps[i].node_stop_scores.items():
traj[i]['details'][k] = {
@ -537,10 +492,4 @@ class GMapObjectNavAgent(Seq2SeqAgent):
self.logs['IL_loss'].append(ml_loss.item())
self.logs['OG_loss'].append(og_loss.item())
'''
print("TRAJ:")
for i in traj:
print(" GT: {}, PREDICT: {}, SCORE: {}".format(i['gt_found'], i['found'], 1 if i['gt_found']==i['found'] else 0))
'''
return traj

View File

@ -87,7 +87,6 @@ def construct_instrs(anno_dir, dataset, splits, tokenizer, max_instr_len=512):
new_item['objId'] = None
new_item['instruction'] = instr
new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len]
new_item['found'] = item['found'][j]
del new_item['instructions']
del new_item['instr_encodings']
data.append(new_item)

View File

@ -311,7 +311,6 @@ class ReverieObjectNavBatch(object):
'navigableLocations' : state.navigableLocations,
'instruction' : item['instruction'],
'instr_encoding': item['instr_encoding'],
'gt_found' : item['found'],
'gt_path' : item['path'],
'gt_end_vps': item.get('end_vps', []),
'gt_obj_id': item['objId'],
@ -352,7 +351,7 @@ class ReverieObjectNavBatch(object):
############### Nav Evaluation ###############
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid, pred_found, gt_found):
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid):
scores = {}
shortest_distances = self.shortest_distances[scan]
@ -370,10 +369,8 @@ class ReverieObjectNavBatch(object):
assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid))
scores['success'] = float(path[-1] in goal_viewpoints)
scores['found_success'] = float(pred_found == gt_found)
scores['oracle_success'] = float(any(x in goal_viewpoints for x in path))
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
scores['sspl'] = scores['spl'] * scores['found_success']
scores['rgs'] = str(pred_objid) == str(gt_objid)
scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
@ -383,7 +380,6 @@ class ReverieObjectNavBatch(object):
''' Evaluate each agent trajectory based on how close it got to the goal location
the path contains [view_id, angle, vofv]'''
print('eval %d predictions' % (len(preds)))
print(preds[0])
metrics = defaultdict(list)
for item in preds:
@ -391,9 +387,7 @@ class ReverieObjectNavBatch(object):
traj = item['trajectory']
pred_objid = item.get('pred_objid', None)
scan, gt_traj, gt_objid = self.gt_trajs[instr_id]
pred_found = item['found']
gt_found = item['gt_found']
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid, pred_found, gt_found)
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid)
for k, v in traj_scores.items():
metrics[k].append(v)
metrics['instr_id'].append(instr_id)
@ -407,8 +401,6 @@ class ReverieObjectNavBatch(object):
'spl': np.mean(metrics['spl']) * 100,
'rgs': np.mean(metrics['rgs']) * 100,
'rgspl': np.mean(metrics['rgspl']) * 100,
'sspl': np.mean(metrics['sspl']) * 100,
'found_sr': np.mean(metrics['found_success']) * 100,
}
return avg_metrics, metrics

View File

@ -65,8 +65,8 @@ def build_dataset(args, rank=0):
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
)
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
val_env_names = [ 'val_seen', 'val_unseen']
# val_env_names = ['val_train_seen']
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
if args.submit:
val_env_names.append('test')
@ -136,7 +136,7 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
)
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", "sspl": 0., 'found_sr': 0.}}
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
for idx in range(start_iter, start_iter+args.iters, args.log_every):
listner.logs = defaultdict(list)
@ -201,11 +201,9 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
# select model by spl
if env_name in best_val:
if score_summary['sspl'] >= best_val[env_name]['sspl']:
if score_summary['spl'] >= best_val[env_name]['spl']:
best_val[env_name]['spl'] = score_summary['spl']
best_val[env_name]['sspl'] = score_summary['sspl']
best_val[env_name]['sr'] = score_summary['sr']
best_val[env_name]['found_sr'] = score_summary['found_sr']
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))