feat: show target stop in state

This commit is contained in:
Ting-Jun Wang 2024-04-28 21:11:50 +08:00
parent 7454bc15af
commit 9165a802d3
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
5 changed files with 54 additions and 26 deletions

View File

@ -2,25 +2,25 @@ import os
import json
import time
from data_utils import construct_instrs
from data_utils import construct_reverie_instrs
from utils.logger import write_to_record_file
from utils.data import ImageObservationsDB
from parser import parse_args
from env import R2RNavBatch
from agent import NavAgent
from env import REVERIENavBatch
from agent import NavGPTAgent
def build_dataset(args):
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
dataset_class = R2RNavBatch
dataset_class = REVERIENavBatch
val_env_names = [args.val_env_name]
val_envs = {}
for split in val_env_names:
val_instr_data = construct_instrs(
val_instr_data = construct_reverie_instrs(
args.anno_dir, args.dataset, [split]
)
val_env = dataset_class(
@ -34,7 +34,7 @@ def build_dataset(args):
def valid(args, val_envs):
agent = NavAgent(next(iter(val_envs.values())), args)
agent = NavGPTAgent(next(iter(val_envs.values())), args)
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
json.dump(vars(args), outf, indent=4)
@ -78,7 +78,7 @@ def valid(args, val_envs):
def valid_from_file(args, val_envs):
agent = NavAgent(next(iter(val_envs.values())), args)
agent = NavGPTAgent(next(iter(val_envs.values())), args)
with open(args.valid_file, 'r') as f:
preds = json.load(f)

View File

@ -6,7 +6,7 @@ import warnings
import numpy as np
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
from env import R2RNavBatch
from env import REVERIENavBatch
from argparse import Namespace
from agent_base import BaseAgent
@ -142,10 +142,10 @@ class VLNAgent(ZeroShotAgent):
full_inputs = {**kwargs, **new_inputs}
return full_inputs
class NavAgent(BaseAgent):
class NavGPTAgent(BaseAgent):
def __init__(
self,
env: R2RNavBatch,
env: REVERIENavBatch,
config: Namespace):
"""
Initialize the LLM Navigation Agent.
@ -725,4 +725,4 @@ class NavAgent(BaseAgent):
self.traj[i]['llm_thought'].append(thought)
self.traj[i]['llm_observation'].append(observation)
return self.traj
return self.traj

View File

@ -13,6 +13,7 @@ def load_instr_datasets(anno_dir, dataset, splits):
return data
'''
def construct_instrs(anno_dir, dataset, splits):
data = []
if "instr" in splits[0]:
@ -27,4 +28,20 @@ def construct_instrs(anno_dir, dataset, splits):
del new_item['instructions']
del new_item['instr_encodings']
data.append(new_item)
return data
return data
'''
def construct_reverie_instrs(anno_dir, dataset, splits):
data = []
if "instr" in splits[0]:
return load_instr_datasets(anno_dir, dataset, splits)
for i, item in enumerate(load_instr_datasets(anno_dir, dataset, splits)):
# Split multiple instructions into separate entries
for j, instr in enumerate(item['instructions']):
new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
new_item['instruction'] = instr
del new_item['instructions']
del new_item['instr_encodings']
data.append(new_item)
return data

View File

@ -33,11 +33,13 @@ class Simulator(object):
scan_ID: str,
viewpoint_ID: str,
heading: int,
elevation: int,):
elevation: int,
stop: str):
self.heading = heading
self.elevation = elevation
self.scan_ID = scan_ID
self.viewpoint_ID = viewpoint_ID
self.stop = stop
# Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f:
@ -57,6 +59,7 @@ class Simulator(object):
'heading': self.heading,
'elevation': self.elevation,
'candidate': self.candidate,
'stop': self.stop
}
return self.state
@ -101,9 +104,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings):
for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
def newEpisodes(self, scanIds, viewpointIds, headings, stops):
for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop)
def getStates(self):
"""
@ -118,6 +121,7 @@ class EnvBatch(object):
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
feature_states.append((feature, state))
print(feature_states[-1])
return feature_states
def makeActions(self, next_viewpoint_IDs):
@ -126,7 +130,7 @@ class EnvBatch(object):
self.sims[i].makeAction(next_viewpoint_ID)
class R2RNavBatch(object):
class REVERIENavBatch(object):
''' Implements the REVERIE navigation task, using discretized viewpoints and pretrained features '''
def __init__(
@ -140,7 +144,7 @@ class R2RNavBatch(object):
self.batch_size = batch_size
self.name = name
self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
#self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
# use different seeds in different processes to shuffle data
self.seed = seed
@ -154,12 +158,14 @@ class R2RNavBatch(object):
print('%s loaded with %d instructions, using splits: %s' % (
self.__class__.__name__, len(self.data), self.name))
'''
def _get_gt_trajs(self, data):
gt_trajs = {
x['instr_id']: (x['scan'], x['path']) \
for x in data if len(x['path']) > 1
}
return gt_trajs
'''
def size(self):
return len(self.data)
@ -197,6 +203,7 @@ class R2RNavBatch(object):
else:
self.ix += batch_size
self.batch = batch
print(self.batch)
def reset_epoch(self, shuffle=False):
''' Reset the data index to beginning of epoch. Primarily for testing.
@ -227,10 +234,13 @@ class R2RNavBatch(object):
}
# RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE.
'''
if ob['instr_id'] in self.gt_trajs:
ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]]
else:
ob['distance'] = 0
'''
obs.append(ob)
return obs
@ -242,7 +252,8 @@ class R2RNavBatch(object):
scanIds = [item['scan'] for item in self.batch]
viewpointIds = [item['path'][0] for item in self.batch]
headings = [item['heading'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings)
stops = [item['stop'] for item in self.batch]
self.env.newEpisodes(scanIds, viewpointIds, headings, stops)
return self._get_obs()
def step(self, next_viewpoint_IDs):

View File

@ -30,7 +30,7 @@ def parse_args():
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_2')
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_3')
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_4')
parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr')
parser.add_argument('--val_env_name', type=str, default='REVERIE_val_unseen_instr')
parser.add_argument('--load_instruction', action='store_true', default=True)
parser.add_argument('--load_action_plan', action='store_true', default=True)
@ -57,15 +57,15 @@ def postprocess_args(args):
ROOTDIR = args.root_dir
# Setup input paths
args.obs_dir = os.path.join(ROOTDIR, 'R2R', 'observations_list_summarized')
args.obs_summary_dir = os.path.join(ROOTDIR, 'R2R', 'observations_summarized')
args.obj_dir = os.path.join(ROOTDIR, 'R2R', 'objects_list')
args.obs_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_list_summarized')
args.obs_summary_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_summarized')
args.obj_dir = os.path.join(ROOTDIR, 'REVERIE', 'objects_list')
args.connectivity_dir = os.path.join(ROOTDIR, 'R2R', 'connectivity')
args.connectivity_dir = os.path.join(ROOTDIR, 'REVERIE', 'connectivity')
args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans')
args.anno_dir = os.path.join(ROOTDIR, 'R2R', 'annotations')
args.navigable_dir = os.path.join(ROOTDIR, 'R2R', 'navigable')
args.anno_dir = os.path.join(ROOTDIR, 'REVERIE', 'annotations')
args.navigable_dir = os.path.join(ROOTDIR, 'REVERIE', 'navigable')
# Build paths
args.log_dir = os.path.join(args.output_dir, 'logs')