feat: show target start point in state
This commit is contained in:
parent
9165a802d3
commit
0546814202
@ -34,12 +34,14 @@ class Simulator(object):
|
|||||||
viewpoint_ID: str,
|
viewpoint_ID: str,
|
||||||
heading: int,
|
heading: int,
|
||||||
elevation: int,
|
elevation: int,
|
||||||
stop: str):
|
stop: str,
|
||||||
|
start: str):
|
||||||
self.heading = heading
|
self.heading = heading
|
||||||
self.elevation = elevation
|
self.elevation = elevation
|
||||||
self.scan_ID = scan_ID
|
self.scan_ID = scan_ID
|
||||||
self.viewpoint_ID = viewpoint_ID
|
self.viewpoint_ID = viewpoint_ID
|
||||||
self.stop = stop
|
self.stop = stop
|
||||||
|
self.start = start
|
||||||
# Load navigable dict
|
# Load navigable dict
|
||||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||||
with open(navigable_path, 'r') as f:
|
with open(navigable_path, 'r') as f:
|
||||||
@ -59,7 +61,8 @@ class Simulator(object):
|
|||||||
'heading': self.heading,
|
'heading': self.heading,
|
||||||
'elevation': self.elevation,
|
'elevation': self.elevation,
|
||||||
'candidate': self.candidate,
|
'candidate': self.candidate,
|
||||||
'stop': self.stop
|
'stop': self.stop,
|
||||||
|
'start': self.start,
|
||||||
}
|
}
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
@ -104,9 +107,9 @@ class EnvBatch(object):
|
|||||||
def _make_id(self, scanId, viewpointId):
|
def _make_id(self, scanId, viewpointId):
|
||||||
return scanId + '_' + viewpointId
|
return scanId + '_' + viewpointId
|
||||||
|
|
||||||
def newEpisodes(self, scanIds, viewpointIds, headings, stops):
|
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts):
|
||||||
for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)):
|
for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)):
|
||||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop)
|
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start)
|
||||||
|
|
||||||
def getStates(self):
|
def getStates(self):
|
||||||
"""
|
"""
|
||||||
@ -121,7 +124,6 @@ class EnvBatch(object):
|
|||||||
|
|
||||||
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
|
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
|
||||||
feature_states.append((feature, state))
|
feature_states.append((feature, state))
|
||||||
print(feature_states[-1])
|
|
||||||
return feature_states
|
return feature_states
|
||||||
|
|
||||||
def makeActions(self, next_viewpoint_IDs):
|
def makeActions(self, next_viewpoint_IDs):
|
||||||
@ -203,7 +205,6 @@ class REVERIENavBatch(object):
|
|||||||
else:
|
else:
|
||||||
self.ix += batch_size
|
self.ix += batch_size
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
print(self.batch)
|
|
||||||
|
|
||||||
def reset_epoch(self, shuffle=False):
|
def reset_epoch(self, shuffle=False):
|
||||||
''' Reset the data index to beginning of epoch. Primarily for testing.
|
''' Reset the data index to beginning of epoch. Primarily for testing.
|
||||||
@ -230,7 +231,9 @@ class REVERIENavBatch(object):
|
|||||||
'candidate': state['candidate'],
|
'candidate': state['candidate'],
|
||||||
'instruction' : item['instruction'],
|
'instruction' : item['instruction'],
|
||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'path_id' : item['path_id']
|
'path_id' : item['path_id'],
|
||||||
|
'stop': item['stop'],
|
||||||
|
'start': item['start']
|
||||||
}
|
}
|
||||||
# RL reward. The negative distance between the state and the final state
|
# RL reward. The negative distance between the state and the final state
|
||||||
# There are multiple gt end viewpoints on REVERIE.
|
# There are multiple gt end viewpoints on REVERIE.
|
||||||
@ -253,7 +256,8 @@ class REVERIENavBatch(object):
|
|||||||
viewpointIds = [item['path'][0] for item in self.batch]
|
viewpointIds = [item['path'][0] for item in self.batch]
|
||||||
headings = [item['heading'] for item in self.batch]
|
headings = [item['heading'] for item in self.batch]
|
||||||
stops = [item['stop'] for item in self.batch]
|
stops = [item['stop'] for item in self.batch]
|
||||||
self.env.newEpisodes(scanIds, viewpointIds, headings, stops)
|
starts = [item['start'] for item in self.batch]
|
||||||
|
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def step(self, next_viewpoint_IDs):
|
def step(self, next_viewpoint_IDs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user