feat: evaluation in result

This commit is contained in:
Ting-Jun Wang 2024-05-06 16:42:40 +08:00
parent 64fbce018a
commit 5cbd75711e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -272,7 +272,7 @@ class REVERIENavBatch(object):
self.batch_size = batch_size self.batch_size = batch_size
self.name = name self.name = name
#self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
# use different seeds in different processes to shuffle data # use different seeds in different processes to shuffle data
self.seed = seed self.seed = seed
@ -288,14 +288,12 @@ class REVERIENavBatch(object):
print('%s loaded with %d instructions, using splits: %s' % ( print('%s loaded with %d instructions, using splits: %s' % (
self.__class__.__name__, len(self.data), self.name)) self.__class__.__name__, len(self.data), self.name))
'''
def _get_gt_trajs(self, data): def _get_gt_trajs(self, data):
gt_trajs = { gt_trajs = {
x['instr_id']: (x['scan'], x['path']) \ x['instr_id']: (x['scan'], x['path']) \
for x in data if len(x['path']) > 1 for x in data if len(x['path']) > 1
} }
return gt_trajs return gt_trajs
'''
def size(self): def size(self):
return len(self.data) return len(self.data)
@ -412,7 +410,7 @@ class REVERIENavBatch(object):
shortest_distances = self.shortest_distances[scan] shortest_distances = self.shortest_distances[scan]
path = sum(pred_path, []) path = sum(pred_path, [])
assert gt_path[0] == path[0], 'Result trajectories should include the start position' # assert gt_path[0] == path[0], 'Result trajectories should include the start position'
nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path) nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path)
@ -426,7 +424,7 @@ class REVERIENavBatch(object):
gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])]) gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
scores['success'] = float(scores['nav_error'] < ERROR_MARGIN) scores['success'] = float(scores['nav_error'] < ERROR_MARGIN)
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01) # scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN) scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN)
scores.update( scores.update(
@ -459,7 +457,7 @@ class REVERIENavBatch(object):
'oracle_error': np.mean(metrics['oracle_error']), 'oracle_error': np.mean(metrics['oracle_error']),
'sr': np.mean(metrics['success']) * 100, 'sr': np.mean(metrics['success']) * 100,
'oracle_sr': np.mean(metrics['oracle_success']) * 100, 'oracle_sr': np.mean(metrics['oracle_success']) * 100,
'spl': np.mean(metrics['spl']) * 100, # 'spl': np.mean(metrics['spl']) * 100,
'nDTW': np.mean(metrics['nDTW']) * 100, 'nDTW': np.mean(metrics['nDTW']) * 100,
'SDTW': np.mean(metrics['SDTW']) * 100, 'SDTW': np.mean(metrics['SDTW']) * 100,
'CLS': np.mean(metrics['CLS']) * 100, 'CLS': np.mean(metrics['CLS']) * 100,