feat: random choose the final stop node
This commit is contained in:
parent
68330c5163
commit
82e2c7e053
125
nav_src/agent.py
125
nav_src/agent.py
@ -1,4 +1,5 @@
|
||||
"""Agent that interacts with Matterport3D simulator via a hierarchical planning approach."""
|
||||
import random
|
||||
import json
|
||||
import yaml
|
||||
import re
|
||||
@ -711,123 +712,29 @@ class NavGPTAgent(BaseAgent):
|
||||
obs = self.env._get_obs()
|
||||
|
||||
global FINAL_STOP_POINT
|
||||
global TEMP_STEPS_COUNTER
|
||||
global NOW_LOCATION
|
||||
global SUCCESS
|
||||
|
||||
FINAL_STOP_POINT = obs[0]['stop']
|
||||
|
||||
if TEMP_STEPS_COUNTER != 0:
|
||||
TEMP_STEPS_COUNTER = 0
|
||||
|
||||
print(f"HAVE SET FINAL_STOP_POINT = {FINAL_STOP_POINT}")
|
||||
|
||||
print(len(obs))
|
||||
|
||||
print(obs[0].keys())
|
||||
print(obs[0]['obs'])
|
||||
print(obs[0]['obs_summary'])
|
||||
print(obs[0]['objects'])
|
||||
print(obs[0]['instr_id'])
|
||||
print(obs[0]['scan'])
|
||||
print(obs[0]['viewpoint'])
|
||||
print(obs[0]['heading'])
|
||||
print(obs[0]['elevation'])
|
||||
print(obs[0]['candidate'])
|
||||
print(obs[0]['instruction'])
|
||||
print(obs[0]['gt_path'])
|
||||
print(obs[0]['path_id'])
|
||||
print(obs[0]['stop'])
|
||||
print(obs[0]['start'])
|
||||
print(obs[0]['target'])
|
||||
NOW_LOCATION = obs[0]['start']
|
||||
|
||||
|
||||
print("==")
|
||||
|
||||
# Initialize the trajectory
|
||||
self.init_trajecotry(obs)
|
||||
|
||||
# Load the instruction
|
||||
# instructions = [ob['instruction'] for ob in obs]
|
||||
targets = [ob['target'] for ob in obs]
|
||||
print(obs[0].keys())
|
||||
print(obs[0]['start'])
|
||||
print(obs[0]['stop'])
|
||||
print(obs[0]['target'])
|
||||
candidates = self.env.env.sims[0].getNodesInTheRoom()
|
||||
candidates.remove(obs[0]['start'])
|
||||
|
||||
next_point = random.choice(candidates)
|
||||
print(next_point)
|
||||
|
||||
|
||||
print(self.config.load_instruction)
|
||||
print(self.config.load_action_plan)
|
||||
if next_point == FINAL_STOP_POINT:
|
||||
print(" SUCCESS")
|
||||
SUCCESS += 1
|
||||
|
||||
if self.config.load_instruction:
|
||||
# action_plans = instructions
|
||||
action_plans = targets
|
||||
elif self.config.load_action_plan:
|
||||
action_plans = [ob['action_plan'] for ob in obs]
|
||||
else:
|
||||
action_plans = []
|
||||
for instruction in instructions:
|
||||
action_plan = self.plan_chain.run(instruction = instruction)
|
||||
action_plans.append(action_plan)
|
||||
print(action_plans)
|
||||
|
||||
for i, init_ob in enumerate(obs):
|
||||
|
||||
# for our work
|
||||
# cur_action_plan is "target object with its location"
|
||||
self.cur_action_plan = action_plans[i]
|
||||
|
||||
print("use_tool_chain:", self.config.use_tool_chain)
|
||||
|
||||
# Take the first action
|
||||
if self.config.use_tool_chain: # we will not HERE
|
||||
first_obs = self.action_maker('')
|
||||
input = {
|
||||
'action_plan': self.cur_action_plan,
|
||||
'init_observation': init_ob['obs_summary'],
|
||||
'observation': first_obs,
|
||||
}
|
||||
else:
|
||||
# Get current feature
|
||||
|
||||
# we are HERE
|
||||
feature = init_ob['obs']
|
||||
navigable = init_ob['candidate']
|
||||
objects = init_ob['objects']
|
||||
heading = np.rad2deg(init_ob['heading'])
|
||||
elevation = np.rad2deg(init_ob['elevation'])
|
||||
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
|
||||
|
||||
print("use_relative_angle:", self.config.use_relative_angle)
|
||||
print("use_relative_angle:", self.config.use_navigable)
|
||||
if self.config.use_relative_angle: # True
|
||||
feature = self.modify_heading_angles(heading, feature, navigable, objects)
|
||||
if self.config.use_navigable: # False
|
||||
navigable = self.get_navigable_str(heading, elevation, navigable)
|
||||
|
||||
if self.config.use_relative_angle:
|
||||
if self.config.use_navigable:
|
||||
init_observation = f"\n\tCurrent Viewpoint:\n{feature}\n\tNavigable Viewpoints:\n{navigable}"
|
||||
else:
|
||||
init_observation = f"\n\tCurrent Viewpoint:\n{feature}"
|
||||
else:
|
||||
if self.config.use_navigable:
|
||||
init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}\n\tNavigable Viewpoints:\n{navigable}"
|
||||
else:
|
||||
init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}"
|
||||
|
||||
|
||||
input = {
|
||||
'action_plan': self.cur_action_plan, # here will be "object & its location" in our work
|
||||
'init_observation': init_observation, # 8 direction's observation caption & navigable point & objects
|
||||
}
|
||||
output = self.agent_executor(input)
|
||||
|
||||
self.traj[i]['llm_output'] = output['output']
|
||||
self.traj[i]['action_plan'] = output['action_plan']
|
||||
# extract agent's thought from llm output
|
||||
intermediate_steps = output['intermediate_steps']
|
||||
self.traj[i]['llm_thought'] = []
|
||||
self.traj[i]['llm_observation'] = []
|
||||
for action, observation in intermediate_steps:
|
||||
thought = action.log
|
||||
self.traj[i]['llm_thought'].append(thought)
|
||||
self.traj[i]['llm_observation'].append(observation)
|
||||
print(f"SUCCESS={SUCCESS}")
|
||||
|
||||
return self.traj
|
||||
|
||||
|
||||
@ -38,15 +38,18 @@ class BaseAgent(object):
|
||||
if iters is not None:
|
||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||
for i in range(iters):
|
||||
print(i)
|
||||
for traj in self.rollout(**kwargs):
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = traj
|
||||
'''
|
||||
preds_detail = self.get_results(detailed_output=True)
|
||||
json.dump(
|
||||
preds_detail,
|
||||
open(os.path.join(self.config.log_dir, 'runtime.json'), 'w'),
|
||||
sort_keys=True, indent=4, separators=(',', ': ')
|
||||
)
|
||||
'''
|
||||
else: # Do a full round
|
||||
while True:
|
||||
for traj in self.rollout(**kwargs):
|
||||
|
||||
@ -174,6 +174,15 @@ class Simulator(object):
|
||||
# Get candidate
|
||||
self.getCandidate()
|
||||
|
||||
def getNodesInTheRoom(self):
|
||||
ans = []
|
||||
start_region = self.node_region[self.scan_ID][self.viewpoint_ID]
|
||||
for node, region_id in self.node_region[self.scan_ID].items():
|
||||
if region_id == start_region:
|
||||
ans.append(node)
|
||||
return ans
|
||||
|
||||
|
||||
def updateGraph(self):
|
||||
# build graph
|
||||
for candidate in self.candidate.keys():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user