feat: random explore
This commit is contained in:
parent
89081b6b21
commit
5848e22b1e
@ -8,7 +8,7 @@ from utils.logger import write_to_record_file
|
|||||||
from utils.data import ImageObservationsDB
|
from utils.data import ImageObservationsDB
|
||||||
from parser import parse_args
|
from parser import parse_args
|
||||||
from env import REVERIENavBatch
|
from env import REVERIENavBatch
|
||||||
from agent import NavGPTAgent
|
from agent import NavGPTAgent, RandomAgent
|
||||||
|
|
||||||
def build_dataset(args, data_limit=100):
|
def build_dataset(args, data_limit=100):
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ def build_dataset(args, data_limit=100):
|
|||||||
|
|
||||||
def valid(args, val_envs):
|
def valid(args, val_envs):
|
||||||
|
|
||||||
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
agent = RandomAgent(next(iter(val_envs.values())), args)
|
||||||
|
|
||||||
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
||||||
json.dump(vars(args), outf, indent=4)
|
json.dump(vars(args), outf, indent=4)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import re
|
|||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
||||||
|
import random
|
||||||
|
|
||||||
from env import REVERIENavBatch
|
from env import REVERIENavBatch
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
@ -884,42 +885,46 @@ class RandomAgent(BaseAgent):
|
|||||||
|
|
||||||
global FINAL_STOP_POINT
|
global FINAL_STOP_POINT
|
||||||
global TEMP_STEPS_COUNTER
|
global TEMP_STEPS_COUNTER
|
||||||
|
global STEPS_COUNTER
|
||||||
|
global SUCCESS
|
||||||
|
|
||||||
FINAL_STOP_POINT = obs[0]['stop']
|
FINAL_STOP_POINT = obs[0]['stop']
|
||||||
|
|
||||||
if TEMP_STEPS_COUNTER != 0:
|
if TEMP_STEPS_COUNTER != 0:
|
||||||
TEMP_STEPS_COUNTER = 0
|
TEMP_STEPS_COUNTER = 0
|
||||||
|
|
||||||
print(obs[0].keys())
|
print("=="*20)
|
||||||
print(obs[0]['obs'])
|
|
||||||
print(obs[0]['obs_summary'])
|
|
||||||
print(obs[0]['objects'])
|
|
||||||
print(obs[0]['instr_id'])
|
|
||||||
print(obs[0]['scan'])
|
|
||||||
print(obs[0]['viewpoint'])
|
|
||||||
print(obs[0]['heading'])
|
|
||||||
print(obs[0]['elevation'])
|
|
||||||
print(obs[0]['candidate'])
|
|
||||||
print(obs[0]['instruction'])
|
|
||||||
print(obs[0]['gt_path'])
|
|
||||||
print(obs[0]['path_id'])
|
|
||||||
print(obs[0]['stop'])
|
|
||||||
print(obs[0]['start'])
|
|
||||||
print(obs[0]['target'])
|
|
||||||
|
|
||||||
|
|
||||||
print("==")
|
|
||||||
|
|
||||||
# Initialize the trajectory
|
# Initialize the trajectory
|
||||||
self.init_trajecotry(obs)
|
self.init_trajecotry(obs)
|
||||||
|
|
||||||
|
for iteration in range(self.config.max_iterations):
|
||||||
|
next_point = None
|
||||||
|
print(obs[0].keys())
|
||||||
|
print(obs[0]['viewpoint'])
|
||||||
for i, init_ob in enumerate(obs):
|
for i, init_ob in enumerate(obs):
|
||||||
|
navigable = [ k for k, v in init_ob['candidate'].items() ]
|
||||||
|
next_point = random.choice(navigable)
|
||||||
|
print(next_point)
|
||||||
|
turned_angle, obs = self.make_equiv_action([next_point])
|
||||||
|
obs = [obs]
|
||||||
|
|
||||||
|
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
|
||||||
|
print(f"STEPS_COUNTER={STEPS_COUNTER}")
|
||||||
|
TEMP_STEPS_COUNTER += 1
|
||||||
|
|
||||||
|
if next_point == FINAL_STOP_POINT:
|
||||||
|
print(" SUCCESS")
|
||||||
|
STEPS_COUNTER += TEMP_STEPS_COUNTER
|
||||||
|
SUCCESS += 1
|
||||||
|
TEMP_STEPS_COUNTER = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"FINAL_STOP_POINT={FINAL_STOP_POINT}")
|
||||||
|
print(f"SUCCESS={SUCCESS}")
|
||||||
|
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
|
||||||
|
print(f"STEPS_COUNTER={STEPS_COUNTER}")
|
||||||
|
|
||||||
navigable = init_ob['candidate']
|
|
||||||
heading = np.rad2deg(init_ob['heading'])
|
|
||||||
elevation = np.rad2deg(init_ob['elevation'])
|
|
||||||
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
|
|
||||||
|
|
||||||
output = self.agent_executor(input)
|
|
||||||
|
|
||||||
return self.traj
|
return self.traj
|
||||||
|
|||||||
@ -14,10 +14,6 @@ class BaseAgent(object):
|
|||||||
output.append({'instr_id': k, 'trajectory': v['path']})
|
output.append({'instr_id': k, 'trajectory': v['path']})
|
||||||
if detailed_output:
|
if detailed_output:
|
||||||
output[-1]['details'] = v['details']
|
output[-1]['details'] = v['details']
|
||||||
output[-1]['action_plan'] = v['action_plan']
|
|
||||||
output[-1]['llm_output'] = v['llm_output']
|
|
||||||
output[-1]['llm_thought'] = v['llm_thought']
|
|
||||||
output[-1]['llm_observation'] = v['llm_observation']
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def rollout(self, **args):
|
def rollout(self, **args):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user