feat: add target in state & let LLM use target to find
This commit is contained in:
parent
0546814202
commit
fc0ccb1458
@ -197,6 +197,7 @@ class NavGPTAgent(BaseAgent):
|
||||
|
||||
self.output_parser = NavGPTOutputParser()
|
||||
self.agent_executor = self.create_vln_agent()
|
||||
print("AGENT_EXECUTOR: ", type(self.agent_executor))
|
||||
|
||||
plan_prompt = PromptTemplate(
|
||||
template=PLANNER_PROMPT,
|
||||
@ -331,11 +332,11 @@ class NavGPTAgent(BaseAgent):
|
||||
# Record the navigation path
|
||||
self.traj = [{
|
||||
'instr_id': ob['instr_id'],
|
||||
'path': [[ob['viewpoint']]],
|
||||
'path': [[ob['start']]],
|
||||
'details': [],
|
||||
} for ob in obs]
|
||||
# Record the history of actions taken
|
||||
self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["viewpoint"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}']
|
||||
self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["start"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}']
|
||||
|
||||
def _create_make_action_tool(
|
||||
self,
|
||||
@ -585,7 +586,9 @@ class NavGPTAgent(BaseAgent):
|
||||
},
|
||||
)
|
||||
elif self.config.use_single_action:
|
||||
# We will be here
|
||||
tools = [self.action_maker]
|
||||
print(tools)
|
||||
prompt = PromptTemplate(
|
||||
template=VLN_GPT4_PROMPT if self.config.llm_model_name == 'gpt-4' else VLN_GPT35_PROMPT,
|
||||
input_variables=["action_plan", "init_observation", "agent_scratchpad"],
|
||||
@ -659,13 +662,42 @@ class NavGPTAgent(BaseAgent):
|
||||
else:
|
||||
obs = self.env._get_obs()
|
||||
|
||||
print(len(obs))
|
||||
|
||||
print(obs[0].keys())
|
||||
print(obs[0]['obs'])
|
||||
print(obs[0]['obs_summary'])
|
||||
print(obs[0]['objects'])
|
||||
print(obs[0]['instr_id'])
|
||||
print(obs[0]['scan'])
|
||||
print(obs[0]['viewpoint'])
|
||||
print(obs[0]['heading'])
|
||||
print(obs[0]['elevation'])
|
||||
print(obs[0]['candidate'])
|
||||
print(obs[0]['instruction'])
|
||||
print(obs[0]['gt_path'])
|
||||
print(obs[0]['path_id'])
|
||||
print(obs[0]['stop'])
|
||||
print(obs[0]['start'])
|
||||
print(obs[0]['target'])
|
||||
|
||||
|
||||
print("==")
|
||||
|
||||
# Initialize the trajectory
|
||||
self.init_trajecotry(obs)
|
||||
|
||||
# Load the instruction
|
||||
instructions = [ob['instruction'] for ob in obs]
|
||||
# instructions = [ob['instruction'] for ob in obs]
|
||||
targets = [ob['target'] for ob in obs]
|
||||
|
||||
|
||||
print(self.config.load_instruction)
|
||||
print(self.config.load_action_plan)
|
||||
|
||||
if self.config.load_instruction:
|
||||
action_plans = instructions
|
||||
# action_plans = instructions
|
||||
action_plans = targets
|
||||
elif self.config.load_action_plan:
|
||||
action_plans = [ob['action_plan'] for ob in obs]
|
||||
else:
|
||||
@ -673,11 +705,18 @@ class NavGPTAgent(BaseAgent):
|
||||
for instruction in instructions:
|
||||
action_plan = self.plan_chain.run(instruction = instruction)
|
||||
action_plans.append(action_plan)
|
||||
print(action_plans)
|
||||
|
||||
for i, init_ob in enumerate(obs):
|
||||
|
||||
# for our work
|
||||
# cur_action_plan is "target object with its location"
|
||||
self.cur_action_plan = action_plans[i]
|
||||
|
||||
print("use_tool_chain:", self.config.use_tool_chain)
|
||||
|
||||
# Take the first action
|
||||
if self.config.use_tool_chain:
|
||||
if self.config.use_tool_chain: # we will not HERE
|
||||
first_obs = self.action_maker('')
|
||||
input = {
|
||||
'action_plan': self.cur_action_plan,
|
||||
@ -686,15 +725,20 @@ class NavGPTAgent(BaseAgent):
|
||||
}
|
||||
else:
|
||||
# Get current feature
|
||||
|
||||
# we are HERE
|
||||
feature = init_ob['obs']
|
||||
navigable = init_ob['candidate']
|
||||
objects = init_ob['objects']
|
||||
heading = np.rad2deg(init_ob['heading'])
|
||||
elevation = np.rad2deg(init_ob['elevation'])
|
||||
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
|
||||
if self.config.use_relative_angle:
|
||||
|
||||
print("use_relative_angle:", self.config.use_relative_angle)
|
||||
print("use_relative_angle:", self.config.use_navigable)
|
||||
if self.config.use_relative_angle: # True
|
||||
feature = self.modify_heading_angles(heading, feature, navigable, objects)
|
||||
if self.config.use_navigable:
|
||||
if self.config.use_navigable: # False
|
||||
navigable = self.get_navigable_str(heading, elevation, navigable)
|
||||
|
||||
if self.config.use_relative_angle:
|
||||
@ -708,9 +752,10 @@ class NavGPTAgent(BaseAgent):
|
||||
else:
|
||||
init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}"
|
||||
|
||||
|
||||
input = {
|
||||
'action_plan': self.cur_action_plan,
|
||||
'init_observation': init_observation,
|
||||
'action_plan': self.cur_action_plan, # here will be "object & its location" in our work
|
||||
'init_observation': init_observation, # 8 direction's observation caption & navigable point & objects
|
||||
}
|
||||
output = self.agent_executor(input)
|
||||
|
||||
|
||||
@ -43,9 +43,9 @@ class BaseAgent(object):
|
||||
self.results[traj['instr_id']] = traj
|
||||
preds_detail = self.get_results(detailed_output=True)
|
||||
json.dump(
|
||||
preds_detail,
|
||||
open(os.path.join(self.config.log_dir, 'runtime.json'), 'w'),
|
||||
sort_keys=True, indent=4, separators=(',', ': ')
|
||||
preds_detail,
|
||||
open(os.path.join(self.config.log_dir, 'runtime.json'), 'w'),
|
||||
sort_keys=True, indent=4, separators=(',', ': ')
|
||||
)
|
||||
else: # Do a full round
|
||||
while True:
|
||||
|
||||
@ -35,13 +35,16 @@ class Simulator(object):
|
||||
heading: int,
|
||||
elevation: int,
|
||||
stop: str,
|
||||
start: str):
|
||||
start: str,
|
||||
target: str
|
||||
):
|
||||
self.heading = heading
|
||||
self.elevation = elevation
|
||||
self.scan_ID = scan_ID
|
||||
self.viewpoint_ID = viewpoint_ID
|
||||
self.stop = stop
|
||||
self.start = start
|
||||
self.target = target
|
||||
# Load navigable dict
|
||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||
with open(navigable_path, 'r') as f:
|
||||
@ -63,6 +66,7 @@ class Simulator(object):
|
||||
'candidate': self.candidate,
|
||||
'stop': self.stop,
|
||||
'start': self.start,
|
||||
'target': self.target
|
||||
}
|
||||
return self.state
|
||||
|
||||
@ -107,9 +111,9 @@ class EnvBatch(object):
|
||||
def _make_id(self, scanId, viewpointId):
|
||||
return scanId + '_' + viewpointId
|
||||
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts):
|
||||
for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start)
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts, targets):
|
||||
for i, (scanId, viewpointId, heading, stop, start, target) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts, targets)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start, target)
|
||||
|
||||
def getStates(self):
|
||||
"""
|
||||
@ -233,7 +237,8 @@ class REVERIENavBatch(object):
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['path_id'],
|
||||
'stop': item['stop'],
|
||||
'start': item['start']
|
||||
'start': item['start'],
|
||||
'target': item['target']
|
||||
}
|
||||
# RL reward. The negative distance between the state and the final state
|
||||
# There are multiple gt end viewpoints on REVERIE.
|
||||
@ -257,7 +262,8 @@ class REVERIENavBatch(object):
|
||||
headings = [item['heading'] for item in self.batch]
|
||||
stops = [item['stop'] for item in self.batch]
|
||||
starts = [item['start'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts)
|
||||
targets = [item['target'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts, targets)
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, next_viewpoint_IDs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user