feat: validate the baseline model's detection ability

This commit is contained in:
Ting-Jun Wang 2024-07-16 17:01:03 +08:00
parent 0a533c0647
commit 6e44a6342e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
2 changed files with 4 additions and 3 deletions

View File

@ -324,7 +324,8 @@ class GMapObjectNavAgent(Seq2SeqAgent):
ml_loss = 0.
og_loss = 0.
for t in range(self.args.max_action_len):
# for t in range(self.args.max_action_len):
for t in range(1):
for i, gmap in enumerate(gmaps):
if not ended[i]:
gmap.node_step_ids[obs[i]['viewpoint']] = t + 1

View File

@ -339,7 +339,7 @@ class ReverieObjectNavBatch(object):
self._next_minibatch(**kwargs)
scanIds = [item['scan'] for item in self.batch]
viewpointIds = [item['path'][0] for item in self.batch]
viewpointIds = [item['path'][-1] for item in self.batch]
headings = [item['heading'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings)
return self._get_obs()
@ -357,7 +357,7 @@ class ReverieObjectNavBatch(object):
shortest_distances = self.shortest_distances[scan]
path = sum(pred_path, [])
assert gt_path[0] == path[0], 'Result trajectories should include the start position'
# assert gt_path[0] == path[0], 'Result trajectories should include the start position'
scores['action_steps'] = len(pred_path) - 1
scores['trajectory_steps'] = len(path) - 1