feat: show target stop in state
This commit is contained in:
parent
7454bc15af
commit
9165a802d3
@ -2,25 +2,25 @@ import os
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from data_utils import construct_instrs
|
from data_utils import construct_reverie_instrs
|
||||||
from utils.logger import write_to_record_file
|
from utils.logger import write_to_record_file
|
||||||
|
|
||||||
from utils.data import ImageObservationsDB
|
from utils.data import ImageObservationsDB
|
||||||
from parser import parse_args
|
from parser import parse_args
|
||||||
from env import R2RNavBatch
|
from env import REVERIENavBatch
|
||||||
from agent import NavAgent
|
from agent import NavGPTAgent
|
||||||
|
|
||||||
def build_dataset(args):
|
def build_dataset(args):
|
||||||
|
|
||||||
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
|
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
|
||||||
|
|
||||||
dataset_class = R2RNavBatch
|
dataset_class = REVERIENavBatch
|
||||||
|
|
||||||
val_env_names = [args.val_env_name]
|
val_env_names = [args.val_env_name]
|
||||||
|
|
||||||
val_envs = {}
|
val_envs = {}
|
||||||
for split in val_env_names:
|
for split in val_env_names:
|
||||||
val_instr_data = construct_instrs(
|
val_instr_data = construct_reverie_instrs(
|
||||||
args.anno_dir, args.dataset, [split]
|
args.anno_dir, args.dataset, [split]
|
||||||
)
|
)
|
||||||
val_env = dataset_class(
|
val_env = dataset_class(
|
||||||
@ -34,7 +34,7 @@ def build_dataset(args):
|
|||||||
|
|
||||||
def valid(args, val_envs):
|
def valid(args, val_envs):
|
||||||
|
|
||||||
agent = NavAgent(next(iter(val_envs.values())), args)
|
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
||||||
|
|
||||||
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
||||||
json.dump(vars(args), outf, indent=4)
|
json.dump(vars(args), outf, indent=4)
|
||||||
@ -78,7 +78,7 @@ def valid(args, val_envs):
|
|||||||
|
|
||||||
def valid_from_file(args, val_envs):
|
def valid_from_file(args, val_envs):
|
||||||
|
|
||||||
agent = NavAgent(next(iter(val_envs.values())), args)
|
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
||||||
with open(args.valid_file, 'r') as f:
|
with open(args.valid_file, 'r') as f:
|
||||||
preds = json.load(f)
|
preds = json.load(f)
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union
|
||||||
|
|
||||||
from env import R2RNavBatch
|
from env import REVERIENavBatch
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from agent_base import BaseAgent
|
from agent_base import BaseAgent
|
||||||
|
|
||||||
@ -142,10 +142,10 @@ class VLNAgent(ZeroShotAgent):
|
|||||||
full_inputs = {**kwargs, **new_inputs}
|
full_inputs = {**kwargs, **new_inputs}
|
||||||
return full_inputs
|
return full_inputs
|
||||||
|
|
||||||
class NavAgent(BaseAgent):
|
class NavGPTAgent(BaseAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: R2RNavBatch,
|
env: REVERIENavBatch,
|
||||||
config: Namespace):
|
config: Namespace):
|
||||||
"""
|
"""
|
||||||
Initialize the LLM Navigation Agent.
|
Initialize the LLM Navigation Agent.
|
||||||
@ -725,4 +725,4 @@ class NavAgent(BaseAgent):
|
|||||||
self.traj[i]['llm_thought'].append(thought)
|
self.traj[i]['llm_thought'].append(thought)
|
||||||
self.traj[i]['llm_observation'].append(observation)
|
self.traj[i]['llm_observation'].append(observation)
|
||||||
|
|
||||||
return self.traj
|
return self.traj
|
||||||
|
|||||||
@ -13,6 +13,7 @@ def load_instr_datasets(anno_dir, dataset, splits):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
'''
|
||||||
def construct_instrs(anno_dir, dataset, splits):
|
def construct_instrs(anno_dir, dataset, splits):
|
||||||
data = []
|
data = []
|
||||||
if "instr" in splits[0]:
|
if "instr" in splits[0]:
|
||||||
@ -27,4 +28,20 @@ def construct_instrs(anno_dir, dataset, splits):
|
|||||||
del new_item['instructions']
|
del new_item['instructions']
|
||||||
del new_item['instr_encodings']
|
del new_item['instr_encodings']
|
||||||
data.append(new_item)
|
data.append(new_item)
|
||||||
return data
|
return data
|
||||||
|
'''
|
||||||
|
def construct_reverie_instrs(anno_dir, dataset, splits):
|
||||||
|
data = []
|
||||||
|
if "instr" in splits[0]:
|
||||||
|
return load_instr_datasets(anno_dir, dataset, splits)
|
||||||
|
|
||||||
|
for i, item in enumerate(load_instr_datasets(anno_dir, dataset, splits)):
|
||||||
|
# Split multiple instructions into separate entries
|
||||||
|
for j, instr in enumerate(item['instructions']):
|
||||||
|
new_item = dict(item)
|
||||||
|
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||||
|
new_item['instruction'] = instr
|
||||||
|
del new_item['instructions']
|
||||||
|
del new_item['instr_encodings']
|
||||||
|
data.append(new_item)
|
||||||
|
return data
|
||||||
|
|||||||
@ -33,11 +33,13 @@ class Simulator(object):
|
|||||||
scan_ID: str,
|
scan_ID: str,
|
||||||
viewpoint_ID: str,
|
viewpoint_ID: str,
|
||||||
heading: int,
|
heading: int,
|
||||||
elevation: int,):
|
elevation: int,
|
||||||
|
stop: str):
|
||||||
self.heading = heading
|
self.heading = heading
|
||||||
self.elevation = elevation
|
self.elevation = elevation
|
||||||
self.scan_ID = scan_ID
|
self.scan_ID = scan_ID
|
||||||
self.viewpoint_ID = viewpoint_ID
|
self.viewpoint_ID = viewpoint_ID
|
||||||
|
self.stop = stop
|
||||||
# Load navigable dict
|
# Load navigable dict
|
||||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||||
with open(navigable_path, 'r') as f:
|
with open(navigable_path, 'r') as f:
|
||||||
@ -57,6 +59,7 @@ class Simulator(object):
|
|||||||
'heading': self.heading,
|
'heading': self.heading,
|
||||||
'elevation': self.elevation,
|
'elevation': self.elevation,
|
||||||
'candidate': self.candidate,
|
'candidate': self.candidate,
|
||||||
|
'stop': self.stop
|
||||||
}
|
}
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
@ -101,9 +104,9 @@ class EnvBatch(object):
|
|||||||
def _make_id(self, scanId, viewpointId):
|
def _make_id(self, scanId, viewpointId):
|
||||||
return scanId + '_' + viewpointId
|
return scanId + '_' + viewpointId
|
||||||
|
|
||||||
def newEpisodes(self, scanIds, viewpointIds, headings):
|
def newEpisodes(self, scanIds, viewpointIds, headings, stops):
|
||||||
for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
|
for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)):
|
||||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
|
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop)
|
||||||
|
|
||||||
def getStates(self):
|
def getStates(self):
|
||||||
"""
|
"""
|
||||||
@ -118,6 +121,7 @@ class EnvBatch(object):
|
|||||||
|
|
||||||
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
|
feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"])
|
||||||
feature_states.append((feature, state))
|
feature_states.append((feature, state))
|
||||||
|
print(feature_states[-1])
|
||||||
return feature_states
|
return feature_states
|
||||||
|
|
||||||
def makeActions(self, next_viewpoint_IDs):
|
def makeActions(self, next_viewpoint_IDs):
|
||||||
@ -126,7 +130,7 @@ class EnvBatch(object):
|
|||||||
self.sims[i].makeAction(next_viewpoint_ID)
|
self.sims[i].makeAction(next_viewpoint_ID)
|
||||||
|
|
||||||
|
|
||||||
class R2RNavBatch(object):
|
class REVERIENavBatch(object):
|
||||||
''' Implements the REVERIE navigation task, using discretized viewpoints and pretrained features '''
|
''' Implements the REVERIE navigation task, using discretized viewpoints and pretrained features '''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -140,7 +144,7 @@ class R2RNavBatch(object):
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
#self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation
|
||||||
|
|
||||||
# use different seeds in different processes to shuffle data
|
# use different seeds in different processes to shuffle data
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@ -154,12 +158,14 @@ class R2RNavBatch(object):
|
|||||||
print('%s loaded with %d instructions, using splits: %s' % (
|
print('%s loaded with %d instructions, using splits: %s' % (
|
||||||
self.__class__.__name__, len(self.data), self.name))
|
self.__class__.__name__, len(self.data), self.name))
|
||||||
|
|
||||||
|
'''
|
||||||
def _get_gt_trajs(self, data):
|
def _get_gt_trajs(self, data):
|
||||||
gt_trajs = {
|
gt_trajs = {
|
||||||
x['instr_id']: (x['scan'], x['path']) \
|
x['instr_id']: (x['scan'], x['path']) \
|
||||||
for x in data if len(x['path']) > 1
|
for x in data if len(x['path']) > 1
|
||||||
}
|
}
|
||||||
return gt_trajs
|
return gt_trajs
|
||||||
|
'''
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -197,6 +203,7 @@ class R2RNavBatch(object):
|
|||||||
else:
|
else:
|
||||||
self.ix += batch_size
|
self.ix += batch_size
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
|
print(self.batch)
|
||||||
|
|
||||||
def reset_epoch(self, shuffle=False):
|
def reset_epoch(self, shuffle=False):
|
||||||
''' Reset the data index to beginning of epoch. Primarily for testing.
|
''' Reset the data index to beginning of epoch. Primarily for testing.
|
||||||
@ -227,10 +234,13 @@ class R2RNavBatch(object):
|
|||||||
}
|
}
|
||||||
# RL reward. The negative distance between the state and the final state
|
# RL reward. The negative distance between the state and the final state
|
||||||
# There are multiple gt end viewpoints on REVERIE.
|
# There are multiple gt end viewpoints on REVERIE.
|
||||||
|
|
||||||
|
'''
|
||||||
if ob['instr_id'] in self.gt_trajs:
|
if ob['instr_id'] in self.gt_trajs:
|
||||||
ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]]
|
ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]]
|
||||||
else:
|
else:
|
||||||
ob['distance'] = 0
|
ob['distance'] = 0
|
||||||
|
'''
|
||||||
|
|
||||||
obs.append(ob)
|
obs.append(ob)
|
||||||
return obs
|
return obs
|
||||||
@ -242,7 +252,8 @@ class R2RNavBatch(object):
|
|||||||
scanIds = [item['scan'] for item in self.batch]
|
scanIds = [item['scan'] for item in self.batch]
|
||||||
viewpointIds = [item['path'][0] for item in self.batch]
|
viewpointIds = [item['path'][0] for item in self.batch]
|
||||||
headings = [item['heading'] for item in self.batch]
|
headings = [item['heading'] for item in self.batch]
|
||||||
self.env.newEpisodes(scanIds, viewpointIds, headings)
|
stops = [item['stop'] for item in self.batch]
|
||||||
|
self.env.newEpisodes(scanIds, viewpointIds, headings, stops)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def step(self, next_viewpoint_IDs):
|
def step(self, next_viewpoint_IDs):
|
||||||
|
|||||||
@ -30,7 +30,7 @@ def parse_args():
|
|||||||
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_2')
|
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_2')
|
||||||
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_3')
|
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_3')
|
||||||
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_4')
|
# parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_4')
|
||||||
parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr')
|
parser.add_argument('--val_env_name', type=str, default='REVERIE_val_unseen_instr')
|
||||||
|
|
||||||
parser.add_argument('--load_instruction', action='store_true', default=True)
|
parser.add_argument('--load_instruction', action='store_true', default=True)
|
||||||
parser.add_argument('--load_action_plan', action='store_true', default=True)
|
parser.add_argument('--load_action_plan', action='store_true', default=True)
|
||||||
@ -57,15 +57,15 @@ def postprocess_args(args):
|
|||||||
ROOTDIR = args.root_dir
|
ROOTDIR = args.root_dir
|
||||||
|
|
||||||
# Setup input paths
|
# Setup input paths
|
||||||
args.obs_dir = os.path.join(ROOTDIR, 'R2R', 'observations_list_summarized')
|
args.obs_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_list_summarized')
|
||||||
args.obs_summary_dir = os.path.join(ROOTDIR, 'R2R', 'observations_summarized')
|
args.obs_summary_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_summarized')
|
||||||
args.obj_dir = os.path.join(ROOTDIR, 'R2R', 'objects_list')
|
args.obj_dir = os.path.join(ROOTDIR, 'REVERIE', 'objects_list')
|
||||||
|
|
||||||
args.connectivity_dir = os.path.join(ROOTDIR, 'R2R', 'connectivity')
|
args.connectivity_dir = os.path.join(ROOTDIR, 'REVERIE', 'connectivity')
|
||||||
args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans')
|
args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans')
|
||||||
|
|
||||||
args.anno_dir = os.path.join(ROOTDIR, 'R2R', 'annotations')
|
args.anno_dir = os.path.join(ROOTDIR, 'REVERIE', 'annotations')
|
||||||
args.navigable_dir = os.path.join(ROOTDIR, 'R2R', 'navigable')
|
args.navigable_dir = os.path.join(ROOTDIR, 'REVERIE', 'navigable')
|
||||||
|
|
||||||
# Build paths
|
# Build paths
|
||||||
args.log_dir = os.path.join(args.output_dir, 'logs')
|
args.log_dir = os.path.join(args.output_dir, 'logs')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user