feat: add target in state & let LLM use target to find
This commit is contained in:
parent
0546814202
commit
fc0ccb1458
@ -197,6 +197,7 @@ class NavGPTAgent(BaseAgent):
|
|||||||
|
|
||||||
self.output_parser = NavGPTOutputParser()
|
self.output_parser = NavGPTOutputParser()
|
||||||
self.agent_executor = self.create_vln_agent()
|
self.agent_executor = self.create_vln_agent()
|
||||||
|
print("AGENT_EXECUTOR: ", type(self.agent_executor))
|
||||||
|
|
||||||
plan_prompt = PromptTemplate(
|
plan_prompt = PromptTemplate(
|
||||||
template=PLANNER_PROMPT,
|
template=PLANNER_PROMPT,
|
||||||
@ -331,11 +332,11 @@ class NavGPTAgent(BaseAgent):
|
|||||||
# Record the navigation path
|
# Record the navigation path
|
||||||
self.traj = [{
|
self.traj = [{
|
||||||
'instr_id': ob['instr_id'],
|
'instr_id': ob['instr_id'],
|
||||||
'path': [[ob['viewpoint']]],
|
'path': [[ob['start']]],
|
||||||
'details': [],
|
'details': [],
|
||||||
} for ob in obs]
|
} for ob in obs]
|
||||||
# Record the history of actions taken
|
# Record the history of actions taken
|
||||||
self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["viewpoint"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}']
|
self.agent_executor.agent.history = [f'Navigation start, no actions taken yet.\nCurrent viewpoint "{obs[0]["start"]}": Scene from the viewpoint is a {obs[0]["obs_summary"]}']
|
||||||
|
|
||||||
def _create_make_action_tool(
|
def _create_make_action_tool(
|
||||||
self,
|
self,
|
||||||
@ -585,7 +586,9 @@ class NavGPTAgent(BaseAgent):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
elif self.config.use_single_action:
|
elif self.config.use_single_action:
|
||||||
|
# We will be here
|
||||||
tools = [self.action_maker]
|
tools = [self.action_maker]
|
||||||
|
print(tools)
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=VLN_GPT4_PROMPT if self.config.llm_model_name == 'gpt-4' else VLN_GPT35_PROMPT,
|
template=VLN_GPT4_PROMPT if self.config.llm_model_name == 'gpt-4' else VLN_GPT35_PROMPT,
|
||||||
input_variables=["action_plan", "init_observation", "agent_scratchpad"],
|
input_variables=["action_plan", "init_observation", "agent_scratchpad"],
|
||||||
@ -658,14 +661,43 @@ class NavGPTAgent(BaseAgent):
|
|||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
else:
|
else:
|
||||||
obs = self.env._get_obs()
|
obs = self.env._get_obs()
|
||||||
|
|
||||||
|
print(len(obs))
|
||||||
|
|
||||||
|
print(obs[0].keys())
|
||||||
|
print(obs[0]['obs'])
|
||||||
|
print(obs[0]['obs_summary'])
|
||||||
|
print(obs[0]['objects'])
|
||||||
|
print(obs[0]['instr_id'])
|
||||||
|
print(obs[0]['scan'])
|
||||||
|
print(obs[0]['viewpoint'])
|
||||||
|
print(obs[0]['heading'])
|
||||||
|
print(obs[0]['elevation'])
|
||||||
|
print(obs[0]['candidate'])
|
||||||
|
print(obs[0]['instruction'])
|
||||||
|
print(obs[0]['gt_path'])
|
||||||
|
print(obs[0]['path_id'])
|
||||||
|
print(obs[0]['stop'])
|
||||||
|
print(obs[0]['start'])
|
||||||
|
print(obs[0]['target'])
|
||||||
|
|
||||||
|
|
||||||
|
print("==")
|
||||||
|
|
||||||
# Initialize the trajectory
|
# Initialize the trajectory
|
||||||
self.init_trajecotry(obs)
|
self.init_trajecotry(obs)
|
||||||
|
|
||||||
# Load the instruction
|
# Load the instruction
|
||||||
instructions = [ob['instruction'] for ob in obs]
|
# instructions = [ob['instruction'] for ob in obs]
|
||||||
|
targets = [ob['target'] for ob in obs]
|
||||||
|
|
||||||
|
|
||||||
|
print(self.config.load_instruction)
|
||||||
|
print(self.config.load_action_plan)
|
||||||
|
|
||||||
if self.config.load_instruction:
|
if self.config.load_instruction:
|
||||||
action_plans = instructions
|
# action_plans = instructions
|
||||||
|
action_plans = targets
|
||||||
elif self.config.load_action_plan:
|
elif self.config.load_action_plan:
|
||||||
action_plans = [ob['action_plan'] for ob in obs]
|
action_plans = [ob['action_plan'] for ob in obs]
|
||||||
else:
|
else:
|
||||||
@ -673,11 +705,18 @@ class NavGPTAgent(BaseAgent):
|
|||||||
for instruction in instructions:
|
for instruction in instructions:
|
||||||
action_plan = self.plan_chain.run(instruction = instruction)
|
action_plan = self.plan_chain.run(instruction = instruction)
|
||||||
action_plans.append(action_plan)
|
action_plans.append(action_plan)
|
||||||
|
print(action_plans)
|
||||||
|
|
||||||
for i, init_ob in enumerate(obs):
|
for i, init_ob in enumerate(obs):
|
||||||
|
|
||||||
|
# for our work
|
||||||
|
# cur_action_plan is "target object with its location"
|
||||||
self.cur_action_plan = action_plans[i]
|
self.cur_action_plan = action_plans[i]
|
||||||
|
|
||||||
|
print("use_tool_chain:", self.config.use_tool_chain)
|
||||||
|
|
||||||
# Take the first action
|
# Take the first action
|
||||||
if self.config.use_tool_chain:
|
if self.config.use_tool_chain: # we will not HERE
|
||||||
first_obs = self.action_maker('')
|
first_obs = self.action_maker('')
|
||||||
input = {
|
input = {
|
||||||
'action_plan': self.cur_action_plan,
|
'action_plan': self.cur_action_plan,
|
||||||
@ -686,15 +725,20 @@ class NavGPTAgent(BaseAgent):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Get current feature
|
# Get current feature
|
||||||
|
|
||||||
|
# we are HERE
|
||||||
feature = init_ob['obs']
|
feature = init_ob['obs']
|
||||||
navigable = init_ob['candidate']
|
navigable = init_ob['candidate']
|
||||||
objects = init_ob['objects']
|
objects = init_ob['objects']
|
||||||
heading = np.rad2deg(init_ob['heading'])
|
heading = np.rad2deg(init_ob['heading'])
|
||||||
elevation = np.rad2deg(init_ob['elevation'])
|
elevation = np.rad2deg(init_ob['elevation'])
|
||||||
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
|
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
|
||||||
if self.config.use_relative_angle:
|
|
||||||
|
print("use_relative_angle:", self.config.use_relative_angle)
|
||||||
|
print("use_relative_angle:", self.config.use_navigable)
|
||||||
|
if self.config.use_relative_angle: # True
|
||||||
feature = self.modify_heading_angles(heading, feature, navigable, objects)
|
feature = self.modify_heading_angles(heading, feature, navigable, objects)
|
||||||
if self.config.use_navigable:
|
if self.config.use_navigable: # False
|
||||||
navigable = self.get_navigable_str(heading, elevation, navigable)
|
navigable = self.get_navigable_str(heading, elevation, navigable)
|
||||||
|
|
||||||
if self.config.use_relative_angle:
|
if self.config.use_relative_angle:
|
||||||
@ -708,9 +752,10 @@ class NavGPTAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}"
|
init_observation = f"\n\tCurrent Orientation:\n{orientation}\n\tCurrent Viewpoint:\n{feature}"
|
||||||
|
|
||||||
|
|
||||||
input = {
|
input = {
|
||||||
'action_plan': self.cur_action_plan,
|
'action_plan': self.cur_action_plan, # here will be "object & its location" in our work
|
||||||
'init_observation': init_observation,
|
'init_observation': init_observation, # 8 direction's observation caption & navigable point & objects
|
||||||
}
|
}
|
||||||
output = self.agent_executor(input)
|
output = self.agent_executor(input)
|
||||||
|
|
||||||
|
|||||||
@ -43,9 +43,9 @@ class BaseAgent(object):
|
|||||||
self.results[traj['instr_id']] = traj
|
self.results[traj['instr_id']] = traj
|
||||||
preds_detail = self.get_results(detailed_output=True)
|
preds_detail = self.get_results(detailed_output=True)
|
||||||
json.dump(
|
json.dump(
|
||||||
preds_detail,
|
preds_detail,
|
||||||
open(os.path.join(self.config.log_dir, 'runtime.json'), 'w'),
|
open(os.path.join(self.config.log_dir, 'runtime.json'), 'w'),
|
||||||
sort_keys=True, indent=4, separators=(',', ': ')
|
sort_keys=True, indent=4, separators=(',', ': ')
|
||||||
)
|
)
|
||||||
else: # Do a full round
|
else: # Do a full round
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@ -35,13 +35,16 @@ class Simulator(object):
|
|||||||
heading: int,
|
heading: int,
|
||||||
elevation: int,
|
elevation: int,
|
||||||
stop: str,
|
stop: str,
|
||||||
start: str):
|
start: str,
|
||||||
|
target: str
|
||||||
|
):
|
||||||
self.heading = heading
|
self.heading = heading
|
||||||
self.elevation = elevation
|
self.elevation = elevation
|
||||||
self.scan_ID = scan_ID
|
self.scan_ID = scan_ID
|
||||||
self.viewpoint_ID = viewpoint_ID
|
self.viewpoint_ID = viewpoint_ID
|
||||||
self.stop = stop
|
self.stop = stop
|
||||||
self.start = start
|
self.start = start
|
||||||
|
self.target = target
|
||||||
# Load navigable dict
|
# Load navigable dict
|
||||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||||
with open(navigable_path, 'r') as f:
|
with open(navigable_path, 'r') as f:
|
||||||
@ -63,6 +66,7 @@ class Simulator(object):
|
|||||||
'candidate': self.candidate,
|
'candidate': self.candidate,
|
||||||
'stop': self.stop,
|
'stop': self.stop,
|
||||||
'start': self.start,
|
'start': self.start,
|
||||||
|
'target': self.target
|
||||||
}
|
}
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
@ -107,9 +111,9 @@ class EnvBatch(object):
|
|||||||
def _make_id(self, scanId, viewpointId):
|
def _make_id(self, scanId, viewpointId):
|
||||||
return scanId + '_' + viewpointId
|
return scanId + '_' + viewpointId
|
||||||
|
|
||||||
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts):
|
def newEpisodes(self, scanIds, viewpointIds, headings, stops, starts, targets):
|
||||||
for i, (scanId, viewpointId, heading, stop, start) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts)):
|
for i, (scanId, viewpointId, heading, stop, start, target) in enumerate(zip(scanIds, viewpointIds, headings, stops, starts, targets)):
|
||||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start)
|
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop, start, target)
|
||||||
|
|
||||||
def getStates(self):
|
def getStates(self):
|
||||||
"""
|
"""
|
||||||
@ -233,7 +237,8 @@ class REVERIENavBatch(object):
|
|||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'path_id' : item['path_id'],
|
'path_id' : item['path_id'],
|
||||||
'stop': item['stop'],
|
'stop': item['stop'],
|
||||||
'start': item['start']
|
'start': item['start'],
|
||||||
|
'target': item['target']
|
||||||
}
|
}
|
||||||
# RL reward. The negative distance between the state and the final state
|
# RL reward. The negative distance between the state and the final state
|
||||||
# There are multiple gt end viewpoints on REVERIE.
|
# There are multiple gt end viewpoints on REVERIE.
|
||||||
@ -257,7 +262,8 @@ class REVERIENavBatch(object):
|
|||||||
headings = [item['heading'] for item in self.batch]
|
headings = [item['heading'] for item in self.batch]
|
||||||
stops = [item['stop'] for item in self.batch]
|
stops = [item['stop'] for item in self.batch]
|
||||||
starts = [item['start'] for item in self.batch]
|
starts = [item['start'] for item in self.batch]
|
||||||
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts)
|
targets = [item['target'] for item in self.batch]
|
||||||
|
self.env.newEpisodes(scanIds, viewpointIds, headings, stops, starts, targets)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def step(self, next_viewpoint_IDs):
|
def step(self, next_viewpoint_IDs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user