Compare commits

...

2 Commits

Author SHA1 Message Date
32ceca7752
feat: random with boundary 2024-04-29 02:39:46 +08:00
5848e22b1e
feat: random explore 2024-04-29 02:19:06 +08:00
4 changed files with 156 additions and 33 deletions

View File

@ -8,7 +8,7 @@ from utils.logger import write_to_record_file
from utils.data import ImageObservationsDB
from parser import parse_args
from env import REVERIENavBatch
from agent import NavGPTAgent
from agent import NavGPTAgent, RandomAgent
def build_dataset(args, data_limit=100):
@ -35,7 +35,7 @@ def build_dataset(args, data_limit=100):
def valid(args, val_envs):
agent = NavGPTAgent(next(iter(val_envs.values())), args)
agent = RandomAgent(next(iter(val_envs.values())), args)
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
json.dump(vars(args), outf, indent=4)

View File

@ -5,6 +5,7 @@ import re
import warnings
import numpy as np
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
import random
from env import REVERIENavBatch
from argparse import Namespace
@ -884,42 +885,46 @@ class RandomAgent(BaseAgent):
global FINAL_STOP_POINT
global TEMP_STEPS_COUNTER
global STEPS_COUNTER
global SUCCESS
FINAL_STOP_POINT = obs[0]['stop']
if TEMP_STEPS_COUNTER != 0:
TEMP_STEPS_COUNTER = 0
print(obs[0].keys())
print(obs[0]['obs'])
print(obs[0]['obs_summary'])
print(obs[0]['objects'])
print(obs[0]['instr_id'])
print(obs[0]['scan'])
print(obs[0]['viewpoint'])
print(obs[0]['heading'])
print(obs[0]['elevation'])
print(obs[0]['candidate'])
print(obs[0]['instruction'])
print(obs[0]['gt_path'])
print(obs[0]['path_id'])
print(obs[0]['stop'])
print(obs[0]['start'])
print(obs[0]['target'])
print("==")
print("=="*20)
# Initialize the trajectory
self.init_trajecotry(obs)
for i, init_ob in enumerate(obs):
navigable = init_ob['candidate']
heading = np.rad2deg(init_ob['heading'])
elevation = np.rad2deg(init_ob['elevation'])
orientation = f'\nheading: {heading:.2f}, elevation: {elevation:.2f}'
for iteration in range(self.config.max_iterations):
next_point = None
print(obs[0].keys())
print(obs[0]['viewpoint'])
for i, init_ob in enumerate(obs):
navigable = [ k for k, v in init_ob['candidate'].items() ]
next_point = random.choice(navigable)
print(next_point)
turned_angle, obs = self.make_equiv_action([next_point])
obs = [obs]
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
print(f"STEPS_COUNTER={STEPS_COUNTER}")
TEMP_STEPS_COUNTER += 1
if next_point == FINAL_STOP_POINT:
print(" SUCCESS")
STEPS_COUNTER += TEMP_STEPS_COUNTER
SUCCESS += 1
TEMP_STEPS_COUNTER = 0
break
print(f"FINAL_STOP_POINT={FINAL_STOP_POINT}")
print(f"SUCCESS={SUCCESS}")
print(f"TEMP_STEPS_COUNTER={TEMP_STEPS_COUNTER}")
print(f"STEPS_COUNTER={STEPS_COUNTER}")
output = self.agent_executor(input)
return self.traj

View File

@ -14,10 +14,6 @@ class BaseAgent(object):
output.append({'instr_id': k, 'trajectory': v['path']})
if detailed_output:
output[-1]['details'] = v['details']
output[-1]['action_plan'] = v['action_plan']
output[-1]['llm_output'] = v['llm_output']
output[-1]['llm_thought'] = v['llm_thought']
output[-1]['llm_observation'] = v['llm_observation']
return output
def rollout(self, **args):

View File

@ -6,6 +6,7 @@ import numpy as np
import random
import networkx as nx
from collections import defaultdict
from glob import glob
from utils.data import load_nav_graphs
from eval_utils import cal_dtw, cal_cls
@ -13,6 +14,111 @@ from utils.graph_utils import NavGraph
ERROR_MARGIN = 3.0
def load_floorplan():
region_label_lookup = load_region_label_lookup()
house_files = glob('/home/snsd0805/code/research/VLN/base_dir/v1/scans/*/house_segmentations/*.house')
node_region_lookups = {}
region_room_lookups = {}
region_object_lookups = {}
node_locations_lookups = {}
for house_file in house_files:
scan_id = house_file.split("/")[-3]
regions, floors, node_id_regions, node_id_floors = {}, {}, {}, {}
room_bboxes = {}
node_coors = {}
node_locations = {}
region_objects = defaultdict(list)
object_name_lookup = {}
#print(scan_id, datetime.now())
#house_lines = []
for line in open(house_file):
house_line = line.strip()
#house_lines.append(line.strip())
#for house_line in house_lines[1:]:
house_line_cols = house_line.split()
house_line_type = house_line_cols[0]
house_line_cols = house_line_cols[1:]
if house_line_type=='R':
region_index, level_index, _, _, label, px, py, pz, xlo, ylo, zlo, xhi, yhi, zhi, height,_,_,_,_ = house_line_cols
regions[region_index] = region_label_lookup[label]
floors[region_index] = level_index
room_bboxes[region_index] = {
'name': region_label_lookup[label],
'floor': level_index
}
#for var_name in ['px', 'py', 'pz', 'xlo', 'ylo', 'zlo', 'xhi', 'yhi', 'zhi', 'height']:
# room_bboxes[region_index][var_name] = float(eval(var_name))
if house_line_type=='P':
node_id, panorama_index, region_index, _, px, py, pz, _,_,_,_,_ = house_line_cols
node_id_regions[node_id] = region_index#regions[region_index]
node_locations[node_id] = (px, py, pz)
#node_id_floors[node_id] = int(floors[region_index]) + 1
#node_coors[node_id] = (float(px), float(py), float(pz))
#raise
#if house_line_type=='I':
#break
if house_line_type=='C':
category_index, category_mapping_index, category_mapping_name, mpcat40_index, mpcat40_name, _,_,_,_,_ = house_line_cols
object_name_lookup[category_index] = category_mapping_name
if house_line_type=='O':
object_index, region_index, category_index, px, py, pz, a0x, a0y, a0z, a1x, a1y, a1z, r0, r1, r2, _, _, _, _, _, _, _, _ = house_line_cols
if category_index=='-1' or region_index=='-1':
#print("error")
continue
region_objects[region_index].append(object_name_lookup[category_index])
#room_lookups[scan_id] = node_id_regions
#floor_lookups[scan_id] = node_id_floors
region_room_lookups[scan_id] = room_bboxes
node_region_lookups[scan_id] = node_id_regions
node_locations_lookups[scan_id] = node_locations
region_object_lookups[scan_id] = {k:sorted(v) for k,v in region_objects.items()}
#node_coor_lookups[scan_id] = node_coors
return node_region_lookups, region_room_lookups, region_object_lookups, node_locations_lookups
def load_region_label_lookup():
region_label_lookup = {
'a': 'bathroom',
'b': 'bedroom',
'c': 'closet',
'd': 'dining room',
'e': 'entryway',#/foyer/lobby (should be the front door, not any door)
'f': 'familyroom',# (should be a room that a family hangs out in, not any area with couches)
'g': 'garage',#
'h': 'hallway',#
'i': 'library',# (should be room like a library at a university, not an individual study)
'j': 'laundryroom',#/mudroom (place where people do laundry, etc.)
'k': 'kitchen',#
'l': 'living room',# (should be the main "showcase" living room in a house, not any area with couches)
'm': 'meeting room',#/conferenceroom
'n': 'lounge',# (any area where people relax in comfy chairs/couches that is not the family room or living room
'o': 'office',# (usually for an individual, or a small set of people)
'p': 'porch',#/terrace/deck/driveway (must be outdoors on ground level)
'r': 'recreation',#/game (should have recreational objects, like pool table, etc.)
's': 'stairs',#
't': 'toilet',# (should be a small room with ONLY a toilet)
'u': 'utility room',#/toolroom
'v': 'tv',# (must have theater-style seating)
'w': 'gym',#workout/gym/exercise
'x': 'outdoor',# areas containing grass, plants, bushes, trees, etc.
'y': 'balcony',# (must be outside and must not be on ground floor)
'z': 'other room',# (it is clearly a room, but the function is not clear)
'B': 'bar',#
'C': 'classroom',#
'D': 'dining booth',#
'S': 'spa',#/sauna
'Z': 'junk',# (reflections of mirrors, random points floating in space, etc.)
'-': 'no label',#
}
return region_label_lookup
class Simulator(object):
''' A simple simulator in Matterport3D environment '''
@ -28,6 +134,8 @@ class Simulator(object):
self.candidate = {}
self.gmap = NavGraph()
self.node_region, self.region_room, self.region_obj, self.node_locations = load_floorplan()
def newEpisode(
self,
scan_ID: str,
@ -48,7 +156,21 @@ class Simulator(object):
# Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f:
self.navigable_dict = json.load(f)
navigable_dict = json.load(f)
self.navigable_dict = {}
for start, v in navigable_dict.items():
self.navigable_dict[start] = {}
print("BEFORE: ", len(navigable_dict[start]))
for to, _v in navigable_dict[start].items():
start_region = self.node_region[scan_ID][start]
to_region = self.node_region[scan_ID][to]
if start_region == to_region:
self.navigable_dict[start][to] = _v
print(start_region, to_region)
print("AFTER: ", len(self.navigable_dict[start]))
# Get candidate
self.getCandidate()