From 5848e22b1ef23fdab9da6a543942319038de88c5 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 29 Apr 2024 02:19:06 +0800 Subject: [PATCH] feat: random explore --- nav_src/NavGPT.py | 4 +-- nav_src/agent.py | 57 +++++++++++++++++++++++-------------------- nav_src/agent_base.py | 4 --- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/nav_src/NavGPT.py b/nav_src/NavGPT.py index 2570df2..cf86ccb 100644 --- a/nav_src/NavGPT.py +++ b/nav_src/NavGPT.py @@ -8,7 +8,7 @@ from utils.logger import write_to_record_file from utils.data import ImageObservationsDB from parser import parse_args from env import REVERIENavBatch -from agent import NavGPTAgent +from agent import NavGPTAgent, RandomAgent def build_dataset(args, data_limit=100): @@ -35,7 +35,7 @@ def build_dataset(args, data_limit=100): def valid(args, val_envs): - agent = NavGPTAgent(next(iter(val_envs.values())), args) + agent = RandomAgent(next(iter(val_envs.values())), args) with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf: json.dump(vars(args), outf, indent=4) diff --git a/nav_src/agent.py b/nav_src/agent.py index 413b43d..2558e60 100644 --- a/nav_src/agent.py +++ b/nav_src/agent.py @@ -5,6 +5,7 @@ import re import warnings import numpy as np from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union +import random from env import REVERIENavBatch from argparse import Namespace @@ -884,42 +885,46 @@ class RandomAgent(BaseAgent): global FINAL_STOP_POINT global TEMP_STEPS_COUNTER + global STEPS_COUNTER + global SUCCESS FINAL_STOP_POINT = obs[0]['stop'] if TEMP_STEPS_COUNTER != 0: TEMP_STEPS_COUNTER = 0 - print(obs[0].keys()) - print(obs[0]['obs']) - print(obs[0]['obs_summary']) - print(obs[0]['objects']) - print(obs[0]['instr_id']) - print(obs[0]['scan']) - print(obs[0]['viewpoint']) - print(obs[0]['heading']) - print(obs[0]['elevation']) - print(obs[0]['candidate']) - print(obs[0]['instruction']) - print(obs[0]['gt_path']) - print(obs[0]['path_id']) - print(obs[0]['stop']) - print(obs[0]['start']) - print(obs[0]['target']) - - - print("==") + print("=="*20) # Initialize the trajectory self.init_trajecotry(obs) - for i, init_ob in enumerate(obs): - - navigable = init_ob['candidate'] - heading = np.rad2deg(init_ob['heading']) - elevation = np.rad2deg(init_ob['elevation']) - orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}' + for iteration in range(self.config.max_iterations): + next_point = None + print(obs[0].keys()) + print(obs[0]['viewpoint']) + for i, init_ob in enumerate(obs): + navigable = [ k for k, v in init_ob['candidate'].items() ] + next_point = random.choice(navigable) + print(next_point) + turned_angle, obs = self.make_equiv_action([next_point]) + obs = [obs] + + print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}") + print(f"STEPS_COUNTER={STEPS_COUNTER}") + TEMP_STEPS_COUNTER += 1 + + if next_point == FINAL_STOP_POINT: + print(" SUCCESS") + STEPS_COUNTER += TEMP_STEPS_COUNTER + SUCCESS += 1 + TEMP_STEPS_COUNTER = 0 + break + + print(f"FINAL_STOP_POINT={FINAL_STOP_POINT}") + print(f"SUCCESS={SUCCESS}") + print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}") + print(f"STEPS_COUNTER={STEPS_COUNTER}") + - output = self.agent_executor(input) return self.traj diff --git a/nav_src/agent_base.py b/nav_src/agent_base.py index c8aab3e..f0107bb 100644 --- a/nav_src/agent_base.py +++ b/nav_src/agent_base.py @@ -14,10 +14,6 @@ class BaseAgent(object): output.append({'instr_id': k, 'trajectory': v['path']}) if detailed_output: output[-1]['details'] = v['details'] - output[-1]['action_plan'] = v['action_plan'] - output[-1]['llm_output'] = v['llm_output'] - output[-1]['llm_thought'] = v['llm_thought'] - output[-1]['llm_observation'] = v['llm_observation'] return output def rollout(self, **args):