diff --git a/map_nav_src/reverie/agent_obj.py b/map_nav_src/reverie/agent_obj.py index d343d9d..2222493 100644 --- a/map_nav_src/reverie/agent_obj.py +++ b/map_nav_src/reverie/agent_obj.py @@ -324,7 +324,8 @@ class GMapObjectNavAgent(Seq2SeqAgent): ml_loss = 0. og_loss = 0. - for t in range(self.args.max_action_len): + # for t in range(self.args.max_action_len): + for t in range(1): for i, gmap in enumerate(gmaps): if not ended[i]: gmap.node_step_ids[obs[i]['viewpoint']] = t + 1 diff --git a/map_nav_src/reverie/env.py b/map_nav_src/reverie/env.py index 8ee8056..92c83d6 100644 --- a/map_nav_src/reverie/env.py +++ b/map_nav_src/reverie/env.py @@ -339,7 +339,7 @@ class ReverieObjectNavBatch(object): self._next_minibatch(**kwargs) scanIds = [item['scan'] for item in self.batch] - viewpointIds = [item['path'][0] for item in self.batch] + viewpointIds = [item['path'][-1] for item in self.batch] headings = [item['heading'] for item in self.batch] self.env.newEpisodes(scanIds, viewpointIds, headings) return self._get_obs() @@ -357,7 +357,7 @@ class ReverieObjectNavBatch(object): shortest_distances = self.shortest_distances[scan] path = sum(pred_path, []) - assert gt_path[0] == path[0], 'Result trajectories should include the start position' + # assert gt_path[0] == path[0], 'Result trajectories should include the start position' scores['action_steps'] = len(pred_path) - 1 scores['trajectory_steps'] = len(path) - 1