feat: add boundary in GPT env

This commit is contained in:
Ting-Jun Wang 2024-04-29 03:05:32 +08:00
parent 89081b6b21
commit 02c957d38f
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -6,6 +6,7 @@ import numpy as np
import random import random
import networkx as nx import networkx as nx
from collections import defaultdict from collections import defaultdict
from glob import glob
from utils.data import load_nav_graphs from utils.data import load_nav_graphs
from eval_utils import cal_dtw, cal_cls from eval_utils import cal_dtw, cal_cls
@ -13,6 +14,111 @@ from utils.graph_utils import NavGraph
ERROR_MARGIN = 3.0 ERROR_MARGIN = 3.0
def load_floorplan():
region_label_lookup = load_region_label_lookup()
house_files = glob('/home/snsd0805/code/research/VLN/base_dir/v1/scans/*/house_segmentations/*.house')
node_region_lookups = {}
region_room_lookups = {}
region_object_lookups = {}
node_locations_lookups = {}
for house_file in house_files:
scan_id = house_file.split("/")[-3]
regions, floors, node_id_regions, node_id_floors = {}, {}, {}, {}
room_bboxes = {}
node_coors = {}
node_locations = {}
region_objects = defaultdict(list)
object_name_lookup = {}
#print(scan_id, datetime.now())
#house_lines = []
for line in open(house_file):
house_line = line.strip()
#house_lines.append(line.strip())
#for house_line in house_lines[1:]:
house_line_cols = house_line.split()
house_line_type = house_line_cols[0]
house_line_cols = house_line_cols[1:]
if house_line_type=='R':
region_index, level_index, _, _, label, px, py, pz, xlo, ylo, zlo, xhi, yhi, zhi, height,_,_,_,_ = house_line_cols
regions[region_index] = region_label_lookup[label]
floors[region_index] = level_index
room_bboxes[region_index] = {
'name': region_label_lookup[label],
'floor': level_index
}
#for var_name in ['px', 'py', 'pz', 'xlo', 'ylo', 'zlo', 'xhi', 'yhi', 'zhi', 'height']:
# room_bboxes[region_index][var_name] = float(eval(var_name))
if house_line_type=='P':
node_id, panorama_index, region_index, _, px, py, pz, _,_,_,_,_ = house_line_cols
node_id_regions[node_id] = region_index#regions[region_index]
node_locations[node_id] = (px, py, pz)
#node_id_floors[node_id] = int(floors[region_index]) + 1
#node_coors[node_id] = (float(px), float(py), float(pz))
#raise
#if house_line_type=='I':
#break
if house_line_type=='C':
category_index, category_mapping_index, category_mapping_name, mpcat40_index, mpcat40_name, _,_,_,_,_ = house_line_cols
object_name_lookup[category_index] = category_mapping_name
if house_line_type=='O':
object_index, region_index, category_index, px, py, pz, a0x, a0y, a0z, a1x, a1y, a1z, r0, r1, r2, _, _, _, _, _, _, _, _ = house_line_cols
if category_index=='-1' or region_index=='-1':
#print("error")
continue
region_objects[region_index].append(object_name_lookup[category_index])
#room_lookups[scan_id] = node_id_regions
#floor_lookups[scan_id] = node_id_floors
region_room_lookups[scan_id] = room_bboxes
node_region_lookups[scan_id] = node_id_regions
node_locations_lookups[scan_id] = node_locations
region_object_lookups[scan_id] = {k:sorted(v) for k,v in region_objects.items()}
#node_coor_lookups[scan_id] = node_coors
return node_region_lookups, region_room_lookups, region_object_lookups, node_locations_lookups
def load_region_label_lookup():
region_label_lookup = {
'a': 'bathroom',
'b': 'bedroom',
'c': 'closet',
'd': 'dining room',
'e': 'entryway',#/foyer/lobby (should be the front door, not any door)
'f': 'familyroom',# (should be a room that a family hangs out in, not any area with couches)
'g': 'garage',#
'h': 'hallway',#
'i': 'library',# (should be room like a library at a university, not an individual study)
'j': 'laundryroom',#/mudroom (place where people do laundry, etc.)
'k': 'kitchen',#
'l': 'living room',# (should be the main "showcase" living room in a house, not any area with couches)
'm': 'meeting room',#/conferenceroom
'n': 'lounge',# (any area where people relax in comfy chairs/couches that is not the family room or living room
'o': 'office',# (usually for an individual, or a small set of people)
'p': 'porch',#/terrace/deck/driveway (must be outdoors on ground level)
'r': 'recreation',#/game (should have recreational objects, like pool table, etc.)
's': 'stairs',#
't': 'toilet',# (should be a small room with ONLY a toilet)
'u': 'utility room',#/toolroom
'v': 'tv',# (must have theater-style seating)
'w': 'gym',#workout/gym/exercise
'x': 'outdoor',# areas containing grass, plants, bushes, trees, etc.
'y': 'balcony',# (must be outside and must not be on ground floor)
'z': 'other room',# (it is clearly a room, but the function is not clear)
'B': 'bar',#
'C': 'classroom',#
'D': 'dining booth',#
'S': 'spa',#/sauna
'Z': 'junk',# (reflections of mirrors, random points floating in space, etc.)
'-': 'no label',#
}
return region_label_lookup
class Simulator(object): class Simulator(object):
''' A simple simulator in Matterport3D environment ''' ''' A simple simulator in Matterport3D environment '''
@ -28,6 +134,9 @@ class Simulator(object):
self.candidate = {} self.candidate = {}
self.gmap = NavGraph() self.gmap = NavGraph()
self.node_region, self.region_room, self.region_obj, self.node_locations = load_floorplan()
def newEpisode( def newEpisode(
self, self,
scan_ID: str, scan_ID: str,
@ -48,7 +157,20 @@ class Simulator(object):
# Load navigable dict # Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json') navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f: with open(navigable_path, 'r') as f:
self.navigable_dict = json.load(f) navigable_dict = json.load(f)
self.navigable_dict = {}
for start, v in navigable_dict.items():
self.navigable_dict[start] = {}
print("BEFORE: ", len(navigable_dict[start]))
for to, _v in navigable_dict[start].items():
start_region = self.node_region[scan_ID][start]
to_region = self.node_region[scan_ID][to]
if start_region == to_region:
self.navigable_dict[start][to] = _v
print(start_region, to_region)
print("AFTER: ", len(self.navigable_dict[start]))
# Get candidate # Get candidate
self.getCandidate() self.getCandidate()