diff --git a/nav_src/env.py b/nav_src/env.py index 5fa1637..76318ad 100644 --- a/nav_src/env.py +++ b/nav_src/env.py @@ -34,12 +34,14 @@ class Simulator(object): viewpoint_ID: str, heading: int, elevation: int, - stop: str): + stop: str, + start: str): self.heading = heading self.elevation = elevation self.scan_ID = scan_ID self.viewpoint_ID = viewpoint_ID self.stop = stop + self.start = start # Load navigable dict navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json') with open(navigable_path, 'r') as f: @@ -59,7 +61,8 @@ class Simulator(object): 'heading': self.heading, 'elevation': self.elevation, 'candidate': self.candidate, - 'stop': self.stop + 'stop': self.stop, + 'start': self.start, } return self.state @@ -104,9 +107,9 @@ class EnvBatch(object): def _make_id(self, scanId, viewpointId): return scanId + '_' + viewpointId - def newEpisodes(self, scanIds, viewpointIds, headings, stops): - for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)): - self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop) + def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts): + for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)): + self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start) def getStates(self): """ @@ -121,7 +124,6 @@ class EnvBatch(object): feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"]) feature_states.append((feature, state)) - print(feature_states[-1]) return feature_states def makeActions(self, next_viewpoint_IDs): @@ -203,7 +205,6 @@ class REVERIENavBatch(object): else: self.ix += batch_size self.batch = batch - print(self.batch) def reset_epoch(self, shuffle=False): ''' Reset the data index to beginning of epoch. Primarily for testing. @@ -230,7 +231,9 @@ class REVERIENavBatch(object): 'candidate': state['candidate'], 'instruction' : item['instruction'], 'gt_path' : item['path'], - 'path_id' : item['path_id'] + 'path_id' : item['path_id'], + 'stop': item['stop'], + 'start': item['start'] } # RL reward. The negative distance between the state and the final state # There are multiple gt end viewpoints on REVERIE. @@ -253,7 +256,8 @@ class REVERIENavBatch(object): viewpointIds = [item['path'][0] for item in self.batch] headings = [item['heading'] for item in self.batch] stops = [item['stop'] for item in self.batch] - self.env.newEpisodes(scanIds, viewpointIds, headings, stops) + starts = [item['start'] for item in self.batch] + self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts) return self._get_obs() def step(self, next_viewpoint_IDs):