Compare commits
No commits in common. "b2dce6111efffd728419a83bfb6cc46070bd179b" and "93e8b233164bc079a6db48b8a0a78d123ec8de41" have entirely different histories.
b2dce6111e
...
93e8b23316
@ -27,7 +27,7 @@ class BaseAgent(object):
|
||||
def get_results(self, detailed_output=False):
|
||||
output = []
|
||||
for k, v in self.results.items():
|
||||
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
|
||||
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']})
|
||||
if detailed_output:
|
||||
output[-1]['details'] = v['details']
|
||||
return output
|
||||
|
||||
@ -174,14 +174,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
|
||||
batch_size = len(obs)
|
||||
# print("PANO shape", pano_embeds.shape)
|
||||
|
||||
# add [stop] token & [NOT FOUND] token
|
||||
# [STOP] 在最前面, [NOT FOUND] 在最後面
|
||||
# add [stop] token
|
||||
vp_img_embeds = torch.cat(
|
||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
|
||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1
|
||||
)
|
||||
# print("SHAPE:", vp_img_embeds.shape)
|
||||
|
||||
batch_vp_pos_fts = []
|
||||
for i, gmap in enumerate(gmaps):
|
||||
@ -195,33 +192,19 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
)
|
||||
# add [stop] token at beginning
|
||||
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
|
||||
# print("vp_pos_fts:", vp_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[:, :7] = cur_start_pos_fts
|
||||
# print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
|
||||
# print("cur_start_pos_fts:", cur_start_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
|
||||
# print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
|
||||
# print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
|
||||
|
||||
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
|
||||
|
||||
batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()
|
||||
|
||||
# 要把 stop 和 not found 的 mask 補上去
|
||||
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
||||
# print('vp_nav_masks:', vp_nav_masks.shape)
|
||||
# print('vp_obj_masks:', vp_obj_masks.shape)
|
||||
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
||||
# print()
|
||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1)
|
||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2], 1)
|
||||
|
||||
return {
|
||||
'vp_img_embeds': vp_img_embeds,
|
||||
'vp_pos_fts': batch_vp_pos_fts,
|
||||
'vp_masks': vp_masks,
|
||||
'vp_masks': gen_seq_masks(view_lens+obj_lens+1),
|
||||
'vp_nav_masks': vp_nav_masks,
|
||||
'vp_obj_masks': vp_obj_masks,
|
||||
'vp_cand_vpids': [[None]+x for x in cand_vpids],
|
||||
@ -259,7 +242,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def _teacher_object(self, obs, ended, view_lens, obj_logits):
|
||||
def _teacher_object(self, obs, ended, view_lens):
|
||||
targets = np.zeros(len(obs), dtype=np.int64)
|
||||
for i, ob in enumerate(obs):
|
||||
if ended[i]:
|
||||
@ -269,18 +252,12 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
if i_vp not in ob['gt_end_vps']:
|
||||
targets[i] = self.args.ignoreid
|
||||
else:
|
||||
|
||||
i_objids = ob['obj_ids']
|
||||
targets[i] = self.args.ignoreid
|
||||
for j, obj_id in enumerate(i_objids):
|
||||
if str(obj_id) == str(ob['gt_obj_id']):
|
||||
|
||||
if ob['gt_found'] == True: # 可以找得到
|
||||
targets[i] = j + view_lens[i] + 1
|
||||
else:
|
||||
targets[i] = len(obj_logits[i])-1 # 不能找到,
|
||||
targets[i] = j + view_lens[i] + 1
|
||||
break
|
||||
|
||||
return torch.from_numpy(targets).cuda()
|
||||
|
||||
def make_equiv_action(self, a_t, gmaps, obs, traj=None):
|
||||
@ -321,8 +298,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
batch_size = len(obs)
|
||||
# build graph: keep the start viewpoint
|
||||
|
||||
gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point
|
||||
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
|
||||
for i, ob in enumerate(obs):
|
||||
gmaps[i].update_graph(ob)
|
||||
|
||||
@ -331,9 +307,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
'instr_id': ob['instr_id'],
|
||||
'path': [[ob['viewpoint']]],
|
||||
'pred_objid': None,
|
||||
'gt_objid': None,
|
||||
'found': None,
|
||||
'gt_found': None,
|
||||
'details': {},
|
||||
} for ob in obs]
|
||||
|
||||
@ -406,22 +379,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
i_vp = obs[i]['viewpoint']
|
||||
# update i_vp: stop and object grounding scores
|
||||
i_objids = obs[i]['obj_ids']
|
||||
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found
|
||||
|
||||
if len(i_objids) > 0:
|
||||
if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格)
|
||||
og = -1
|
||||
else:
|
||||
og = i_objids[torch.argmax(i_obj_logits)]
|
||||
else:
|
||||
og = None
|
||||
|
||||
# 如果有找到,og 會是 object id
|
||||
# 如果是 not found,og 會是 -1
|
||||
# 如果這個 viewpoint 看不到物件,og 會是 None
|
||||
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:]
|
||||
gmap.node_stop_scores[i_vp] = {
|
||||
'stop': nav_probs[i, 0].data.item(),
|
||||
'og': og,
|
||||
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
|
||||
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
|
||||
}
|
||||
|
||||
@ -442,7 +403,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
)
|
||||
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
||||
# objec grounding
|
||||
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
||||
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'])
|
||||
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
||||
og_loss += self.criterion(obj_logits, obj_targets)
|
||||
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
||||
@ -490,11 +451,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
else:
|
||||
cpu_a_t.append(nav_vpids[i][a_t[i]])
|
||||
|
||||
original_gt_founds = [ ob['gt_found'] for ob in obs ]
|
||||
# Make action and get the new state
|
||||
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
|
||||
for i in range(batch_size):
|
||||
traj[i]['gt_found'] = original_gt_founds[i]
|
||||
if (not ended[i]) and just_ended[i]:
|
||||
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
|
||||
for k, v in gmaps[i].node_stop_scores.items():
|
||||
@ -504,10 +463,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
|
||||
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
|
||||
traj[i]['pred_objid'] = stop_score['og']
|
||||
if stop_score['og'] == -1:
|
||||
traj[i]['found'] = False
|
||||
else:
|
||||
traj[i]['found'] = True
|
||||
if self.args.detailed_output:
|
||||
for k, v in gmaps[i].node_stop_scores.items():
|
||||
traj[i]['details'][k] = {
|
||||
@ -537,10 +492,4 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
self.logs['IL_loss'].append(ml_loss.item())
|
||||
self.logs['OG_loss'].append(og_loss.item())
|
||||
|
||||
'''
|
||||
print("TRAJ:")
|
||||
for i in traj:
|
||||
print(" GT: {}, PREDICT: {}, SCORE: {}".format(i['gt_found'], i['found'], 1 if i['gt_found']==i['found'] else 0))
|
||||
|
||||
'''
|
||||
return traj
|
||||
|
||||
@ -87,7 +87,6 @@ def construct_instrs(anno_dir, dataset, splits, tokenizer, max_instr_len=512):
|
||||
new_item['objId'] = None
|
||||
new_item['instruction'] = instr
|
||||
new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len]
|
||||
new_item['found'] = item['found'][j]
|
||||
del new_item['instructions']
|
||||
del new_item['instr_encodings']
|
||||
data.append(new_item)
|
||||
|
||||
@ -311,7 +311,6 @@ class ReverieObjectNavBatch(object):
|
||||
'navigableLocations' : state.navigableLocations,
|
||||
'instruction' : item['instruction'],
|
||||
'instr_encoding': item['instr_encoding'],
|
||||
'gt_found' : item['found'],
|
||||
'gt_path' : item['path'],
|
||||
'gt_end_vps': item.get('end_vps', []),
|
||||
'gt_obj_id': item['objId'],
|
||||
@ -352,7 +351,7 @@ class ReverieObjectNavBatch(object):
|
||||
|
||||
|
||||
############### Nav Evaluation ###############
|
||||
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid, pred_found, gt_found):
|
||||
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid):
|
||||
scores = {}
|
||||
|
||||
shortest_distances = self.shortest_distances[scan]
|
||||
@ -370,10 +369,8 @@ class ReverieObjectNavBatch(object):
|
||||
assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid))
|
||||
|
||||
scores['success'] = float(path[-1] in goal_viewpoints)
|
||||
scores['found_success'] = float(pred_found == gt_found)
|
||||
scores['oracle_success'] = float(any(x in goal_viewpoints for x in path))
|
||||
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||
scores['sspl'] = scores['spl'] * scores['found_success']
|
||||
|
||||
scores['rgs'] = str(pred_objid) == str(gt_objid)
|
||||
scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||
@ -383,7 +380,6 @@ class ReverieObjectNavBatch(object):
|
||||
''' Evaluate each agent trajectory based on how close it got to the goal location
|
||||
the path contains [view_id, angle, vofv]'''
|
||||
print('eval %d predictions' % (len(preds)))
|
||||
print(preds[0])
|
||||
|
||||
metrics = defaultdict(list)
|
||||
for item in preds:
|
||||
@ -391,9 +387,7 @@ class ReverieObjectNavBatch(object):
|
||||
traj = item['trajectory']
|
||||
pred_objid = item.get('pred_objid', None)
|
||||
scan, gt_traj, gt_objid = self.gt_trajs[instr_id]
|
||||
pred_found = item['found']
|
||||
gt_found = item['gt_found']
|
||||
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid, pred_found, gt_found)
|
||||
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid)
|
||||
for k, v in traj_scores.items():
|
||||
metrics[k].append(v)
|
||||
metrics['instr_id'].append(instr_id)
|
||||
@ -407,8 +401,6 @@ class ReverieObjectNavBatch(object):
|
||||
'spl': np.mean(metrics['spl']) * 100,
|
||||
'rgs': np.mean(metrics['rgs']) * 100,
|
||||
'rgspl': np.mean(metrics['rgspl']) * 100,
|
||||
'sspl': np.mean(metrics['sspl']) * 100,
|
||||
'found_sr': np.mean(metrics['found_success']) * 100,
|
||||
}
|
||||
return avg_metrics, metrics
|
||||
|
||||
|
||||
@ -65,8 +65,8 @@ def build_dataset(args, rank=0):
|
||||
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
|
||||
)
|
||||
|
||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
val_env_names = [ 'val_seen', 'val_unseen']
|
||||
# val_env_names = ['val_train_seen']
|
||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
|
||||
if args.submit:
|
||||
val_env_names.append('test')
|
||||
@ -136,7 +136,7 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
||||
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
|
||||
)
|
||||
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", "sspl": 0., 'found_sr': 0.}}
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
|
||||
|
||||
for idx in range(start_iter, start_iter+args.iters, args.log_every):
|
||||
listner.logs = defaultdict(list)
|
||||
@ -201,11 +201,9 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
||||
|
||||
# select model by spl
|
||||
if env_name in best_val:
|
||||
if score_summary['sspl'] >= best_val[env_name]['sspl']:
|
||||
if score_summary['spl'] >= best_val[env_name]['spl']:
|
||||
best_val[env_name]['spl'] = score_summary['spl']
|
||||
best_val[env_name]['sspl'] = score_summary['sspl']
|
||||
best_val[env_name]['sr'] = score_summary['sr']
|
||||
best_val[env_name]['found_sr'] = score_summary['found_sr']
|
||||
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
|
||||
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user