feat: random explore

This commit is contained in:
Ting-Jun Wang 2024-04-29 02:19:06 +08:00
parent 89081b6b21
commit 5848e22b1e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 33 additions and 32 deletions

View File

@ -8,7 +8,7 @@ from utils.logger import write_to_record_file
from utils.data import ImageObservationsDB from utils.data import ImageObservationsDB
from parser import parse_args from parser import parse_args
from env import REVERIENavBatch from env import REVERIENavBatch
from agent import NavGPTAgent from agent import NavGPTAgent, RandomAgent
def build_dataset(args, data_limit=100): def build_dataset(args, data_limit=100):
@ -35,7 +35,7 @@ def build_dataset(args, data_limit=100):
def valid(args, val_envs): def valid(args, val_envs):
agent = NavGPTAgent(next(iter(val_envs.values())), args) agent = RandomAgent(next(iter(val_envs.values())), args)
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf: with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
json.dump(vars(args), outf, indent=4) json.dump(vars(args), outf, indent=4)

View File

@ -5,6 +5,7 @@ import re
import warnings import warnings
import numpy as np import numpy as np
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
import random
from env import REVERIENavBatch from env import REVERIENavBatch
from argparse import Namespace from argparse import Namespace
@ -884,42 +885,46 @@ class RandomAgent(BaseAgent):
global FINAL_STOP_POINT global FINAL_STOP_POINT
global TEMP_STEPS_COUNTER global TEMP_STEPS_COUNTER
global STEPS_COUNTER
global SUCCESS
FINAL_STOP_POINT = obs[0]['stop'] FINAL_STOP_POINT = obs[0]['stop']
if TEMP_STEPS_COUNTER != 0: if TEMP_STEPS_COUNTER != 0:
TEMP_STEPS_COUNTER = 0 TEMP_STEPS_COUNTER = 0
print(obs[0].keys()) print("=="*20)
print(obs[0]['obs'])
print(obs[0]['obs_summary'])
print(obs[0]['objects'])
print(obs[0]['instr_id'])
print(obs[0]['scan'])
print(obs[0]['viewpoint'])
print(obs[0]['heading'])
print(obs[0]['elevation'])
print(obs[0]['candidate'])
print(obs[0]['instruction'])
print(obs[0]['gt_path'])
print(obs[0]['path_id'])
print(obs[0]['stop'])
print(obs[0]['start'])
print(obs[0]['target'])
print("==")
# Initialize the trajectory # Initialize the trajectory
self.init_trajecotry(obs) self.init_trajecotry(obs)
for i, init_ob in enumerate(obs): for iteration in range(self.config.max_iterations):
next_point = None
navigable = init_ob['candidate'] print(obs[0].keys())
heading = np.rad2deg(init_ob['heading']) print(obs[0]['viewpoint'])
elevation = np.rad2deg(init_ob['elevation']) for i, init_ob in enumerate(obs):
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}' navigable = [ k for k, v in init_ob['candidate'].items() ]
next_point = random.choice(navigable)
print(next_point)
turned_angle, obs = self.make_equiv_action([next_point])
obs = [obs]
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
print(f"STEPS_COUNTER={STEPS_COUNTER}")
TEMP_STEPS_COUNTER += 1
if next_point == FINAL_STOP_POINT:
print(" SUCCESS")
STEPS_COUNTER += TEMP_STEPS_COUNTER
SUCCESS += 1
TEMP_STEPS_COUNTER = 0
break
print(f"FINAL_STOP_POINT={FINAL_STOP_POINT}")
print(f"SUCCESS={SUCCESS}")
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
print(f"STEPS_COUNTER={STEPS_COUNTER}")
output = self.agent_executor(input)
return self.traj return self.traj

View File

@ -14,10 +14,6 @@ class BaseAgent(object):
output.append({'instr_id': k, 'trajectory': v['path']}) output.append({'instr_id': k, 'trajectory': v['path']})
if detailed_output: if detailed_output:
output[-1]['details'] = v['details'] output[-1]['details'] = v['details']
output[-1]['action_plan'] = v['action_plan']
output[-1]['llm_output'] = v['llm_output']
output[-1]['llm_thought'] = v['llm_thought']
output[-1]['llm_observation'] = v['llm_observation']
return output return output
def rollout(self, **args): def rollout(self, **args):