From b2dce6111efffd728419a83bfb6cc46070bd179b Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 11 Dec 2023 14:34:45 +0800 Subject: [PATCH] feat: add new metric (found_sr & SSPL) --- map_nav_src/reverie/agent_base.py | 2 +- map_nav_src/reverie/env.py | 11 +++++++++-- map_nav_src/reverie/main_nav_obj.py | 6 ++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/map_nav_src/reverie/agent_base.py b/map_nav_src/reverie/agent_base.py index e14570f..b18d5a6 100644 --- a/map_nav_src/reverie/agent_base.py +++ b/map_nav_src/reverie/agent_base.py @@ -27,7 +27,7 @@ class BaseAgent(object): def get_results(self, detailed_output=False): output = [] for k, v in self.results.items(): - output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']}) + output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']}) if detailed_output: output[-1]['details'] = v['details'] return output diff --git a/map_nav_src/reverie/env.py b/map_nav_src/reverie/env.py index e9fe4af..65e9848 100644 --- a/map_nav_src/reverie/env.py +++ b/map_nav_src/reverie/env.py @@ -352,7 +352,7 @@ class ReverieObjectNavBatch(object): ############### Nav Evaluation ############### - def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid): + def _eval_item(self, scan, pred_path, pred_objid, gt_path, gt_objid, pred_found, gt_found): scores = {} shortest_distances = self.shortest_distances[scan] @@ -370,8 +370,10 @@ class ReverieObjectNavBatch(object): assert len(goal_viewpoints) > 0, '%s_%s'%(scan, str(gt_objid)) scores['success'] = float(path[-1] in goal_viewpoints) + scores['found_success'] = float(pred_found == gt_found) scores['oracle_success'] = float(any(x in goal_viewpoints for x in path)) scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) + scores['sspl'] = scores['spl'] * scores['found_success'] scores['rgs'] = str(pred_objid) == str(gt_objid) scores['rgspl'] = scores['rgs'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) @@ -381,6 +383,7 @@ class ReverieObjectNavBatch(object): ''' Evaluate each agent trajectory based on how close it got to the goal location the path contains [view_id, angle, vofv]''' print('eval %d predictions' % (len(preds))) + print(preds[0]) metrics = defaultdict(list) for item in preds: @@ -388,7 +391,9 @@ class ReverieObjectNavBatch(object): traj = item['trajectory'] pred_objid = item.get('pred_objid', None) scan, gt_traj, gt_objid = self.gt_trajs[instr_id] - traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid) + pred_found = item['found'] + gt_found = item['gt_found'] + traj_scores = self._eval_item(scan, traj, pred_objid, gt_traj, gt_objid, pred_found, gt_found) for k, v in traj_scores.items(): metrics[k].append(v) metrics['instr_id'].append(instr_id) @@ -402,6 +407,8 @@ class ReverieObjectNavBatch(object): 'spl': np.mean(metrics['spl']) * 100, 'rgs': np.mean(metrics['rgs']) * 100, 'rgspl': np.mean(metrics['rgspl']) * 100, + 'sspl': np.mean(metrics['sspl']) * 100, + 'found_sr': np.mean(metrics['found_success']) * 100, } return avg_metrics, metrics diff --git a/map_nav_src/reverie/main_nav_obj.py b/map_nav_src/reverie/main_nav_obj.py index aecae16..0acc1f1 100644 --- a/map_nav_src/reverie/main_nav_obj.py +++ b/map_nav_src/reverie/main_nav_obj.py @@ -136,7 +136,7 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1): '\nListener training starts, start iteration: %s' % str(start_iter), record_file ) - best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}} + best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", "sspl": 0., 'found_sr': 0.}} for idx in range(start_iter, start_iter+args.iters, args.log_every): listner.logs = defaultdict(list) @@ -201,9 +201,11 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1): # select model by spl if env_name in best_val: - if score_summary['spl'] >= best_val[env_name]['spl']: + if score_summary['sspl'] >= best_val[env_name]['sspl']: best_val[env_name]['spl'] = score_summary['spl'] + best_val[env_name]['sspl'] = score_summary['sspl'] best_val[env_name]['sr'] = score_summary['sr'] + best_val[env_name]['found_sr'] = score_summary['found_sr'] best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str) listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))