feat: show target start point in state

This commit is contained in:
Ting-Jun Wang 2024-04-28 21:45:34 +08:00
parent 9165a802d3
commit 0546814202
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -34,12 +34,14 @@ class Simulator(object):
viewpoint_ID: str, viewpoint_ID: str,
heading: int, heading: int,
elevation: int, elevation: int,
stop: str): stop: str,
start: str):
self.heading = heading self.heading = heading
self.elevation = elevation self.elevation = elevation
self.scan_ID = scan_ID self.scan_ID = scan_ID
self.viewpoint_ID = viewpoint_ID self.viewpoint_ID = viewpoint_ID
self.stop = stop self.stop = stop
self.start = start
# Load navigable dict # Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json') navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f: with open(navigable_path, 'r') as f:
@ -59,7 +61,8 @@ class Simulator(object):
'heading': self.heading, 'heading': self.heading,
'elevation': self.elevation, 'elevation': self.elevation,
'candidate': self.candidate, 'candidate': self.candidate,
'stop': self.stop 'stop': self.stop,
'start': self.start,
} }
return self.state return self.state
@ -104,9 +107,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId): def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings, stops): def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts):
for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)): for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop) self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start)
def getStates(self): def getStates(self):
""" """
@ -121,7 +124,6 @@ class EnvBatch(object):
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"]) feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
feature_states.append((feature, state)) feature_states.append((feature, state))
print(feature_states[-1])
return feature_states return feature_states
def makeActions(self, next_viewpoint_IDs): def makeActions(self, next_viewpoint_IDs):
@ -203,7 +205,6 @@ class REVERIENavBatch(object):
else: else:
self.ix += batch_size self.ix += batch_size
self.batch = batch self.batch = batch
print(self.batch)
def reset_epoch(self, shuffle=False): def reset_epoch(self, shuffle=False):
''' Reset the data index to beginning of epoch. Primarily for testing. ''' Reset the data index to beginning of epoch. Primarily for testing.
@ -230,7 +231,9 @@ class REVERIENavBatch(object):
'candidate': state['candidate'], 'candidate': state['candidate'],
'instruction' : item['instruction'], 'instruction' : item['instruction'],
'gt_path' : item['path'], 'gt_path' : item['path'],
'path_id' : item['path_id'] 'path_id' : item['path_id'],
'stop': item['stop'],
'start': item['start']
} }
# RL reward. The negative distance between the state and the final state # RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE. # There are multiple gt end viewpoints on REVERIE.
@ -253,7 +256,8 @@ class REVERIENavBatch(object):
viewpointIds = [item['path'][0] for item in self.batch] viewpointIds = [item['path'][0] for item in self.batch]
headings = [item['heading'] for item in self.batch] headings = [item['heading'] for item in self.batch]
stops = [item['stop'] for item in self.batch] stops = [item['stop'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops) starts = [item['start'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts)
return self._get_obs() return self._get_obs()
def step(self, next_viewpoint_IDs): def step(self, next_viewpoint_IDs):