feat: show target start point in state
This commit is contained in:
parent
9165a802d3
commit
0546814202
@ -34,12 +34,14 @@ class Simulator(object):
|
||||
viewpoint_ID: str,
|
||||
heading: int,
|
||||
elevation: int,
|
||||
stop: str):
|
||||
stop: str,
|
||||
start: str):
|
||||
self.heading = heading
|
||||
self.elevation = elevation
|
||||
self.scan_ID = scan_ID
|
||||
self.viewpoint_ID = viewpoint_ID
|
||||
self.stop = stop
|
||||
self.start = start
|
||||
# Load navigable dict
|
||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||
with open(navigable_path, 'r') as f:
|
||||
@ -59,7 +61,8 @@ class Simulator(object):
|
||||
'heading': self.heading,
|
||||
'elevation': self.elevation,
|
||||
'candidate': self.candidate,
|
||||
'stop': self.stop
|
||||
'stop': self.stop,
|
||||
'start': self.start,
|
||||
}
|
||||
return self.state
|
||||
|
||||
@ -104,9 +107,9 @@ class EnvBatch(object):
|
||||
def _make_id(self, scanId, viewpointId):
|
||||
return scanId + '_' + viewpointId
|
||||
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings, stops):
|
||||
for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop)
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts):
|
||||
for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start)
|
||||
|
||||
def getStates(self):
|
||||
"""
|
||||
@ -121,7 +124,6 @@ class EnvBatch(object):
|
||||
|
||||
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
|
||||
feature_states.append((feature, state))
|
||||
print(feature_states[-1])
|
||||
return feature_states
|
||||
|
||||
def makeActions(self, next_viewpoint_IDs):
|
||||
@ -203,7 +205,6 @@ class REVERIENavBatch(object):
|
||||
else:
|
||||
self.ix += batch_size
|
||||
self.batch = batch
|
||||
print(self.batch)
|
||||
|
||||
def reset_epoch(self, shuffle=False):
|
||||
''' Reset the data index to beginning of epoch. Primarily for testing.
|
||||
@ -230,7 +231,9 @@ class REVERIENavBatch(object):
|
||||
'candidate': state['candidate'],
|
||||
'instruction' : item['instruction'],
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['path_id']
|
||||
'path_id' : item['path_id'],
|
||||
'stop': item['stop'],
|
||||
'start': item['start']
|
||||
}
|
||||
# RL reward. The negative distance between the state and the final state
|
||||
# There are multiple gt end viewpoints on REVERIE.
|
||||
@ -253,7 +256,8 @@ class REVERIENavBatch(object):
|
||||
viewpointIds = [item['path'][0] for item in self.batch]
|
||||
headings = [item['heading'] for item in self.batch]
|
||||
stops = [item['stop'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings, stops)
|
||||
starts = [item['start'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts)
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, next_viewpoint_IDs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user