From 9165a802d3bd0214d879d99defc3962999a9ea9f Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Sun, 28 Apr 2024 21:11:50 +0800 Subject: [PATCH] feat: show target stop in state --- nav_src/NavGPT.py | 14 +++++++------- nav_src/agent.py | 8 ++++---- nav_src/data_utils.py | 19 ++++++++++++++++++- nav_src/env.py | 25 ++++++++++++++++++------- nav_src/parser.py | 14 +++++++------- 5 files changed, 54 insertions(+), 26 deletions(-) diff --git a/nav_src/NavGPT.py b/nav_src/NavGPT.py index 56eedd5..e836632 100644 --- a/nav_src/NavGPT.py +++ b/nav_src/NavGPT.py @@ -2,25 +2,25 @@ import os import json import time -from data_utils import construct_instrs +from data_utils import construct_reverie_instrs from utils.logger import write_to_record_file from utils.data import ImageObservationsDB from parser import parse_args -from env import R2RNavBatch -from agent import NavAgent +from env import REVERIENavBatch +from agent import NavGPTAgent def build_dataset(args): feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir) - dataset_class = R2RNavBatch + dataset_class = REVERIENavBatch val_env_names = [args.val_env_name] val_envs = {} for split in val_env_names: - val_instr_data = construct_instrs( + val_instr_data = construct_reverie_instrs( args.anno_dir, args.dataset, [split] ) val_env = dataset_class( @@ -34,7 +34,7 @@ def build_dataset(args): def valid(args, val_envs): - agent = NavAgent(next(iter(val_envs.values())), args) + agent = NavGPTAgent(next(iter(val_envs.values())), args) with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf: json.dump(vars(args), outf, indent=4) @@ -78,7 +78,7 @@ def valid(args, val_envs): def valid_from_file(args, val_envs): - agent = NavAgent(next(iter(val_envs.values())), args) + agent = NavGPTAgent(next(iter(val_envs.values())), args) with open(args.valid_file, 'r') as f: preds = json.load(f) diff --git a/nav_src/agent.py b/nav_src/agent.py index cc78b91..80fd8ea 100644 --- a/nav_src/agent.py +++ b/nav_src/agent.py @@ -6,7 +6,7 @@ import warnings import numpy as np from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple, Dict, Union -from env import R2RNavBatch +from env import REVERIENavBatch from argparse import Namespace from agent_base import BaseAgent @@ -142,10 +142,10 @@ class VLNAgent(ZeroShotAgent): full_inputs = {**kwargs, **new_inputs} return full_inputs -class NavAgent(BaseAgent): +class NavGPTAgent(BaseAgent): def __init__( self, - env: R2RNavBatch, + env: REVERIENavBatch, config: Namespace): """ Initialize the LLM Navigation Agent. @@ -725,4 +725,4 @@ class NavAgent(BaseAgent): self.traj[i]['llm_thought'].append(thought) self.traj[i]['llm_observation'].append(observation) - return self.traj \ No newline at end of file + return self.traj diff --git a/nav_src/data_utils.py b/nav_src/data_utils.py index a7f3f9f..2ebaf38 100644 --- a/nav_src/data_utils.py +++ b/nav_src/data_utils.py @@ -13,6 +13,7 @@ def load_instr_datasets(anno_dir, dataset, splits): return data +''' def construct_instrs(anno_dir, dataset, splits): data = [] if "instr" in splits[0]: @@ -27,4 +28,20 @@ def construct_instrs(anno_dir, dataset, splits): del new_item['instructions'] del new_item['instr_encodings'] data.append(new_item) - return data \ No newline at end of file + return data +''' +def construct_reverie_instrs(anno_dir, dataset, splits): + data = [] + if "instr" in splits[0]: + return load_instr_datasets(anno_dir, dataset, splits) + + for i, item in enumerate(load_instr_datasets(anno_dir, dataset, splits)): + # Split multiple instructions into separate entries + for j, instr in enumerate(item['instructions']): + new_item = dict(item) + new_item['instr_id'] = '%s_%d' % (item['path_id'], j) + new_item['instruction'] = instr + del new_item['instructions'] + del new_item['instr_encodings'] + data.append(new_item) + return data diff --git a/nav_src/env.py b/nav_src/env.py index 6da0356..5fa1637 100644 --- a/nav_src/env.py +++ b/nav_src/env.py @@ -33,11 +33,13 @@ class Simulator(object): scan_ID: str, viewpoint_ID: str, heading: int, - elevation: int,): + elevation: int, + stop: str): self.heading = heading self.elevation = elevation self.scan_ID = scan_ID self.viewpoint_ID = viewpoint_ID + self.stop = stop # Load navigable dict navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json') with open(navigable_path, 'r') as f: @@ -57,6 +59,7 @@ class Simulator(object): 'heading': self.heading, 'elevation': self.elevation, 'candidate': self.candidate, + 'stop': self.stop } return self.state @@ -101,9 +104,9 @@ class EnvBatch(object): def _make_id(self, scanId, viewpointId): return scanId + '_' + viewpointId - def newEpisodes(self, scanIds, viewpointIds, headings): - for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)): - self.sims[i].newEpisode(scanId, viewpointId, heading, 0) + def newEpisodes(self, scanIds, viewpointIds, headings, stops): + for i, (scanId, viewpointId, heading, stop) in enumerate(zip(scanIds, viewpointIds, headings, stops)): + self.sims[i].newEpisode(scanId, viewpointId, heading, 0, stop) def getStates(self): """ @@ -118,6 +121,7 @@ class EnvBatch(object): feature = self.feat_db.get_image_observation(state["scanID"], state["viewpointID"]) feature_states.append((feature, state)) + print(feature_states[-1]) return feature_states def makeActions(self, next_viewpoint_IDs): @@ -126,7 +130,7 @@ class EnvBatch(object): self.sims[i].makeAction(next_viewpoint_ID) -class R2RNavBatch(object): +class REVERIENavBatch(object): ''' Implements the REVERIE navigation task, using discretized viewpoints and pretrained features ''' def __init__( @@ -140,7 +144,7 @@ class R2RNavBatch(object): self.batch_size = batch_size self.name = name - self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation + #self.gt_trajs = self._get_gt_trajs(self.data) # for evaluation # use different seeds in different processes to shuffle data self.seed = seed @@ -154,12 +158,14 @@ class R2RNavBatch(object): print('%s loaded with %d instructions, using splits: %s' % ( self.__class__.__name__, len(self.data), self.name)) + ''' def _get_gt_trajs(self, data): gt_trajs = { x['instr_id']: (x['scan'], x['path']) \ for x in data if len(x['path']) > 1 } return gt_trajs + ''' def size(self): return len(self.data) @@ -197,6 +203,7 @@ class R2RNavBatch(object): else: self.ix += batch_size self.batch = batch + print(self.batch) def reset_epoch(self, shuffle=False): ''' Reset the data index to beginning of epoch. Primarily for testing. @@ -227,10 +234,13 @@ class R2RNavBatch(object): } # RL reward. The negative distance between the state and the final state # There are multiple gt end viewpoints on REVERIE. + + ''' if ob['instr_id'] in self.gt_trajs: ob['distance'] = self.shortest_distances[ob['scan']][ob['viewpoint']][item['path'][-1]] else: ob['distance'] = 0 + ''' obs.append(ob) return obs @@ -242,7 +252,8 @@ class R2RNavBatch(object): scanIds = [item['scan'] for item in self.batch] viewpointIds = [item['path'][0] for item in self.batch] headings = [item['heading'] for item in self.batch] - self.env.newEpisodes(scanIds, viewpointIds, headings) + stops = [item['stop'] for item in self.batch] + self.env.newEpisodes(scanIds, viewpointIds, headings, stops) return self._get_obs() def step(self, next_viewpoint_IDs): diff --git a/nav_src/parser.py b/nav_src/parser.py index 13b8fff..f4428c8 100644 --- a/nav_src/parser.py +++ b/nav_src/parser.py @@ -30,7 +30,7 @@ def parse_args(): # parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_2') # parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_3') # parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr_4') - parser.add_argument('--val_env_name', type=str, default='R2R_val_unseen_instr') + parser.add_argument('--val_env_name', type=str, default='REVERIE_val_unseen_instr') parser.add_argument('--load_instruction', action='store_true', default=True) parser.add_argument('--load_action_plan', action='store_true', default=True) @@ -57,15 +57,15 @@ def postprocess_args(args): ROOTDIR = args.root_dir # Setup input paths - args.obs_dir = os.path.join(ROOTDIR, 'R2R', 'observations_list_summarized') - args.obs_summary_dir = os.path.join(ROOTDIR, 'R2R', 'observations_summarized') - args.obj_dir = os.path.join(ROOTDIR, 'R2R', 'objects_list') + args.obs_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_list_summarized') + args.obs_summary_dir = os.path.join(ROOTDIR, 'REVERIE', 'observations_summarized') + args.obj_dir = os.path.join(ROOTDIR, 'REVERIE', 'objects_list') - args.connectivity_dir = os.path.join(ROOTDIR, 'R2R', 'connectivity') + args.connectivity_dir = os.path.join(ROOTDIR, 'REVERIE', 'connectivity') args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans') - args.anno_dir = os.path.join(ROOTDIR, 'R2R', 'annotations') - args.navigable_dir = os.path.join(ROOTDIR, 'R2R', 'navigable') + args.anno_dir = os.path.join(ROOTDIR, 'REVERIE', 'annotations') + args.navigable_dir = os.path.join(ROOTDIR, 'REVERIE', 'navigable') # Build paths args.log_dir = os.path.join(args.output_dir, 'logs')