feat: add new metric (found_sr & SSPL)
This commit is contained in:
parent
fb82daf16a
commit
b2dce6111e
@ -27,7 +27,7 @@ class BaseAgent(object):
|
||||
def get_results(self, detailed_output=False):
|
||||
output = []
|
||||
for k, v in self.results.items():
|
||||
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']})
|
||||
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
|
||||
if detailed_output:
|
||||
output[-1]['details'] = v['details']
|
||||
return output
|
||||
|
||||
@ -352,7 +352,7 @@ class ReverieObjectNavBatch(object):
|
||||
|
||||
|
||||
############### Nav Evaluation ###############
|
||||
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid):
|
||||
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid, pred_found, gt_found):
|
||||
scores = {}
|
||||
|
||||
shortest_distances = self.shortest_distances[scan]
|
||||
@ -370,8 +370,10 @@ class ReverieObjectNavBatch(object):
|
||||
assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid))
|
||||
|
||||
scores['success'] = float(path[-1] in goal_viewpoints)
|
||||
scores['found_success'] = float(pred_found == gt_found)
|
||||
scores['oracle_success'] = float(any(x in goal_viewpoints for x in path))
|
||||
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||
scores['sspl'] = scores['spl'] * scores['found_success']
|
||||
|
||||
scores['rgs'] = str(pred_objid) == str(gt_objid)
|
||||
scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||
@ -381,6 +383,7 @@ class ReverieObjectNavBatch(object):
|
||||
''' Evaluate each agent trajectory based on how close it got to the goal location
|
||||
the path contains [view_id, angle, vofv]'''
|
||||
print('eval %d predictions' % (len(preds)))
|
||||
print(preds[0])
|
||||
|
||||
metrics = defaultdict(list)
|
||||
for item in preds:
|
||||
@ -388,7 +391,9 @@ class ReverieObjectNavBatch(object):
|
||||
traj = item['trajectory']
|
||||
pred_objid = item.get('pred_objid', None)
|
||||
scan, gt_traj, gt_objid = self.gt_trajs[instr_id]
|
||||
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid)
|
||||
pred_found = item['found']
|
||||
gt_found = item['gt_found']
|
||||
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid, pred_found, gt_found)
|
||||
for k, v in traj_scores.items():
|
||||
metrics[k].append(v)
|
||||
metrics['instr_id'].append(instr_id)
|
||||
@ -402,6 +407,8 @@ class ReverieObjectNavBatch(object):
|
||||
'spl': np.mean(metrics['spl']) * 100,
|
||||
'rgs': np.mean(metrics['rgs']) * 100,
|
||||
'rgspl': np.mean(metrics['rgspl']) * 100,
|
||||
'sspl': np.mean(metrics['sspl']) * 100,
|
||||
'found_sr': np.mean(metrics['found_success']) * 100,
|
||||
}
|
||||
return avg_metrics, metrics
|
||||
|
||||
|
||||
@ -136,7 +136,7 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
||||
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
|
||||
)
|
||||
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", "sspl": 0., 'found_sr': 0.}}
|
||||
|
||||
for idx in range(start_iter, start_iter+args.iters, args.log_every):
|
||||
listner.logs = defaultdict(list)
|
||||
@ -201,9 +201,11 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
||||
|
||||
# select model by spl
|
||||
if env_name in best_val:
|
||||
if score_summary['spl'] >= best_val[env_name]['spl']:
|
||||
if score_summary['sspl'] >= best_val[env_name]['sspl']:
|
||||
best_val[env_name]['spl'] = score_summary['spl']
|
||||
best_val[env_name]['sspl'] = score_summary['sspl']
|
||||
best_val[env_name]['sr'] = score_summary['sr']
|
||||
best_val[env_name]['found_sr'] = score_summary['found_sr']
|
||||
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
|
||||
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user