diff --git a/map_nav_src/reverie/agent_base.py b/map_nav_src/reverie/agent_base.py index b18d5a6..e3ed149 100644 --- a/map_nav_src/reverie/agent_base.py +++ b/map_nav_src/reverie/agent_base.py @@ -27,7 +27,13 @@ class BaseAgent(object): def get_results(self, detailed_output=False): output = [] for k, v in self.results.items(): - output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']}) + output.append({ + 'instr_id': k, + 'trajectory': v['path'], + 'pred_objid': v['pred_objid'], + 'found': v['found'], + 'gt_found': v['gt_found'] + }) if detailed_output: output[-1]['details'] = v['details'] return output diff --git a/map_nav_src/reverie/agent_obj.py b/map_nav_src/reverie/agent_obj.py index 0edcbe1..ce5f458 100644 --- a/map_nav_src/reverie/agent_obj.py +++ b/map_nav_src/reverie/agent_obj.py @@ -504,7 +504,7 @@ class GMapObjectNavAgent(Seq2SeqAgent): if stop_node is not None and obs[i]['viewpoint'] != stop_node: traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node)) traj[i]['pred_objid'] = stop_score['og'] - if stop_score['og'] == -1: + if stop_score['og'] == -1 or stop_score['og'] == None: traj[i]['found'] = False else: traj[i]['found'] = True diff --git a/map_nav_src/reverie/main_nav_obj.py b/map_nav_src/reverie/main_nav_obj.py index 0acc1f1..641943e 100644 --- a/map_nav_src/reverie/main_nav_obj.py +++ b/map_nav_src/reverie/main_nav_obj.py @@ -69,7 +69,9 @@ def build_dataset(args, rank=0): val_env_names = [ 'val_seen', 'val_unseen'] if args.submit: - val_env_names.append('test') + include_test = input('Include test dataset? (y/n)') + if include_test == 'y' or include_test == 'Y': + val_env_names.append('test') val_envs = {} for split in val_env_names: @@ -239,11 +241,14 @@ def valid(args, train_env, val_envs, rank=-1): write_to_record_file(str(args) + '\n\n', record_file) for env_name, env in val_envs.items(): + print(env_name) prefix = 'submit' if args.detailed_output is False else 'detail' output_file = os.path.join(args.pred_dir, "%s_%s_%s.json" % ( prefix, env_name, args.fusion)) if os.path.exists(output_file): - continue + replace = input(f"{output_file} exists. Replace? (y/n): ") + if replace != 'y' and replace != 'Y': + continue agent.logs = defaultdict(list) agent.env = env