feat: add new metric (found_sr & SSPL)
This commit is contained in:
parent
fb82daf16a
commit
b2dce6111e
@ -27,7 +27,7 @@ class BaseAgent(object):
|
|||||||
def get_results(self, detailed_output=False):
|
def get_results(self, detailed_output=False):
|
||||||
output = []
|
output = []
|
||||||
for k, v in self.results.items():
|
for k, v in self.results.items():
|
||||||
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']})
|
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
|
||||||
if detailed_output:
|
if detailed_output:
|
||||||
output[-1]['details'] = v['details']
|
output[-1]['details'] = v['details']
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -352,7 +352,7 @@ class ReverieObjectNavBatch(object):
|
|||||||
|
|
||||||
|
|
||||||
############### Nav Evaluation ###############
|
############### Nav Evaluation ###############
|
||||||
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid):
|
def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid, pred_found, gt_found):
|
||||||
scores = {}
|
scores = {}
|
||||||
|
|
||||||
shortest_distances = self.shortest_distances[scan]
|
shortest_distances = self.shortest_distances[scan]
|
||||||
@ -370,8 +370,10 @@ class ReverieObjectNavBatch(object):
|
|||||||
assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid))
|
assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid))
|
||||||
|
|
||||||
scores['success'] = float(path[-1] in goal_viewpoints)
|
scores['success'] = float(path[-1] in goal_viewpoints)
|
||||||
|
scores['found_success'] = float(pred_found == gt_found)
|
||||||
scores['oracle_success'] = float(any(x in goal_viewpoints for x in path))
|
scores['oracle_success'] = float(any(x in goal_viewpoints for x in path))
|
||||||
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||||
|
scores['sspl'] = scores['spl'] * scores['found_success']
|
||||||
|
|
||||||
scores['rgs'] = str(pred_objid) == str(gt_objid)
|
scores['rgs'] = str(pred_objid) == str(gt_objid)
|
||||||
scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||||
@ -381,6 +383,7 @@ class ReverieObjectNavBatch(object):
|
|||||||
''' Evaluate each agent trajectory based on how close it got to the goal location
|
''' Evaluate each agent trajectory based on how close it got to the goal location
|
||||||
the path contains [view_id, angle, vofv]'''
|
the path contains [view_id, angle, vofv]'''
|
||||||
print('eval %d predictions' % (len(preds)))
|
print('eval %d predictions' % (len(preds)))
|
||||||
|
print(preds[0])
|
||||||
|
|
||||||
metrics = defaultdict(list)
|
metrics = defaultdict(list)
|
||||||
for item in preds:
|
for item in preds:
|
||||||
@ -388,7 +391,9 @@ class ReverieObjectNavBatch(object):
|
|||||||
traj = item['trajectory']
|
traj = item['trajectory']
|
||||||
pred_objid = item.get('pred_objid', None)
|
pred_objid = item.get('pred_objid', None)
|
||||||
scan, gt_traj, gt_objid = self.gt_trajs[instr_id]
|
scan, gt_traj, gt_objid = self.gt_trajs[instr_id]
|
||||||
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid)
|
pred_found = item['found']
|
||||||
|
gt_found = item['gt_found']
|
||||||
|
traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid, pred_found, gt_found)
|
||||||
for k, v in traj_scores.items():
|
for k, v in traj_scores.items():
|
||||||
metrics[k].append(v)
|
metrics[k].append(v)
|
||||||
metrics['instr_id'].append(instr_id)
|
metrics['instr_id'].append(instr_id)
|
||||||
@ -402,6 +407,8 @@ class ReverieObjectNavBatch(object):
|
|||||||
'spl': np.mean(metrics['spl']) * 100,
|
'spl': np.mean(metrics['spl']) * 100,
|
||||||
'rgs': np.mean(metrics['rgs']) * 100,
|
'rgs': np.mean(metrics['rgs']) * 100,
|
||||||
'rgspl': np.mean(metrics['rgspl']) * 100,
|
'rgspl': np.mean(metrics['rgspl']) * 100,
|
||||||
|
'sspl': np.mean(metrics['sspl']) * 100,
|
||||||
|
'found_sr': np.mean(metrics['found_success']) * 100,
|
||||||
}
|
}
|
||||||
return avg_metrics, metrics
|
return avg_metrics, metrics
|
||||||
|
|
||||||
|
|||||||
@ -136,7 +136,7 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
|||||||
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
|
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
|
||||||
)
|
)
|
||||||
|
|
||||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
|
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", "sspl": 0., 'found_sr': 0.}}
|
||||||
|
|
||||||
for idx in range(start_iter, start_iter+args.iters, args.log_every):
|
for idx in range(start_iter, start_iter+args.iters, args.log_every):
|
||||||
listner.logs = defaultdict(list)
|
listner.logs = defaultdict(list)
|
||||||
@ -201,9 +201,11 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
|||||||
|
|
||||||
# select model by spl
|
# select model by spl
|
||||||
if env_name in best_val:
|
if env_name in best_val:
|
||||||
if score_summary['spl'] >= best_val[env_name]['spl']:
|
if score_summary['sspl'] >= best_val[env_name]['sspl']:
|
||||||
best_val[env_name]['spl'] = score_summary['spl']
|
best_val[env_name]['spl'] = score_summary['spl']
|
||||||
|
best_val[env_name]['sspl'] = score_summary['sspl']
|
||||||
best_val[env_name]['sr'] = score_summary['sr']
|
best_val[env_name]['sr'] = score_summary['sr']
|
||||||
|
best_val[env_name]['found_sr'] = score_summary['found_sr']
|
||||||
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
|
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
|
||||||
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
|
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user