feat: show target stop in state

This commit is contained in:
Ting-Jun Wang 2024-04-28 21:11:50 +08:00
parent 7454bc15af
commit 9165a802d3
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
5 changed files with 54 additions and 26 deletions

View File

@ -2,25 +2,25 @@ import os
import json import json
import time import time
from data_utils import construct_instrs from data_utils import construct_reverie_instrs
from utils.logger import write_to_record_file from utils.logger import write_to_record_file
from utils.data import ImageObservationsDB from utils.data import ImageObservationsDB
from parser import parse_args from parser import parse_args
from env import R2RNavBatch from env import REVERIENavBatch
from agent import NavAgent from agent import NavGPTAgent
def build_dataset(args): def build_dataset(args):
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir) feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
dataset_class = R2RNavBatch dataset_class = REVERIENavBatch
val_env_names = [args.val_env_name] val_env_names = [args.val_env_name]
val_envs = {} val_envs = {}
for split in val_env_names: for split in val_env_names:
val_instr_data = construct_instrs( val_instr_data = construct_reverie_instrs(
args.anno_dir, args.dataset, [split] args.anno_dir, args.dataset, [split]
) )
val_env = dataset_class( val_env = dataset_class(
@ -34,7 +34,7 @@ def build_dataset(args):
def valid(args, val_envs): def valid(args, val_envs):
agent = NavAgent(next(iter(val_envs.values())), args) agent = NavGPTAgent(next(iter(val_envs.values())), args)
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf: with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
json.dump(vars(args), outf, indent=4) json.dump(vars(args), outf, indent=4)
@ -78,7 +78,7 @@ def valid(args, val_envs):
def valid_from_file(args, val_envs): def valid_from_file(args, val_envs):
agent = NavAgent(next(iter(val_envs.values())), args) agent = NavGPTAgent(next(iter(val_envs.values())), args)
with open(args.valid_file, 'r') as f: with open(args.valid_file, 'r') as f:
preds = json.load(f) preds = json.load(f)

View File

@ -6,7 +6,7 @@ import warnings
import numpy as np import numpy as np
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
from env import R2RNavBatch from env import REVERIENavBatch
from argparse import Namespace from argparse import Namespace
from agent_base import BaseAgent from agent_base import BaseAgent
@ -142,10 +142,10 @@ class VLNAgent(ZeroShotAgent):
full_inputs = {**kwargs, **new_inputs} full_inputs = {**kwargs, **new_inputs}
return full_inputs return full_inputs
class NavAgent(BaseAgent): class NavGPTAgent(BaseAgent):
def __init__( def __init__(
self, self,
env: R2RNavBatch, env: REVERIENavBatch,
config: Namespace): config: Namespace):
""" """
Initialize the LLM Navigation Agent. Initialize the LLM Navigation Agent.
@ -725,4 +725,4 @@ class NavAgent(BaseAgent):
self.traj[i]['llm_thought'].append(thought) self.traj[i]['llm_thought'].append(thought)
self.traj[i]['llm_observation'].append(observation) self.traj[i]['llm_observation'].append(observation)
return self.traj return self.traj

View File

@ -13,6 +13,7 @@ def load_instr_datasets(anno_dir, dataset, splits):
return data return data
'''
def construct_instrs(anno_dir, dataset, splits): def construct_instrs(anno_dir, dataset, splits):
data = [] data = []
if "instr" in splits[0]: if "instr" in splits[0]:
@ -27,4 +28,20 @@ def construct_instrs(anno_dir, dataset, splits):
del new_item['instructions'] del new_item['instructions']
del new_item['instr_encodings'] del new_item['instr_encodings']
data.append(new_item) data.append(new_item)
return data return data
'''
def construct_reverie_instrs(anno_dir, dataset, splits):
data = []
if "instr" in splits[0]:
return load_instr_datasets(anno_dir, dataset, splits)
for i, item in enumerate(load_instr_datasets(anno_dir, dataset, splits)):
# Split multiple instructions into separate entries
for j, instr in enumerate(item['instructions']):
new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
new_item['instruction'] = instr
del new_item['instructions']
del new_item['instr_encodings']
data.append(new_item)
return data

View File

@ -33,11 +33,13 @@ class Simulator(object):
scan_ID: str, scan_ID: str,
viewpoint_ID: str, viewpoint_ID: str,
heading: int, heading: int,
elevation: int,): elevation: int,
stop: str):
self.heading = heading self.heading = heading
self.elevation = elevation self.elevation = elevation
self.scan_ID = scan_ID self.scan_ID = scan_ID
self.viewpoint_ID = viewpoint_ID self.viewpoint_ID = viewpoint_ID
self.stop = stop
# Load navigable dict # Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json') navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f: with open(navigable_path, 'r') as f:
@ -57,6 +59,7 @@ class Simulator(object):
'heading': self.heading, 'heading': self.heading,
'elevation': self.elevation, 'elevation': self.elevation,
'candidate': self.candidate, 'candidate': self.candidate,
'stop': self.stop
} }
return self.state return self.state
@ -101,9 +104,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId): def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings): def newEpisodes(self, scanIds, viewpointIds, headings, stops):
for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)): for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0) self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop)
def getStates(self): def getStates(self):
""" """
@ -118,6 +121,7 @@ class EnvBatch(object):
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"]) feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
feature_states.append((feature, state)) feature_states.append((feature, state))
print(feature_states[-1])
return feature_states return feature_states
def makeActions(self, next_viewpoint_IDs): def makeActions(self, next_viewpoint_IDs):
@ -126,7 +130,7 @@ class EnvBatch(object):
self.sims[i].makeAction(next_viewpoint_ID) self.sims[i].makeAction(next_viewpoint_ID)
class R2RNavBatch(object): class REVERIENavBatch(object):
''' Implements the REVERIE navigation task, using discretized viewpoints and pretrained features ''' ''' Implements the REVERIE navigation task, using discretized viewpoints and pretrained features '''
def __init__( def __init__(
@ -140,7 +144,7 @@ class R2RNavBatch(object):
self.batch_size = batch_size self.batch_size = batch_size
self.name = name self.name = name
self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation #self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
# use different seeds in different processes to shuffle data # use different seeds in different processes to shuffle data
self.seed = seed self.seed = seed
@ -154,12 +158,14 @@ class R2RNavBatch(object):
print('%s loaded with %d instructions, using splits: %s' % ( print('%s loaded with %d instructions, using splits: %s' % (
self.__class__.__name__, len(self.data), self.name)) self.__class__.__name__, len(self.data), self.name))
'''
def _get_gt_trajs(self, data): def _get_gt_trajs(self, data):
gt_trajs = { gt_trajs = {
x['instr_id']: (x['scan'], x['path']) \ x['instr_id']: (x['scan'], x['path']) \
for x in data if len(x['path']) > 1 for x in data if len(x['path']) > 1
} }
return gt_trajs return gt_trajs
'''
def size(self): def size(self):
return len(self.data) return len(self.data)
@ -197,6 +203,7 @@ class R2RNavBatch(object):
else: else:
self.ix += batch_size self.ix += batch_size
self.batch = batch self.batch = batch
print(self.batch)
def reset_epoch(self, shuffle=False): def reset_epoch(self, shuffle=False):
''' Reset the data index to beginning of epoch. Primarily for testing. ''' Reset the data index to beginning of epoch. Primarily for testing.
@ -227,10 +234,13 @@ class R2RNavBatch(object):
} }
# RL reward. The negative distance between the state and the final state # RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE. # There are multiple gt end viewpoints on REVERIE.
'''
if ob['instr_id'] in self.gt_trajs: if ob['instr_id'] in self.gt_trajs:
ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]] ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]]
else: else:
ob['distance'] = 0 ob['distance'] = 0
'''
obs.append(ob) obs.append(ob)
return obs return obs
@ -242,7 +252,8 @@ class R2RNavBatch(object):
scanIds = [item['scan'] for item in self.batch] scanIds = [item['scan'] for item in self.batch]
viewpointIds = [item['path'][0] for item in self.batch] viewpointIds = [item['path'][0] for item in self.batch]
headings = [item['heading'] for item in self.batch] headings = [item['heading'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings) stops = [item['stop'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops)
return self._get_obs() return self._get_obs()
def step(self, next_viewpoint_IDs): def step(self, next_viewpoint_IDs):

View File

@ -30,7 +30,7 @@ def parse_args():
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_2') # parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_2')
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_3') # parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_3')
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_4') # parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_4')
parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr') parser.add_argument('--val_env_name', type=str, default='REVERIE_val_unseen_instr')
parser.add_argument('--load_instruction', action='store_true', default=True) parser.add_argument('--load_instruction', action='store_true', default=True)
parser.add_argument('--load_action_plan', action='store_true', default=True) parser.add_argument('--load_action_plan', action='store_true', default=True)
@ -57,15 +57,15 @@ def postprocess_args(args):
ROOTDIR = args.root_dir ROOTDIR = args.root_dir
# Setup input paths # Setup input paths
args.obs_dir = os.path.join(ROOTDIR, 'R2R', 'observations_list_summarized') args.obs_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_list_summarized')
args.obs_summary_dir = os.path.join(ROOTDIR, 'R2R', 'observations_summarized') args.obs_summary_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_summarized')
args.obj_dir = os.path.join(ROOTDIR, 'R2R', 'objects_list') args.obj_dir = os.path.join(ROOTDIR, 'REVERIE', 'objects_list')
args.connectivity_dir = os.path.join(ROOTDIR, 'R2R', 'connectivity') args.connectivity_dir = os.path.join(ROOTDIR, 'REVERIE', 'connectivity')
args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans') args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans')
args.anno_dir = os.path.join(ROOTDIR, 'R2R', 'annotations') args.anno_dir = os.path.join(ROOTDIR, 'REVERIE', 'annotations')
args.navigable_dir = os.path.join(ROOTDIR, 'R2R', 'navigable') args.navigable_dir = os.path.join(ROOTDIR, 'REVERIE', 'navigable')
# Build paths # Build paths
args.log_dir = os.path.join(args.output_dir, 'logs') args.log_dir = os.path.join(args.output_dir, 'logs')