feat: add target in state & let LLM use target to find

This commit is contained in:
Ting-Jun Wang 2024-04-28 22:30:12 +08:00
parent 0546814202
commit fc0ccb1458
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 69 additions and 18 deletions

View File

@ -197,6 +197,7 @@ class NavGPTAgent(BaseAgent):
self.output_parser = NavGPTOutputParser()
self.agent_executor = self.create_vln_agent()
print("AGENT_EXECUTOR: ", type(self.agent_executor))
plan_prompt = PromptTemplate(
template=PLANNER_PROMPT,
@ -331,11 +332,11 @@ class NavGPTAgent(BaseAgent):
# Record the navigation path
self.traj = [{
'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]],
'path': [[ob['start']]],
'details': [],
} for ob in obs]
# Record the history of actions taken
self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["viewpoint"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}']
self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["start"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}']
def _create_make_action_tool(
self,
@ -585,7 +586,9 @@ class NavGPTAgent(BaseAgent):
},
)
elif self.config.use_single_action:
# We will be here
tools = [self.action_maker]
print(tools)
prompt = PromptTemplate(
template=VLN_GPT4_PROMPT if self.config.llm_model_name == 'gpt-4' else VLN_GPT35_PROMPT,
input_variables=["action_plan", "init_observation", "agent_scratchpad"],
@ -659,13 +662,42 @@ class NavGPTAgent(BaseAgent):
else:
obs = self.env._get_obs()
print(len(obs))
print(obs[0].keys())
print(obs[0]['obs'])
print(obs[0]['obs_summary'])
print(obs[0]['objects'])
print(obs[0]['instr_id'])
print(obs[0]['scan'])
print(obs[0]['viewpoint'])
print(obs[0]['heading'])
print(obs[0]['elevation'])
print(obs[0]['candidate'])
print(obs[0]['instruction'])
print(obs[0]['gt_path'])
print(obs[0]['path_id'])
print(obs[0]['stop'])
print(obs[0]['start'])
print(obs[0]['target'])
print("==")
# Initialize the trajectory
self.init_trajecotry(obs)
# Load the instruction
instructions = [ob['instruction'] for ob in obs]
# instructions = [ob['instruction'] for ob in obs]
targets = [ob['target'] for ob in obs]
print(self.config.load_instruction)
print(self.config.load_action_plan)
if self.config.load_instruction:
action_plans = instructions
# action_plans = instructions
action_plans = targets
elif self.config.load_action_plan:
action_plans = [ob['action_plan'] for ob in obs]
else:
@ -673,11 +705,18 @@ class NavGPTAgent(BaseAgent):
for instruction in instructions:
action_plan = self.plan_chain.run(instruction = instruction)
action_plans.append(action_plan)
print(action_plans)
for i, init_ob in enumerate(obs):
# for our work
# cur_action_plan is "target object with its location"
self.cur_action_plan = action_plans[i]
print("use_tool_chain:", self.config.use_tool_chain)
# Take the first action
if self.config.use_tool_chain:
if self.config.use_tool_chain: # we will not HERE
first_obs = self.action_maker('')
input = {
'action_plan': self.cur_action_plan,
@ -686,15 +725,20 @@ class NavGPTAgent(BaseAgent):
}
else:
# Get current feature
# we are HERE
feature = init_ob['obs']
navigable = init_ob['candidate']
objects = init_ob['objects']
heading = np.rad2deg(init_ob['heading'])
elevation = np.rad2deg(init_ob['elevation'])
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
if self.config.use_relative_angle:
print("use_relative_angle:", self.config.use_relative_angle)
print("use_relative_angle:", self.config.use_navigable)
if self.config.use_relative_angle: # True
feature = self.modify_heading_angles(heading, feature, navigable, objects)
if self.config.use_navigable:
if self.config.use_navigable: # False
navigable = self.get_navigable_str(heading, elevation, navigable)
if self.config.use_relative_angle:
@ -708,9 +752,10 @@ class NavGPTAgent(BaseAgent):
else:
init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}"
input = {
'action_plan': self.cur_action_plan,
'init_observation': init_observation,
'action_plan': self.cur_action_plan, # here will be "object & its location" in our work
'init_observation': init_observation, # 8 direction's observation caption & navigable point & objects
}
output = self.agent_executor(input)

View File

@ -35,13 +35,16 @@ class Simulator(object):
heading: int,
elevation: int,
stop: str,
start: str):
start: str,
target: str
):
self.heading = heading
self.elevation = elevation
self.scan_ID = scan_ID
self.viewpoint_ID = viewpoint_ID
self.stop = stop
self.start = start
self.target = target
# Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f:
@ -63,6 +66,7 @@ class Simulator(object):
'candidate': self.candidate,
'stop': self.stop,
'start': self.start,
'target': self.target
}
return self.state
@ -107,9 +111,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts):
for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start)
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts, targets):
for i, (scanId, viewpointId, heading, stop, start, target) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts, targets)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start, target)
def getStates(self):
"""
@ -233,7 +237,8 @@ class REVERIENavBatch(object):
'gt_path' : item['path'],
'path_id' : item['path_id'],
'stop': item['stop'],
'start': item['start']
'start': item['start'],
'target': item['target']
}
# RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE.
@ -257,7 +262,8 @@ class REVERIENavBatch(object):
headings = [item['heading'] for item in self.batch]
stops = [item['stop'] for item in self.batch]
starts = [item['start'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts)
targets = [item['target'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts, targets)
return self._get_obs()
def step(self, next_viewpoint_IDs):