feat: validate the baseline model's detection ability
This commit is contained in:
parent
0a533c0647
commit
6e44a6342e
@ -324,7 +324,8 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
ml_loss = 0.
|
||||
og_loss = 0.
|
||||
|
||||
for t in range(self.args.max_action_len):
|
||||
# for t in range(self.args.max_action_len):
|
||||
for t in range(1):
|
||||
for i, gmap in enumerate(gmaps):
|
||||
if not ended[i]:
|
||||
gmap.node_step_ids[obs[i]['viewpoint']] = t + 1
|
||||
|
||||
@ -339,7 +339,7 @@ class ReverieObjectNavBatch(object):
|
||||
self._next_minibatch(**kwargs)
|
||||
|
||||
scanIds = [item['scan'] for item in self.batch]
|
||||
viewpointIds = [item['path'][0] for item in self.batch]
|
||||
viewpointIds = [item['path'][-1] for item in self.batch]
|
||||
headings = [item['heading'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings)
|
||||
return self._get_obs()
|
||||
@ -357,7 +357,7 @@ class ReverieObjectNavBatch(object):
|
||||
shortest_distances = self.shortest_distances[scan]
|
||||
|
||||
path = sum(pred_path, [])
|
||||
assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
||||
# assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
||||
|
||||
scores['action_steps'] = len(pred_path) - 1
|
||||
scores['trajectory_steps'] = len(path) - 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user