feat: show target start point in state

This commit is contained in:
Ting-Jun Wang 2024-04-28 21:45:34 +08:00
parent 9165a802d3
commit 0546814202
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -34,12 +34,14 @@ class Simulator(object):
viewpoint_ID: str,
heading: int,
elevation: int,
stop: str):
stop: str,
start: str):
self.heading = heading
self.elevation = elevation
self.scan_ID = scan_ID
self.viewpoint_ID = viewpoint_ID
self.stop = stop
self.start = start
# Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f:
@ -59,7 +61,8 @@ class Simulator(object):
'heading': self.heading,
'elevation': self.elevation,
'candidate': self.candidate,
'stop': self.stop
'stop': self.stop,
'start': self.start,
}
return self.state
@ -104,9 +107,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings, stops):
for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop)
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts):
for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start)
def getStates(self):
"""
@ -121,7 +124,6 @@ class EnvBatch(object):
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
feature_states.append((feature, state))
print(feature_states[-1])
return feature_states
def makeActions(self, next_viewpoint_IDs):
@ -203,7 +205,6 @@ class REVERIENavBatch(object):
else:
self.ix += batch_size
self.batch = batch
print(self.batch)
def reset_epoch(self, shuffle=False):
''' Reset the data index to beginning of epoch. Primarily for testing.
@ -230,7 +231,9 @@ class REVERIENavBatch(object):
'candidate': state['candidate'],
'instruction' : item['instruction'],
'gt_path' : item['path'],
'path_id' : item['path_id']
'path_id' : item['path_id'],
'stop': item['stop'],
'start': item['start']
}
# RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE.
@ -253,7 +256,8 @@ class REVERIENavBatch(object):
viewpointIds = [item['path'][0] for item in self.batch]
headings = [item['heading'] for item in self.batch]
stops = [item['stop'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops)
starts = [item['start'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts)
return self._get_obs()
def step(self, next_viewpoint_IDs):