Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32ceca7752 | |||
| 5848e22b1e |
@ -8,7 +8,7 @@ from utils.logger import write_to_record_file
|
|||||||
from utils.data import ImageObservationsDB
|
from utils.data import ImageObservationsDB
|
||||||
from parser import parse_args
|
from parser import parse_args
|
||||||
from env import REVERIENavBatch
|
from env import REVERIENavBatch
|
||||||
from agent import NavGPTAgent
|
from agent import NavGPTAgent, RandomAgent
|
||||||
|
|
||||||
def build_dataset(args, data_limit=100):
|
def build_dataset(args, data_limit=100):
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ def build_dataset(args, data_limit=100):
|
|||||||
|
|
||||||
def valid(args, val_envs):
|
def valid(args, val_envs):
|
||||||
|
|
||||||
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
agent = RandomAgent(next(iter(val_envs.values())), args)
|
||||||
|
|
||||||
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
||||||
json.dump(vars(args), outf, indent=4)
|
json.dump(vars(args), outf, indent=4)
|
||||||
|
|||||||
111
nav_src/agent.py
111
nav_src/agent.py
@ -5,6 +5,7 @@ import re
|
|||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
||||||
|
import random
|
||||||
|
|
||||||
from env import REVERIENavBatch
|
from env import REVERIENavBatch
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
@ -817,3 +818,113 @@ class NavGPTAgent(BaseAgent):
|
|||||||
|
|
||||||
return self.traj
|
return self.traj
|
||||||
|
|
||||||
|
class RandomAgent(BaseAgent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: REVERIENavBatch,
|
||||||
|
config: Namespace):
|
||||||
|
"""
|
||||||
|
Initialize the LLM Navigation Agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The Matterport3D environment.
|
||||||
|
config: The configuration.
|
||||||
|
"""
|
||||||
|
super().__init__(env)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
|
||||||
|
def init_trajecotry(self, obs: List[dict]):
|
||||||
|
"""Initialize the trajectory with the given observation."""
|
||||||
|
# Record the navigation path
|
||||||
|
self.traj = [{
|
||||||
|
'instr_id': ob['instr_id'],
|
||||||
|
'path': [[ob['start']]],
|
||||||
|
'details': [],
|
||||||
|
} for ob in obs]
|
||||||
|
# Record the history of actions taken
|
||||||
|
|
||||||
|
|
||||||
|
def make_equiv_action(self, actions: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Interface between Panoramic view and Egocentric view
|
||||||
|
Take in the next viewpoint ID and move the agent to that viewpoint
|
||||||
|
return the turned angle and new observation
|
||||||
|
"""
|
||||||
|
def normalize_angle(angle):
|
||||||
|
while angle > 180:
|
||||||
|
angle -= 360
|
||||||
|
while angle <= -180:
|
||||||
|
angle += 360
|
||||||
|
return angle
|
||||||
|
|
||||||
|
def angle_to_left_right(angle):
|
||||||
|
return f"left {-angle:.2f}" if angle < 0 else f"right {angle:.2f}"
|
||||||
|
|
||||||
|
# Get current agent facing angle
|
||||||
|
cur_obs = self.env._get_obs()[0]
|
||||||
|
cur_heading = np.rad2deg(cur_obs['heading'])
|
||||||
|
# Make the action
|
||||||
|
new_obs = self.env.step(actions)[0]
|
||||||
|
new_heading = np.rad2deg(new_obs['heading'])
|
||||||
|
# Record the trajectory
|
||||||
|
self.traj[0]['path'].append(self.env.env.sims[0].gmap.bfs_shortest_path(cur_obs['viewpoint'], actions[0])[1:])
|
||||||
|
# Calculate the turned angle
|
||||||
|
turned_angle = new_heading - cur_heading
|
||||||
|
# Generate action description
|
||||||
|
cur_heading = angle_to_left_right(normalize_angle(cur_heading))
|
||||||
|
new_heading = angle_to_left_right(normalize_angle(new_heading))
|
||||||
|
action_description = f'Turn heading direction {turned_angle:.2f} degrees from {cur_heading} to {new_heading}.'
|
||||||
|
return action_description, new_obs
|
||||||
|
|
||||||
|
def rollout(self, reset=True):
|
||||||
|
if reset: # Reset env
|
||||||
|
obs = self.env.reset()
|
||||||
|
else:
|
||||||
|
obs = self.env._get_obs()
|
||||||
|
|
||||||
|
global FINAL_STOP_POINT
|
||||||
|
global TEMP_STEPS_COUNTER
|
||||||
|
global STEPS_COUNTER
|
||||||
|
global SUCCESS
|
||||||
|
|
||||||
|
FINAL_STOP_POINT = obs[0]['stop']
|
||||||
|
|
||||||
|
if TEMP_STEPS_COUNTER != 0:
|
||||||
|
TEMP_STEPS_COUNTER = 0
|
||||||
|
|
||||||
|
print("=="*20)
|
||||||
|
|
||||||
|
# Initialize the trajectory
|
||||||
|
self.init_trajecotry(obs)
|
||||||
|
|
||||||
|
for iteration in range(self.config.max_iterations):
|
||||||
|
next_point = None
|
||||||
|
print(obs[0].keys())
|
||||||
|
print(obs[0]['viewpoint'])
|
||||||
|
for i, init_ob in enumerate(obs):
|
||||||
|
navigable = [ k for k, v in init_ob['candidate'].items() ]
|
||||||
|
next_point = random.choice(navigable)
|
||||||
|
print(next_point)
|
||||||
|
turned_angle, obs = self.make_equiv_action([next_point])
|
||||||
|
obs = [obs]
|
||||||
|
|
||||||
|
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
|
||||||
|
print(f"STEPS_COUNTER={STEPS_COUNTER}")
|
||||||
|
TEMP_STEPS_COUNTER += 1
|
||||||
|
|
||||||
|
if next_point == FINAL_STOP_POINT:
|
||||||
|
print(" SUCCESS")
|
||||||
|
STEPS_COUNTER += TEMP_STEPS_COUNTER
|
||||||
|
SUCCESS += 1
|
||||||
|
TEMP_STEPS_COUNTER = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"FINAL_STOP_POINT={FINAL_STOP_POINT}")
|
||||||
|
print(f"SUCCESS={SUCCESS}")
|
||||||
|
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
|
||||||
|
print(f"STEPS_COUNTER={STEPS_COUNTER}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return self.traj
|
||||||
|
|||||||
@ -14,10 +14,6 @@ class BaseAgent(object):
|
|||||||
output.append({'instr_id': k, 'trajectory': v['path']})
|
output.append({'instr_id': k, 'trajectory': v['path']})
|
||||||
if detailed_output:
|
if detailed_output:
|
||||||
output[-1]['details'] = v['details']
|
output[-1]['details'] = v['details']
|
||||||
output[-1]['action_plan'] = v['action_plan']
|
|
||||||
output[-1]['llm_output'] = v['llm_output']
|
|
||||||
output[-1]['llm_thought'] = v['llm_thought']
|
|
||||||
output[-1]['llm_observation'] = v['llm_observation']
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def rollout(self, **args):
|
def rollout(self, **args):
|
||||||
|
|||||||
@ -136,7 +136,6 @@ class Simulator(object):
|
|||||||
|
|
||||||
self.node_region, self.region_room, self.region_obj, self.node_locations = load_floorplan()
|
self.node_region, self.region_room, self.region_obj, self.node_locations = load_floorplan()
|
||||||
|
|
||||||
|
|
||||||
def newEpisode(
|
def newEpisode(
|
||||||
self,
|
self,
|
||||||
scan_ID: str,
|
scan_ID: str,
|
||||||
@ -171,6 +170,7 @@ class Simulator(object):
|
|||||||
print(start_region, to_region)
|
print(start_region, to_region)
|
||||||
print("AFTER: ", len(self.navigable_dict[start]))
|
print("AFTER: ", len(self.navigable_dict[start]))
|
||||||
|
|
||||||
|
|
||||||
# Get candidate
|
# Get candidate
|
||||||
self.getCandidate()
|
self.getCandidate()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user