feat: evaluation in result
This commit is contained in:
parent
64fbce018a
commit
5cbd75711e
@ -272,7 +272,7 @@ class REVERIENavBatch(object):
|
||||
self.batch_size = batch_size
|
||||
self.name = name
|
||||
|
||||
#self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
||||
self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
||||
|
||||
# use different seeds in different processes to shuffle data
|
||||
self.seed = seed
|
||||
@ -288,14 +288,12 @@ class REVERIENavBatch(object):
|
||||
print('%s loaded with %d instructions, using splits: %s' % (
|
||||
self.__class__.__name__, len(self.data), self.name))
|
||||
|
||||
'''
|
||||
def _get_gt_trajs(self, data):
|
||||
gt_trajs = {
|
||||
x['instr_id']: (x['scan'], x['path']) \
|
||||
for x in data if len(x['path']) > 1
|
||||
}
|
||||
return gt_trajs
|
||||
'''
|
||||
|
||||
def size(self):
|
||||
return len(self.data)
|
||||
@ -412,7 +410,7 @@ class REVERIENavBatch(object):
|
||||
shortest_distances = self.shortest_distances[scan]
|
||||
|
||||
path = sum(pred_path, [])
|
||||
assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
||||
# assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
||||
|
||||
nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path)
|
||||
|
||||
@ -426,7 +424,7 @@ class REVERIENavBatch(object):
|
||||
gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
|
||||
|
||||
scores['success'] = float(scores['nav_error'] < ERROR_MARGIN)
|
||||
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||
# scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||
scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN)
|
||||
|
||||
scores.update(
|
||||
@ -459,7 +457,7 @@ class REVERIENavBatch(object):
|
||||
'oracle_error': np.mean(metrics['oracle_error']),
|
||||
'sr': np.mean(metrics['success']) * 100,
|
||||
'oracle_sr': np.mean(metrics['oracle_success']) * 100,
|
||||
'spl': np.mean(metrics['spl']) * 100,
|
||||
# 'spl': np.mean(metrics['spl']) * 100,
|
||||
'nDTW': np.mean(metrics['nDTW']) * 100,
|
||||
'SDTW': np.mean(metrics['SDTW']) * 100,
|
||||
'CLS': np.mean(metrics['CLS']) * 100,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user