Compare commits
No commits in common. "a85950f06fd0c3ab5c36a33fa9d19e97f2fbe062" and "68330c51637baef786ed6a1cde975b811246e93c" have entirely different histories.
a85950f06f
...
68330c5163
@ -11,8 +11,7 @@ def dump_json(data, filename):
|
|||||||
json.dump(data, fp)
|
json.dump(data, fp)
|
||||||
|
|
||||||
for f in os.listdir():
|
for f in os.listdir():
|
||||||
if 'navgpt' in f:
|
if 'json' in f:
|
||||||
print(f)
|
|
||||||
data = load_json(f)
|
data = load_json(f)
|
||||||
|
|
||||||
new_data = []
|
new_data = []
|
||||||
@ -20,8 +19,7 @@ for f in os.listdir():
|
|||||||
for index, instr in enumerate(i['instructions']):
|
for index, instr in enumerate(i['instructions']):
|
||||||
new_i = i.copy()
|
new_i = i.copy()
|
||||||
new_i['instruction'] = instr
|
new_i['instruction'] = instr
|
||||||
# new_i['instr_id'] = f'{new_i["id"]}_{index}'
|
new_i['instr_id'] = f'{new_i["id"]}_{index}'
|
||||||
new_i['new_reverie_id'] = f'{new_i["new_reverie_id"]}_{index}'
|
|
||||||
del new_i['instructions']
|
del new_i['instructions']
|
||||||
|
|
||||||
new_data.append(new_i)
|
new_data.append(new_i)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from parser import parse_args
|
|||||||
from env import REVERIENavBatch
|
from env import REVERIENavBatch
|
||||||
from agent import NavGPTAgent
|
from agent import NavGPTAgent
|
||||||
|
|
||||||
def build_dataset(args):
|
def build_dataset(args, data_limit=100):
|
||||||
|
|
||||||
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
|
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
|
||||||
print(feat_db)
|
print(feat_db)
|
||||||
@ -26,7 +26,7 @@ def build_dataset(args):
|
|||||||
)
|
)
|
||||||
val_env = dataset_class(
|
val_env = dataset_class(
|
||||||
feat_db, val_instr_data, args.connectivity_dir, args.navigable_dir,
|
feat_db, val_instr_data, args.connectivity_dir, args.navigable_dir,
|
||||||
batch_size=args.batch_size, seed=args.seed, name=split
|
batch_size=args.batch_size, seed=args.seed, name=split, data_limit=data_limit
|
||||||
) # evaluation using all objects
|
) # evaluation using all objects
|
||||||
val_envs[split] = val_env
|
val_envs[split] = val_env
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ def valid_from_file(args, val_envs):
|
|||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
val_envs = build_dataset(args)
|
val_envs = build_dataset(args, data_limit=100)
|
||||||
|
|
||||||
if args.valid_file is not None:
|
if args.valid_file is not None:
|
||||||
valid_from_file(args, val_envs)
|
valid_from_file(args, val_envs)
|
||||||
|
|||||||
@ -46,7 +46,6 @@ EXCEPTION_TOOL_NAME = "_Exception"
|
|||||||
MAX_SCRATCHPAD_LENGTH = 7000
|
MAX_SCRATCHPAD_LENGTH = 7000
|
||||||
|
|
||||||
FINAL_STOP_POINT = ""
|
FINAL_STOP_POINT = ""
|
||||||
FINAL_STATE = ""
|
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
TEMP_STEPS_COUNTER = 0
|
TEMP_STEPS_COUNTER = 0
|
||||||
STEPS_COUNTER = 0
|
STEPS_COUNTER = 0
|
||||||
@ -74,7 +73,6 @@ class NavGPTOutputParser(AgentOutputParser):
|
|||||||
global TEMP_STEPS_COUNTER
|
global TEMP_STEPS_COUNTER
|
||||||
global SUCCESS
|
global SUCCESS
|
||||||
global NOW_LOCATION
|
global NOW_LOCATION
|
||||||
global FINAL_STATE
|
|
||||||
includes_answer = FINAL_ANSWER_ACTION in text
|
includes_answer = FINAL_ANSWER_ACTION in text
|
||||||
regex = (
|
regex = (
|
||||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*\"?([a-fA-F0-9]{32})\"?"
|
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*\"?([a-fA-F0-9]{32})\"?"
|
||||||
@ -102,7 +100,6 @@ class NavGPTOutputParser(AgentOutputParser):
|
|||||||
print(f"SUCCESS = {SUCCESS}")
|
print(f"SUCCESS = {SUCCESS}")
|
||||||
|
|
||||||
NOW_LOCATION = tool_input
|
NOW_LOCATION = tool_input
|
||||||
TEMP_STEPS_COUNTER += 1
|
|
||||||
print(f"NOW_LOCATION = {NOW_LOCATION}")
|
print(f"NOW_LOCATION = {NOW_LOCATION}")
|
||||||
|
|
||||||
|
|
||||||
@ -122,18 +119,7 @@ class NavGPTOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
return AgentAction(action, tool_input, text)
|
return AgentAction(action, tool_input, text)
|
||||||
elif includes_answer:
|
elif includes_answer:
|
||||||
is_STOP = 'Finished' in text
|
|
||||||
print("FINAL: ", is_STOP)
|
|
||||||
|
|
||||||
if is_STOP:
|
|
||||||
FINAL_STATE = 'stop'
|
|
||||||
else:
|
|
||||||
FINAL_STATE = 'not found'
|
|
||||||
|
|
||||||
|
|
||||||
if NOW_LOCATION == FINAL_STOP_POINT:
|
if NOW_LOCATION == FINAL_STOP_POINT:
|
||||||
STEPS_COUNTER += TEMP_STEPS_COUNTER
|
|
||||||
TEMP_STEPS_COUNTER = 0
|
|
||||||
SUCCESS += 1
|
SUCCESS += 1
|
||||||
print(f"SUCCESS = {SUCCESS}")
|
print(f"SUCCESS = {SUCCESS}")
|
||||||
else:
|
else:
|
||||||
@ -142,7 +128,6 @@ class NavGPTOutputParser(AgentOutputParser):
|
|||||||
print(f"{NOW_LOCATION}_{type(NOW_LOCATION)}")
|
print(f"{NOW_LOCATION}_{type(NOW_LOCATION)}")
|
||||||
print(f"{FINAL_STOP_POINT}_{type(FINAL_STOP_POINT)}")
|
print(f"{FINAL_STOP_POINT}_{type(FINAL_STOP_POINT)}")
|
||||||
print(f"SUCCESS = {SUCCESS}")
|
print(f"SUCCESS = {SUCCESS}")
|
||||||
print(f"STEPS_COUNTER = {STEPS_COUNTER}")
|
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||||
)
|
)
|
||||||
@ -392,7 +377,7 @@ class NavGPTAgent(BaseAgent):
|
|||||||
"""Initialize the trajectory with the given observation."""
|
"""Initialize the trajectory with the given observation."""
|
||||||
# Record the navigation path
|
# Record the navigation path
|
||||||
self.traj = [{
|
self.traj = [{
|
||||||
'instr_id': ob['new_reverie_id'],
|
'instr_id': ob['instr_id'],
|
||||||
'path': [[ob['start']]],
|
'path': [[ob['start']]],
|
||||||
'details': [],
|
'details': [],
|
||||||
} for ob in obs]
|
} for ob in obs]
|
||||||
@ -632,7 +617,7 @@ class NavGPTAgent(BaseAgent):
|
|||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
self.action_maker,
|
self.action_maker,
|
||||||
self.back_tracer,
|
self.back_tracer
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.config.use_tool_chain:
|
if self.config.use_tool_chain:
|
||||||
@ -710,10 +695,7 @@ class NavGPTAgent(BaseAgent):
|
|||||||
new_obs = self.env.step(actions)[0]
|
new_obs = self.env.step(actions)[0]
|
||||||
new_heading = np.rad2deg(new_obs['heading'])
|
new_heading = np.rad2deg(new_obs['heading'])
|
||||||
# Record the trajectory
|
# Record the trajectory
|
||||||
try:
|
|
||||||
self.traj[0]['path'].append(self.env.env.sims[0].gmap.bfs_shortest_path(cur_obs['viewpoint'], actions[0])[1:])
|
self.traj[0]['path'].append(self.env.env.sims[0].gmap.bfs_shortest_path(cur_obs['viewpoint'], actions[0])[1:])
|
||||||
except:
|
|
||||||
None
|
|
||||||
# Calculate the turned angle
|
# Calculate the turned angle
|
||||||
turned_angle = new_heading - cur_heading
|
turned_angle = new_heading - cur_heading
|
||||||
# Generate action description
|
# Generate action description
|
||||||
@ -730,12 +712,9 @@ class NavGPTAgent(BaseAgent):
|
|||||||
|
|
||||||
global FINAL_STOP_POINT
|
global FINAL_STOP_POINT
|
||||||
global TEMP_STEPS_COUNTER
|
global TEMP_STEPS_COUNTER
|
||||||
global STEPS_COUNTER
|
|
||||||
global FINAL_STATE
|
|
||||||
global NOW_LOCATION
|
global NOW_LOCATION
|
||||||
|
|
||||||
FINAL_STOP_POINT = obs[0]['gt_path'][-1]
|
FINAL_STOP_POINT = obs[0]['stop']
|
||||||
FINAL_STATE = ""
|
|
||||||
|
|
||||||
if TEMP_STEPS_COUNTER != 0:
|
if TEMP_STEPS_COUNTER != 0:
|
||||||
TEMP_STEPS_COUNTER = 0
|
TEMP_STEPS_COUNTER = 0
|
||||||
@ -748,6 +727,7 @@ class NavGPTAgent(BaseAgent):
|
|||||||
print(obs[0]['obs'])
|
print(obs[0]['obs'])
|
||||||
print(obs[0]['obs_summary'])
|
print(obs[0]['obs_summary'])
|
||||||
print(obs[0]['objects'])
|
print(obs[0]['objects'])
|
||||||
|
print(obs[0]['instr_id'])
|
||||||
print(obs[0]['scan'])
|
print(obs[0]['scan'])
|
||||||
print(obs[0]['viewpoint'])
|
print(obs[0]['viewpoint'])
|
||||||
print(obs[0]['heading'])
|
print(obs[0]['heading'])
|
||||||
@ -756,9 +736,9 @@ class NavGPTAgent(BaseAgent):
|
|||||||
print(obs[0]['instruction'])
|
print(obs[0]['instruction'])
|
||||||
print(obs[0]['gt_path'])
|
print(obs[0]['gt_path'])
|
||||||
print(obs[0]['path_id'])
|
print(obs[0]['path_id'])
|
||||||
|
print(obs[0]['stop'])
|
||||||
print(obs[0]['start'])
|
print(obs[0]['start'])
|
||||||
print(obs[0]['target'])
|
print(obs[0]['target'])
|
||||||
print(obs[0]['new_reverie_id'])
|
|
||||||
NOW_LOCATION = obs[0]['start']
|
NOW_LOCATION = obs[0]['start']
|
||||||
|
|
||||||
|
|
||||||
@ -839,11 +819,6 @@ class NavGPTAgent(BaseAgent):
|
|||||||
}
|
}
|
||||||
output = self.agent_executor(input)
|
output = self.agent_executor(input)
|
||||||
|
|
||||||
if 'stop' in FINAL_STATE:
|
|
||||||
self.traj[i]['final_state'] = 'stop'
|
|
||||||
else:
|
|
||||||
self.traj[i]['final_state'] = 'not found'
|
|
||||||
|
|
||||||
self.traj[i]['llm_output'] = output['output']
|
self.traj[i]['llm_output'] = output['output']
|
||||||
self.traj[i]['action_plan'] = output['action_plan']
|
self.traj[i]['action_plan'] = output['action_plan']
|
||||||
# extract agent's thought from llm output
|
# extract agent's thought from llm output
|
||||||
|
|||||||
@ -18,7 +18,6 @@ class BaseAgent(object):
|
|||||||
output[-1]['llm_output'] = v['llm_output']
|
output[-1]['llm_output'] = v['llm_output']
|
||||||
output[-1]['llm_thought'] = v['llm_thought']
|
output[-1]['llm_thought'] = v['llm_thought']
|
||||||
output[-1]['llm_observation'] = v['llm_observation']
|
output[-1]['llm_observation'] = v['llm_observation']
|
||||||
output[-1]['final_state'] = v['final_state']
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def rollout(self, **args):
|
def rollout(self, **args):
|
||||||
@ -51,8 +50,6 @@ class BaseAgent(object):
|
|||||||
else: # Do a full round
|
else: # Do a full round
|
||||||
while True:
|
while True:
|
||||||
for traj in self.rollout(**kwargs):
|
for traj in self.rollout(**kwargs):
|
||||||
print(f"ID: {traj['instr_id']}")
|
|
||||||
print(self.results.keys())
|
|
||||||
if traj['instr_id'] in self.results:
|
if traj['instr_id'] in self.results:
|
||||||
looped = True
|
looped = True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -143,6 +143,7 @@ class Simulator(object):
|
|||||||
viewpoint_ID: str,
|
viewpoint_ID: str,
|
||||||
heading: int,
|
heading: int,
|
||||||
elevation: int,
|
elevation: int,
|
||||||
|
stop: str,
|
||||||
start: str,
|
start: str,
|
||||||
target: str
|
target: str
|
||||||
):
|
):
|
||||||
@ -150,6 +151,7 @@ class Simulator(object):
|
|||||||
self.elevation = elevation
|
self.elevation = elevation
|
||||||
self.scan_ID = scan_ID
|
self.scan_ID = scan_ID
|
||||||
self.viewpoint_ID = viewpoint_ID
|
self.viewpoint_ID = viewpoint_ID
|
||||||
|
self.stop = stop
|
||||||
self.start = start
|
self.start = start
|
||||||
self.target = target
|
self.target = target
|
||||||
# Load navigable dict
|
# Load navigable dict
|
||||||
@ -184,6 +186,7 @@ class Simulator(object):
|
|||||||
'heading': self.heading,
|
'heading': self.heading,
|
||||||
'elevation': self.elevation,
|
'elevation': self.elevation,
|
||||||
'candidate': self.candidate,
|
'candidate': self.candidate,
|
||||||
|
'stop': self.stop,
|
||||||
'start': self.start,
|
'start': self.start,
|
||||||
'target': self.target
|
'target': self.target
|
||||||
}
|
}
|
||||||
@ -230,9 +233,9 @@ class EnvBatch(object):
|
|||||||
def _make_id(self, scanId, viewpointId):
|
def _make_id(self, scanId, viewpointId):
|
||||||
return scanId + '_' + viewpointId
|
return scanId + '_' + viewpointId
|
||||||
|
|
||||||
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets):
|
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts, targets):
|
||||||
for i, (scanId, viewpointId, heading, start, target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets)):
|
for i, (scanId, viewpointId, heading, stop, start, target) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts, targets)):
|
||||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target)
|
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start, target)
|
||||||
|
|
||||||
def getStates(self):
|
def getStates(self):
|
||||||
"""
|
"""
|
||||||
@ -260,7 +263,7 @@ class REVERIENavBatch(object):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, view_db, instr_data, connectivity_dir, navigable_dir,
|
self, view_db, instr_data, connectivity_dir, navigable_dir,
|
||||||
batch_size=1, seed=0, name=None
|
batch_size=1, seed=0, name=None, data_limit=100
|
||||||
):
|
):
|
||||||
self.env = EnvBatch(navigable_dir, feat_db=view_db, batch_size=batch_size)
|
self.env = EnvBatch(navigable_dir, feat_db=view_db, batch_size=batch_size)
|
||||||
self.data = instr_data
|
self.data = instr_data
|
||||||
@ -269,15 +272,14 @@ class REVERIENavBatch(object):
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
#self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
||||||
|
|
||||||
# use different seeds in different processes to shuffle data
|
# use different seeds in different processes to shuffle data
|
||||||
'''
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
random.shuffle(self.data)
|
random.shuffle(self.data)
|
||||||
'''
|
|
||||||
|
|
||||||
|
self.data = self.data[:data_limit]
|
||||||
|
|
||||||
self.ix = 0
|
self.ix = 0
|
||||||
self._load_nav_graphs()
|
self._load_nav_graphs()
|
||||||
@ -286,12 +288,14 @@ class REVERIENavBatch(object):
|
|||||||
print('%s loaded with %d instructions, using splits: %s' % (
|
print('%s loaded with %d instructions, using splits: %s' % (
|
||||||
self.__class__.__name__, len(self.data), self.name))
|
self.__class__.__name__, len(self.data), self.name))
|
||||||
|
|
||||||
|
'''
|
||||||
def _get_gt_trajs(self, data):
|
def _get_gt_trajs(self, data):
|
||||||
gt_trajs = {
|
gt_trajs = {
|
||||||
x['new_reverie_id']: (x['scan'], x['path']) \
|
x['instr_id']: (x['scan'], x['path']) \
|
||||||
for x in data if len(x['path']) > 1
|
for x in data if len(x['path']) > 1
|
||||||
}
|
}
|
||||||
return gt_trajs
|
return gt_trajs
|
||||||
|
'''
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -346,7 +350,7 @@ class REVERIENavBatch(object):
|
|||||||
'obs' : feature["detail"],
|
'obs' : feature["detail"],
|
||||||
'obs_summary' : feature["summary"],
|
'obs_summary' : feature["summary"],
|
||||||
'objects' : feature["objects"],
|
'objects' : feature["objects"],
|
||||||
# 'instr_id' : item['instr_id'],
|
'instr_id' : item['instr_id'],
|
||||||
# 'action_plan' : item['action_plan'],
|
# 'action_plan' : item['action_plan'],
|
||||||
'scan' : state['scanID'],
|
'scan' : state['scanID'],
|
||||||
'viewpoint' : state['viewpointID'],
|
'viewpoint' : state['viewpointID'],
|
||||||
@ -356,8 +360,8 @@ class REVERIENavBatch(object):
|
|||||||
'instruction' : item['instruction'],
|
'instruction' : item['instruction'],
|
||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'path_id' : item['path_id'],
|
'path_id' : item['path_id'],
|
||||||
|
'stop': item['stop'],
|
||||||
'start': item['start'],
|
'start': item['start'],
|
||||||
'new_reverie_id': item['new_reverie_id'],
|
|
||||||
'target': item['target']
|
'target': item['target']
|
||||||
}
|
}
|
||||||
# RL reward. The negative distance between the state and the final state
|
# RL reward. The negative distance between the state and the final state
|
||||||
@ -380,9 +384,10 @@ class REVERIENavBatch(object):
|
|||||||
scanIds = [item['scan'] for item in self.batch]
|
scanIds = [item['scan'] for item in self.batch]
|
||||||
viewpointIds = [item['path'][0] for item in self.batch]
|
viewpointIds = [item['path'][0] for item in self.batch]
|
||||||
headings = [item['heading'] for item in self.batch]
|
headings = [item['heading'] for item in self.batch]
|
||||||
|
stops = [item['stop'] for item in self.batch]
|
||||||
starts = [item['start'] for item in self.batch]
|
starts = [item['start'] for item in self.batch]
|
||||||
targets = [item['target'] for item in self.batch]
|
targets = [item['target'] for item in self.batch]
|
||||||
self.env.newEpisodes(scanIds, starts, headings, starts, targets)
|
self.env.newEpisodes(scanIds, starts, headings, stops, starts, targets)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def step(self, next_viewpoint_IDs):
|
def step(self, next_viewpoint_IDs):
|
||||||
@ -407,7 +412,7 @@ class REVERIENavBatch(object):
|
|||||||
shortest_distances = self.shortest_distances[scan]
|
shortest_distances = self.shortest_distances[scan]
|
||||||
|
|
||||||
path = sum(pred_path, [])
|
path = sum(pred_path, [])
|
||||||
# assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
assert gt_path[0] == path[0], 'Result trajectories should include the start position'
|
||||||
|
|
||||||
nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path)
|
nearest_position = self._get_nearest(shortest_distances, gt_path[-1], path)
|
||||||
|
|
||||||
@ -421,7 +426,7 @@ class REVERIENavBatch(object):
|
|||||||
gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
|
gt_lengths = np.sum([shortest_distances[a][b] for a, b in zip(gt_path[:-1], gt_path[1:])])
|
||||||
|
|
||||||
scores['success'] = float(scores['nav_error'] < ERROR_MARGIN)
|
scores['success'] = float(scores['nav_error'] < ERROR_MARGIN)
|
||||||
# scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
scores['spl'] = scores['success'] * gt_lengths / max(scores['trajectory_lengths'], gt_lengths, 0.01)
|
||||||
scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN)
|
scores['oracle_success'] = float(scores['oracle_error'] < ERROR_MARGIN)
|
||||||
|
|
||||||
scores.update(
|
scores.update(
|
||||||
@ -454,7 +459,7 @@ class REVERIENavBatch(object):
|
|||||||
'oracle_error': np.mean(metrics['oracle_error']),
|
'oracle_error': np.mean(metrics['oracle_error']),
|
||||||
'sr': np.mean(metrics['success']) * 100,
|
'sr': np.mean(metrics['success']) * 100,
|
||||||
'oracle_sr': np.mean(metrics['oracle_success']) * 100,
|
'oracle_sr': np.mean(metrics['oracle_success']) * 100,
|
||||||
# 'spl': np.mean(metrics['spl']) * 100,
|
'spl': np.mean(metrics['spl']) * 100,
|
||||||
'nDTW': np.mean(metrics['nDTW']) * 100,
|
'nDTW': np.mean(metrics['nDTW']) * 100,
|
||||||
'SDTW': np.mean(metrics['SDTW']) * 100,
|
'SDTW': np.mean(metrics['SDTW']) * 100,
|
||||||
'CLS': np.mean(metrics['CLS']) * 100,
|
'CLS': np.mean(metrics['CLS']) * 100,
|
||||||
|
|||||||
@ -7,8 +7,8 @@ def parse_args():
|
|||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
parser.add_argument('--root_dir', type=str, default='../datasets')
|
parser.add_argument('--root_dir', type=str, default='../datasets')
|
||||||
parser.add_argument('--dataset', type=str, default='reverie', choices=['r2r', 'r4r', 'reverie'])
|
parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r', 'r4r'])
|
||||||
parser.add_argument('--output_dir', type=str, default='../datasets/REVERIE/exprs/gpt-3.5-turbo', help='experiment id')
|
parser.add_argument('--output_dir', type=str, default='../datasets/R2R/exprs/gpt-3.5-turbo', help='experiment id')
|
||||||
# parser.add_argument('--output_dir', type=str, default='../datasets/R2R/exprs/LlaMA-2-13b-test', help='experiment id')
|
# parser.add_argument('--output_dir', type=str, default='../datasets/R2R/exprs/LlaMA-2-13b-test', help='experiment id')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ def parse_args():
|
|||||||
parser.add_argument('--max_iterations', type=int, default=25)
|
parser.add_argument('--max_iterations', type=int, default=25)
|
||||||
|
|
||||||
# General config
|
# General config
|
||||||
parser.add_argument('--iters', type=int, default=None, help='number of iterations to run')
|
parser.add_argument('--iters', type=int, default=10, help='number of iterations to run')
|
||||||
# parser.add_argument('--iters', type=int, default=None, help='number of iterations to run')
|
# parser.add_argument('--iters', type=int, default=None, help='number of iterations to run')
|
||||||
parser.add_argument('--max_scratchpad_length', type=int, default=1000, help='max number of steps in an episode')
|
parser.add_argument('--max_scratchpad_length', type=int, default=1000, help='max number of steps in an episode')
|
||||||
parser.add_argument('--test', action='store_true', default=False)
|
parser.add_argument('--test', action='store_true', default=False)
|
||||||
|
|||||||
@ -250,10 +250,9 @@ You will receive a trajectory instruction at the start and will have access to s
|
|||||||
|
|
||||||
Explore the environment while avoiding revisiting viewpoints by comparing current and previously visited IDs. Reach within 3 meters of the instructed destination, and if it's visible but no objects are detected, move closer.
|
Explore the environment while avoiding revisiting viewpoints by comparing current and previously visited IDs. Reach within 3 meters of the instructed destination, and if it's visible but no objects are detected, move closer.
|
||||||
|
|
||||||
At each step, determine if you've reached the destination(If the object is more than three meters away from you, you are not considered to have reached the destination).
|
At each step, determine if you've reached the destination.
|
||||||
If yes, stop and output 'Final Answer: Finished!'.
|
If yes, stop and output 'Final Answer: Finished!'.
|
||||||
If not, continue by considering your location and the next viewpoint based on the instruction, using the action_maker tool.
|
If not, continue by considering your location and the next viewpoint based on the instruction, using the action_maker tool.
|
||||||
And if you explored all room, you think this object doesn't exist in this room. stop and output 'Final Answer: Not found!'.
|
|
||||||
Show your reasoning in the Thought section.
|
Show your reasoning in the Thought section.
|
||||||
|
|
||||||
Follow the given format and use provided tools.
|
Follow the given format and use provided tools.
|
||||||
@ -272,11 +271,6 @@ Observation: the result of the action
|
|||||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||||
Thought: I have reached the destination, I can stop.
|
Thought: I have reached the destination, I can stop.
|
||||||
Final Answer: Finished!
|
Final Answer: Finished!
|
||||||
|
|
||||||
or
|
|
||||||
|
|
||||||
Thought: I cannot find the object in this room, I should stop.
|
|
||||||
Final Answer: Not found!
|
|
||||||
----
|
----
|
||||||
|
|
||||||
Begin!
|
Begin!
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user