feat: add target in state & let LLM use target to find

This commit is contained in:
Ting-Jun Wang 2024-04-28 22:30:12 +08:00
parent 0546814202
commit fc0ccb1458
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 69 additions and 18 deletions

View File

@ -197,6 +197,7 @@ class NavGPTAgent(BaseAgent):
self.output_parser = NavGPTOutputParser() self.output_parser = NavGPTOutputParser()
self.agent_executor = self.create_vln_agent() self.agent_executor = self.create_vln_agent()
print("AGENT_EXECUTOR: ", type(self.agent_executor))
plan_prompt = PromptTemplate( plan_prompt = PromptTemplate(
template=PLANNER_PROMPT, template=PLANNER_PROMPT,
@ -331,11 +332,11 @@ class NavGPTAgent(BaseAgent):
# Record the navigation path # Record the navigation path
self.traj = [{ self.traj = [{
'instr_id': ob['instr_id'], 'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]], 'path': [[ob['start']]],
'details': [], 'details': [],
} for ob in obs] } for ob in obs]
# Record the history of actions taken # Record the history of actions taken
self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["viewpoint"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}'] self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["start"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}']
def _create_make_action_tool( def _create_make_action_tool(
self, self,
@ -585,7 +586,9 @@ class NavGPTAgent(BaseAgent):
}, },
) )
elif self.config.use_single_action: elif self.config.use_single_action:
# We will be here
tools = [self.action_maker] tools = [self.action_maker]
print(tools)
prompt = PromptTemplate( prompt = PromptTemplate(
template=VLN_GPT4_PROMPT if self.config.llm_model_name == 'gpt-4' else VLN_GPT35_PROMPT, template=VLN_GPT4_PROMPT if self.config.llm_model_name == 'gpt-4' else VLN_GPT35_PROMPT,
input_variables=["action_plan", "init_observation", "agent_scratchpad"], input_variables=["action_plan", "init_observation", "agent_scratchpad"],
@ -658,14 +661,43 @@ class NavGPTAgent(BaseAgent):
obs = self.env.reset() obs = self.env.reset()
else: else:
obs = self.env._get_obs() obs = self.env._get_obs()
print(len(obs))
print(obs[0].keys())
print(obs[0]['obs'])
print(obs[0]['obs_summary'])
print(obs[0]['objects'])
print(obs[0]['instr_id'])
print(obs[0]['scan'])
print(obs[0]['viewpoint'])
print(obs[0]['heading'])
print(obs[0]['elevation'])
print(obs[0]['candidate'])
print(obs[0]['instruction'])
print(obs[0]['gt_path'])
print(obs[0]['path_id'])
print(obs[0]['stop'])
print(obs[0]['start'])
print(obs[0]['target'])
print("==")
# Initialize the trajectory # Initialize the trajectory
self.init_trajecotry(obs) self.init_trajecotry(obs)
# Load the instruction # Load the instruction
instructions = [ob['instruction'] for ob in obs] # instructions = [ob['instruction'] for ob in obs]
targets = [ob['target'] for ob in obs]
print(self.config.load_instruction)
print(self.config.load_action_plan)
if self.config.load_instruction: if self.config.load_instruction:
action_plans = instructions # action_plans = instructions
action_plans = targets
elif self.config.load_action_plan: elif self.config.load_action_plan:
action_plans = [ob['action_plan'] for ob in obs] action_plans = [ob['action_plan'] for ob in obs]
else: else:
@ -673,11 +705,18 @@ class NavGPTAgent(BaseAgent):
for instruction in instructions: for instruction in instructions:
action_plan = self.plan_chain.run(instruction = instruction) action_plan = self.plan_chain.run(instruction = instruction)
action_plans.append(action_plan) action_plans.append(action_plan)
print(action_plans)
for i, init_ob in enumerate(obs): for i, init_ob in enumerate(obs):
# for our work
# cur_action_plan is "target object with its location"
self.cur_action_plan = action_plans[i] self.cur_action_plan = action_plans[i]
print("use_tool_chain:", self.config.use_tool_chain)
# Take the first action # Take the first action
if self.config.use_tool_chain: if self.config.use_tool_chain: # we will not HERE
first_obs = self.action_maker('') first_obs = self.action_maker('')
input = { input = {
'action_plan': self.cur_action_plan, 'action_plan': self.cur_action_plan,
@ -686,15 +725,20 @@ class NavGPTAgent(BaseAgent):
} }
else: else:
# Get current feature # Get current feature
# we are HERE
feature = init_ob['obs'] feature = init_ob['obs']
navigable = init_ob['candidate'] navigable = init_ob['candidate']
objects = init_ob['objects'] objects = init_ob['objects']
heading = np.rad2deg(init_ob['heading']) heading = np.rad2deg(init_ob['heading'])
elevation = np.rad2deg(init_ob['elevation']) elevation = np.rad2deg(init_ob['elevation'])
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}' orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
if self.config.use_relative_angle:
print("use_relative_angle:", self.config.use_relative_angle)
print("use_relative_angle:", self.config.use_navigable)
if self.config.use_relative_angle: # True
feature = self.modify_heading_angles(heading, feature, navigable, objects) feature = self.modify_heading_angles(heading, feature, navigable, objects)
if self.config.use_navigable: if self.config.use_navigable: # False
navigable = self.get_navigable_str(heading, elevation, navigable) navigable = self.get_navigable_str(heading, elevation, navigable)
if self.config.use_relative_angle: if self.config.use_relative_angle:
@ -708,9 +752,10 @@ class NavGPTAgent(BaseAgent):
else: else:
init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}" init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}"
input = { input = {
'action_plan': self.cur_action_plan, 'action_plan': self.cur_action_plan, # here will be "object & its location" in our work
'init_observation': init_observation, 'init_observation': init_observation, # 8 direction's observation caption & navigable point & objects
} }
output = self.agent_executor(input) output = self.agent_executor(input)

View File

@ -43,9 +43,9 @@ class BaseAgent(object):
self.results[traj['instr_id']] = traj self.results[traj['instr_id']] = traj
preds_detail = self.get_results(detailed_output=True) preds_detail = self.get_results(detailed_output=True)
json.dump( json.dump(
preds_detail, preds_detail,
open(os.path.join(self.config.log_dir, 'runtime.json'), 'w'), open(os.path.join(self.config.log_dir, 'runtime.json'), 'w'),
sort_keys=True, indent=4, separators=(',', ': ') sort_keys=True, indent=4, separators=(',', ': ')
) )
else: # Do a full round else: # Do a full round
while True: while True:

View File

@ -35,13 +35,16 @@ class Simulator(object):
heading: int, heading: int,
elevation: int, elevation: int,
stop: str, stop: str,
start: str): start: str,
target: str
):
self.heading = heading self.heading = heading
self.elevation = elevation self.elevation = elevation
self.scan_ID = scan_ID self.scan_ID = scan_ID
self.viewpoint_ID = viewpoint_ID self.viewpoint_ID = viewpoint_ID
self.stop = stop self.stop = stop
self.start = start self.start = start
self.target = target
# Load navigable dict # Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json') navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f: with open(navigable_path, 'r') as f:
@ -63,6 +66,7 @@ class Simulator(object):
'candidate': self.candidate, 'candidate': self.candidate,
'stop': self.stop, 'stop': self.stop,
'start': self.start, 'start': self.start,
'target': self.target
} }
return self.state return self.state
@ -107,9 +111,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId): def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts): def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts, targets):
for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)): for i, (scanId, viewpointId, heading, stop, start, target) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts, targets)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start) self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start, target)
def getStates(self): def getStates(self):
""" """
@ -233,7 +237,8 @@ class REVERIENavBatch(object):
'gt_path' : item['path'], 'gt_path' : item['path'],
'path_id' : item['path_id'], 'path_id' : item['path_id'],
'stop': item['stop'], 'stop': item['stop'],
'start': item['start'] 'start': item['start'],
'target': item['target']
} }
# RL reward. The negative distance between the state and the final state # RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE. # There are multiple gt end viewpoints on REVERIE.
@ -257,7 +262,8 @@ class REVERIENavBatch(object):
headings = [item['heading'] for item in self.batch] headings = [item['heading'] for item in self.batch]
stops = [item['stop'] for item in self.batch] stops = [item['stop'] for item in self.batch]
starts = [item['start'] for item in self.batch] starts = [item['start'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts) targets = [item['target'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts, targets)
return self._get_obs() return self._get_obs()
def step(self, next_viewpoint_IDs): def step(self, next_viewpoint_IDs):