feat: random explore
This commit is contained in:
parent
89081b6b21
commit
5848e22b1e
@ -8,7 +8,7 @@ from utils.logger import write_to_record_file
|
||||
from utils.data import ImageObservationsDB
|
||||
from parser import parse_args
|
||||
from env import REVERIENavBatch
|
||||
from agent import NavGPTAgent
|
||||
from agent import NavGPTAgent, RandomAgent
|
||||
|
||||
def build_dataset(args, data_limit=100):
|
||||
|
||||
@ -35,7 +35,7 @@ def build_dataset(args, data_limit=100):
|
||||
|
||||
def valid(args, val_envs):
|
||||
|
||||
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
||||
agent = RandomAgent(next(iter(val_envs.values())), args)
|
||||
|
||||
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
||||
json.dump(vars(args), outf, indent=4)
|
||||
|
||||
@ -5,6 +5,7 @@ import re
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
||||
import random
|
||||
|
||||
from env import REVERIENavBatch
|
||||
from argparse import Namespace
|
||||
@ -884,42 +885,46 @@ class RandomAgent(BaseAgent):
|
||||
|
||||
global FINAL_STOP_POINT
|
||||
global TEMP_STEPS_COUNTER
|
||||
global STEPS_COUNTER
|
||||
global SUCCESS
|
||||
|
||||
FINAL_STOP_POINT = obs[0]['stop']
|
||||
|
||||
if TEMP_STEPS_COUNTER != 0:
|
||||
TEMP_STEPS_COUNTER = 0
|
||||
|
||||
print(obs[0].keys())
|
||||
print(obs[0]['obs'])
|
||||
print(obs[0]['obs_summary'])
|
||||
print(obs[0]['objects'])
|
||||
print(obs[0]['instr_id'])
|
||||
print(obs[0]['scan'])
|
||||
print(obs[0]['viewpoint'])
|
||||
print(obs[0]['heading'])
|
||||
print(obs[0]['elevation'])
|
||||
print(obs[0]['candidate'])
|
||||
print(obs[0]['instruction'])
|
||||
print(obs[0]['gt_path'])
|
||||
print(obs[0]['path_id'])
|
||||
print(obs[0]['stop'])
|
||||
print(obs[0]['start'])
|
||||
print(obs[0]['target'])
|
||||
|
||||
|
||||
print("==")
|
||||
print("=="*20)
|
||||
|
||||
# Initialize the trajectory
|
||||
self.init_trajecotry(obs)
|
||||
|
||||
for i, init_ob in enumerate(obs):
|
||||
for iteration in range(self.config.max_iterations):
|
||||
next_point = None
|
||||
print(obs[0].keys())
|
||||
print(obs[0]['viewpoint'])
|
||||
for i, init_ob in enumerate(obs):
|
||||
navigable = [ k for k, v in init_ob['candidate'].items() ]
|
||||
next_point = random.choice(navigable)
|
||||
print(next_point)
|
||||
turned_angle, obs = self.make_equiv_action([next_point])
|
||||
obs = [obs]
|
||||
|
||||
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
|
||||
print(f"STEPS_COUNTER={STEPS_COUNTER}")
|
||||
TEMP_STEPS_COUNTER += 1
|
||||
|
||||
if next_point == FINAL_STOP_POINT:
|
||||
print(" SUCCESS")
|
||||
STEPS_COUNTER += TEMP_STEPS_COUNTER
|
||||
SUCCESS += 1
|
||||
TEMP_STEPS_COUNTER = 0
|
||||
break
|
||||
|
||||
print(f"FINAL_STOP_POINT={FINAL_STOP_POINT}")
|
||||
print(f"SUCCESS={SUCCESS}")
|
||||
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
|
||||
print(f"STEPS_COUNTER={STEPS_COUNTER}")
|
||||
|
||||
navigable = init_ob['candidate']
|
||||
heading = np.rad2deg(init_ob['heading'])
|
||||
elevation = np.rad2deg(init_ob['elevation'])
|
||||
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
|
||||
|
||||
output = self.agent_executor(input)
|
||||
|
||||
return self.traj
|
||||
|
||||
@ -14,10 +14,6 @@ class BaseAgent(object):
|
||||
output.append({'instr_id': k, 'trajectory': v['path']})
|
||||
if detailed_output:
|
||||
output[-1]['details'] = v['details']
|
||||
output[-1]['action_plan'] = v['action_plan']
|
||||
output[-1]['llm_output'] = v['llm_output']
|
||||
output[-1]['llm_thought'] = v['llm_thought']
|
||||
output[-1]['llm_observation'] = v['llm_observation']
|
||||
return output
|
||||
|
||||
def rollout(self, **args):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user