feat: evaluation in result
This commit is contained in:
parent
64fbce018a
commit
5cbd75711e
@ -272,7 +272,7 @@ class REVERIENavBatch(object):
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
#self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
||||||
|
|
||||||
# use different seeds in different processes to shuffle data
|
# use different seeds in different processes to shuffle data
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@ -288,14 +288,12 @@ class REVERIENavBatch(object):
|
|||||||
print('%s loaded with %d instructions, using splits: %s' % (
|
print('%s loaded with %d instructions, using splits: %s' % (
|
||||||
self.__class__.__name__, len(self.data), self.name))
|
self.__class__.__name__, len(self.data), self.name))
|
||||||
|
|
||||||
'''
|
|
||||||
def _get_gt_trajs(self, data):
|
def _get_gt_trajs(self, data):
|
||||||
gt_trajs = {
|
gt_trajs = {
|
||||||
x['instr_id']: (x['scan'], x['path']) \
|
x['instr_id']: (x['scan'], x['path']) \
|
||||||
for x in data if len(x['path']) > 1
|
for x in data if len(x['path']) > 1
|
||||||
}
|
}
|
||||||
return gt_trajs
|
return gt_trajs
|
||||||
'''
|
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -412,7 +410,7 @@ class REVERIENavBatch(object):
|
|||||||
shortest_distances = self.shortest_distances[scan]
|
shortest_distances = self.shortest_distances[scan]
|
||||||
|
|
||||||
path = sum(pred_path, [])
|
path = sum(pred_path, [])
|
||||||
assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
# assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
||||||
|
|
||||||
nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path)
|
nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path)
|
||||||
|
|
||||||
@ -426,7 +424,7 @@ class REVERIENavBatch(object):
|
|||||||
gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
|
gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
|
||||||
|
|
||||||
scores['success'] = float(scores['nav_error'] < ERROR_MARGIN)
|
scores['success'] = float(scores['nav_error'] < ERROR_MARGIN)
|
||||||
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
# scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||||
scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN)
|
scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN)
|
||||||
|
|
||||||
scores.update(
|
scores.update(
|
||||||
@ -459,7 +457,7 @@ class REVERIENavBatch(object):
|
|||||||
'oracle_error': np.mean(metrics['oracle_error']),
|
'oracle_error': np.mean(metrics['oracle_error']),
|
||||||
'sr': np.mean(metrics['success']) * 100,
|
'sr': np.mean(metrics['success']) * 100,
|
||||||
'oracle_sr': np.mean(metrics['oracle_success']) * 100,
|
'oracle_sr': np.mean(metrics['oracle_success']) * 100,
|
||||||
'spl': np.mean(metrics['spl']) * 100,
|
# 'spl': np.mean(metrics['spl']) * 100,
|
||||||
'nDTW': np.mean(metrics['nDTW']) * 100,
|
'nDTW': np.mean(metrics['nDTW']) * 100,
|
||||||
'SDTW': np.mean(metrics['SDTW']) * 100,
|
'SDTW': np.mean(metrics['SDTW']) * 100,
|
||||||
'CLS': np.mean(metrics['CLS']) * 100,
|
'CLS': np.mean(metrics['CLS']) * 100,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user