feat: add new metric (found_sr & SSPL)

This commit is contained in:
Ting-Jun Wang 2023-12-11 14:34:45 +08:00
parent fb82daf16a
commit b2dce6111e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 14 additions and 5 deletions

View File

@ -27,7 +27,7 @@ class BaseAgent(object):
def get_results(self, detailed_output=False):
output = []
for k, v in self.results.items():
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']})
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
if detailed_output:
output[-1]['details'] = v['details']
return output

View File

@ -352,7 +352,7 @@ class ReverieObjectNavBatch(object):
############### Nav Evaluation ###############
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid):
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid, pred_found, gt_found):
scores = {}
shortest_distances = self.shortest_distances[scan]
@ -370,8 +370,10 @@ class ReverieObjectNavBatch(object):
assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid))
scores['success'] = float(path[-1] in goal_viewpoints)
scores['found_success'] = float(pred_found == gt_found)
scores['oracle_success'] = float(any(x in goal_viewpoints for x in path))
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
scores['sspl'] = scores['spl'] * scores['found_success']
scores['rgs'] = str(pred_objid) == str(gt_objid)
scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
@ -381,6 +383,7 @@ class ReverieObjectNavBatch(object):
''' Evaluate each agent trajectory based on how close it got to the goal location
the path contains [view_id, angle, vofv]'''
print('eval %d predictions' % (len(preds)))
print(preds[0])
metrics = defaultdict(list)
for item in preds:
@ -388,7 +391,9 @@ class ReverieObjectNavBatch(object):
traj = item['trajectory']
pred_objid = item.get('pred_objid', None)
scan, gt_traj, gt_objid = self.gt_trajs[instr_id]
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid)
pred_found = item['found']
gt_found = item['gt_found']
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid, pred_found, gt_found)
for k, v in traj_scores.items():
metrics[k].append(v)
metrics['instr_id'].append(instr_id)
@ -402,6 +407,8 @@ class ReverieObjectNavBatch(object):
'spl': np.mean(metrics['spl']) * 100,
'rgs': np.mean(metrics['rgs']) * 100,
'rgspl': np.mean(metrics['rgspl']) * 100,
'sspl': np.mean(metrics['sspl']) * 100,
'found_sr': np.mean(metrics['found_success']) * 100,
}
return avg_metrics, metrics

View File

@ -136,7 +136,7 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
)
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", "sspl": 0., 'found_sr': 0.}}
for idx in range(start_iter, start_iter+args.iters, args.log_every):
listner.logs = defaultdict(list)
@ -201,9 +201,11 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
# select model by spl
if env_name in best_val:
if score_summary['spl'] >= best_val[env_name]['spl']:
if score_summary['sspl'] >= best_val[env_name]['sspl']:
best_val[env_name]['spl'] = score_summary['spl']
best_val[env_name]['sspl'] = score_summary['sspl']
best_val[env_name]['sr'] = score_summary['sr']
best_val[env_name]['found_sr'] = score_summary['found_sr']
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))