feat: random explore

This commit is contained in:
Ting-Jun Wang 2024-04-29 02:19:06 +08:00
parent 89081b6b21
commit 5848e22b1e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 33 additions and 32 deletions

View File

@ -8,7 +8,7 @@ from utils.logger import write_to_record_file
from utils.data import ImageObservationsDB
from parser import parse_args
from env import REVERIENavBatch
from agent import NavGPTAgent
from agent import NavGPTAgent, RandomAgent
def build_dataset(args, data_limit=100):
@ -35,7 +35,7 @@ def build_dataset(args, data_limit=100):
def valid(args, val_envs):
agent = NavGPTAgent(next(iter(val_envs.values())), args)
agent = RandomAgent(next(iter(val_envs.values())), args)
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
json.dump(vars(args), outf, indent=4)

View File

@ -5,6 +5,7 @@ import re
import warnings
import numpy as np
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
import random
from env import REVERIENavBatch
from argparse import Namespace
@ -884,42 +885,46 @@ class RandomAgent(BaseAgent):
global FINAL_STOP_POINT
global TEMP_STEPS_COUNTER
global STEPS_COUNTER
global SUCCESS
FINAL_STOP_POINT = obs[0]['stop']
if TEMP_STEPS_COUNTER != 0:
TEMP_STEPS_COUNTER = 0
print(obs[0].keys())
print(obs[0]['obs'])
print(obs[0]['obs_summary'])
print(obs[0]['objects'])
print(obs[0]['instr_id'])
print(obs[0]['scan'])
print(obs[0]['viewpoint'])
print(obs[0]['heading'])
print(obs[0]['elevation'])
print(obs[0]['candidate'])
print(obs[0]['instruction'])
print(obs[0]['gt_path'])
print(obs[0]['path_id'])
print(obs[0]['stop'])
print(obs[0]['start'])
print(obs[0]['target'])
print("==")
print("=="*20)
# Initialize the trajectory
self.init_trajecotry(obs)
for i, init_ob in enumerate(obs):
for iteration in range(self.config.max_iterations):
next_point = None
print(obs[0].keys())
print(obs[0]['viewpoint'])
for i, init_ob in enumerate(obs):
navigable = [ k for k, v in init_ob['candidate'].items() ]
next_point = random.choice(navigable)
print(next_point)
turned_angle, obs = self.make_equiv_action([next_point])
obs = [obs]
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
print(f"STEPS_COUNTER={STEPS_COUNTER}")
TEMP_STEPS_COUNTER += 1
if next_point == FINAL_STOP_POINT:
print(" SUCCESS")
STEPS_COUNTER += TEMP_STEPS_COUNTER
SUCCESS += 1
TEMP_STEPS_COUNTER = 0
break
print(f"FINAL_STOP_POINT={FINAL_STOP_POINT}")
print(f"SUCCESS={SUCCESS}")
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
print(f"STEPS_COUNTER={STEPS_COUNTER}")
navigable = init_ob['candidate']
heading = np.rad2deg(init_ob['heading'])
elevation = np.rad2deg(init_ob['elevation'])
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
output = self.agent_executor(input)
return self.traj

View File

@ -14,10 +14,6 @@ class BaseAgent(object):
output.append({'instr_id': k, 'trajectory': v['path']})
if detailed_output:
output[-1]['details'] = v['details']
output[-1]['action_plan'] = v['action_plan']
output[-1]['llm_output'] = v['llm_output']
output[-1]['llm_thought'] = v['llm_thought']
output[-1]['llm_observation'] = v['llm_observation']
return output
def rollout(self, **args):