feat: show target stop in state
This commit is contained in:
parent
7454bc15af
commit
9165a802d3
@ -2,25 +2,25 @@ import os
|
||||
import json
|
||||
import time
|
||||
|
||||
from data_utils import construct_instrs
|
||||
from data_utils import construct_reverie_instrs
|
||||
from utils.logger import write_to_record_file
|
||||
|
||||
from utils.data import ImageObservationsDB
|
||||
from parser import parse_args
|
||||
from env import R2RNavBatch
|
||||
from agent import NavAgent
|
||||
from env import REVERIENavBatch
|
||||
from agent import NavGPTAgent
|
||||
|
||||
def build_dataset(args):
|
||||
|
||||
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
|
||||
|
||||
dataset_class = R2RNavBatch
|
||||
dataset_class = REVERIENavBatch
|
||||
|
||||
val_env_names = [args.val_env_name]
|
||||
|
||||
val_envs = {}
|
||||
for split in val_env_names:
|
||||
val_instr_data = construct_instrs(
|
||||
val_instr_data = construct_reverie_instrs(
|
||||
args.anno_dir, args.dataset, [split]
|
||||
)
|
||||
val_env = dataset_class(
|
||||
@ -34,7 +34,7 @@ def build_dataset(args):
|
||||
|
||||
def valid(args, val_envs):
|
||||
|
||||
agent = NavAgent(next(iter(val_envs.values())), args)
|
||||
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
||||
|
||||
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
||||
json.dump(vars(args), outf, indent=4)
|
||||
@ -78,7 +78,7 @@ def valid(args, val_envs):
|
||||
|
||||
def valid_from_file(args, val_envs):
|
||||
|
||||
agent = NavAgent(next(iter(val_envs.values())), args)
|
||||
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
||||
with open(args.valid_file, 'r') as f:
|
||||
preds = json.load(f)
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
||||
|
||||
from env import R2RNavBatch
|
||||
from env import REVERIENavBatch
|
||||
from argparse import Namespace
|
||||
from agent_base import BaseAgent
|
||||
|
||||
@ -142,10 +142,10 @@ class VLNAgent(ZeroShotAgent):
|
||||
full_inputs = {**kwargs, **new_inputs}
|
||||
return full_inputs
|
||||
|
||||
class NavAgent(BaseAgent):
|
||||
class NavGPTAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
env: R2RNavBatch,
|
||||
env: REVERIENavBatch,
|
||||
config: Namespace):
|
||||
"""
|
||||
Initialize the LLM Navigation Agent.
|
||||
|
||||
@ -13,6 +13,7 @@ def load_instr_datasets(anno_dir, dataset, splits):
|
||||
|
||||
return data
|
||||
|
||||
'''
|
||||
def construct_instrs(anno_dir, dataset, splits):
|
||||
data = []
|
||||
if "instr" in splits[0]:
|
||||
@ -28,3 +29,19 @@ def construct_instrs(anno_dir, dataset, splits):
|
||||
del new_item['instr_encodings']
|
||||
data.append(new_item)
|
||||
return data
|
||||
'''
|
||||
def construct_reverie_instrs(anno_dir, dataset, splits):
|
||||
data = []
|
||||
if "instr" in splits[0]:
|
||||
return load_instr_datasets(anno_dir, dataset, splits)
|
||||
|
||||
for i, item in enumerate(load_instr_datasets(anno_dir, dataset, splits)):
|
||||
# Split multiple instructions into separate entries
|
||||
for j, instr in enumerate(item['instructions']):
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||
new_item['instruction'] = instr
|
||||
del new_item['instructions']
|
||||
del new_item['instr_encodings']
|
||||
data.append(new_item)
|
||||
return data
|
||||
|
||||
@ -33,11 +33,13 @@ class Simulator(object):
|
||||
scan_ID: str,
|
||||
viewpoint_ID: str,
|
||||
heading: int,
|
||||
elevation: int,):
|
||||
elevation: int,
|
||||
stop: str):
|
||||
self.heading = heading
|
||||
self.elevation = elevation
|
||||
self.scan_ID = scan_ID
|
||||
self.viewpoint_ID = viewpoint_ID
|
||||
self.stop = stop
|
||||
# Load navigable dict
|
||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||
with open(navigable_path, 'r') as f:
|
||||
@ -57,6 +59,7 @@ class Simulator(object):
|
||||
'heading': self.heading,
|
||||
'elevation': self.elevation,
|
||||
'candidate': self.candidate,
|
||||
'stop': self.stop
|
||||
}
|
||||
return self.state
|
||||
|
||||
@ -101,9 +104,9 @@ class EnvBatch(object):
|
||||
def _make_id(self, scanId, viewpointId):
|
||||
return scanId + '_' + viewpointId
|
||||
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings):
|
||||
for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings, stops):
|
||||
for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop)
|
||||
|
||||
def getStates(self):
|
||||
"""
|
||||
@ -118,6 +121,7 @@ class EnvBatch(object):
|
||||
|
||||
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
|
||||
feature_states.append((feature, state))
|
||||
print(feature_states[-1])
|
||||
return feature_states
|
||||
|
||||
def makeActions(self, next_viewpoint_IDs):
|
||||
@ -126,7 +130,7 @@ class EnvBatch(object):
|
||||
self.sims[i].makeAction(next_viewpoint_ID)
|
||||
|
||||
|
||||
class R2RNavBatch(object):
|
||||
class REVERIENavBatch(object):
|
||||
''' Implements the REVERIE navigation task, using discretized viewpoints and pretrained features '''
|
||||
|
||||
def __init__(
|
||||
@ -140,7 +144,7 @@ class R2RNavBatch(object):
|
||||
self.batch_size = batch_size
|
||||
self.name = name
|
||||
|
||||
self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
||||
#self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
||||
|
||||
# use different seeds in different processes to shuffle data
|
||||
self.seed = seed
|
||||
@ -154,12 +158,14 @@ class R2RNavBatch(object):
|
||||
print('%s loaded with %d instructions, using splits: %s' % (
|
||||
self.__class__.__name__, len(self.data), self.name))
|
||||
|
||||
'''
|
||||
def _get_gt_trajs(self, data):
|
||||
gt_trajs = {
|
||||
x['instr_id']: (x['scan'], x['path']) \
|
||||
for x in data if len(x['path']) > 1
|
||||
}
|
||||
return gt_trajs
|
||||
'''
|
||||
|
||||
def size(self):
|
||||
return len(self.data)
|
||||
@ -197,6 +203,7 @@ class R2RNavBatch(object):
|
||||
else:
|
||||
self.ix += batch_size
|
||||
self.batch = batch
|
||||
print(self.batch)
|
||||
|
||||
def reset_epoch(self, shuffle=False):
|
||||
''' Reset the data index to beginning of epoch. Primarily for testing.
|
||||
@ -227,10 +234,13 @@ class R2RNavBatch(object):
|
||||
}
|
||||
# RL reward. The negative distance between the state and the final state
|
||||
# There are multiple gt end viewpoints on REVERIE.
|
||||
|
||||
'''
|
||||
if ob['instr_id'] in self.gt_trajs:
|
||||
ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]]
|
||||
else:
|
||||
ob['distance'] = 0
|
||||
'''
|
||||
|
||||
obs.append(ob)
|
||||
return obs
|
||||
@ -242,7 +252,8 @@ class R2RNavBatch(object):
|
||||
scanIds = [item['scan'] for item in self.batch]
|
||||
viewpointIds = [item['path'][0] for item in self.batch]
|
||||
headings = [item['heading'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings)
|
||||
stops = [item['stop'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings, stops)
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, next_viewpoint_IDs):
|
||||
|
||||
@ -30,7 +30,7 @@ def parse_args():
|
||||
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_2')
|
||||
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_3')
|
||||
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_4')
|
||||
parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr')
|
||||
parser.add_argument('--val_env_name', type=str, default='REVERIE_val_unseen_instr')
|
||||
|
||||
parser.add_argument('--load_instruction', action='store_true', default=True)
|
||||
parser.add_argument('--load_action_plan', action='store_true', default=True)
|
||||
@ -57,15 +57,15 @@ def postprocess_args(args):
|
||||
ROOTDIR = args.root_dir
|
||||
|
||||
# Setup input paths
|
||||
args.obs_dir = os.path.join(ROOTDIR, 'R2R', 'observations_list_summarized')
|
||||
args.obs_summary_dir = os.path.join(ROOTDIR, 'R2R', 'observations_summarized')
|
||||
args.obj_dir = os.path.join(ROOTDIR, 'R2R', 'objects_list')
|
||||
args.obs_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_list_summarized')
|
||||
args.obs_summary_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_summarized')
|
||||
args.obj_dir = os.path.join(ROOTDIR, 'REVERIE', 'objects_list')
|
||||
|
||||
args.connectivity_dir = os.path.join(ROOTDIR, 'R2R', 'connectivity')
|
||||
args.connectivity_dir = os.path.join(ROOTDIR, 'REVERIE', 'connectivity')
|
||||
args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans')
|
||||
|
||||
args.anno_dir = os.path.join(ROOTDIR, 'R2R', 'annotations')
|
||||
args.navigable_dir = os.path.join(ROOTDIR, 'R2R', 'navigable')
|
||||
args.anno_dir = os.path.join(ROOTDIR, 'REVERIE', 'annotations')
|
||||
args.navigable_dir = os.path.join(ROOTDIR, 'REVERIE', 'navigable')
|
||||
|
||||
# Build paths
|
||||
args.log_dir = os.path.join(args.output_dir, 'logs')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user