fix: some KEY ERROR situation

- situation: in visualization tool, some False Negative case(receive
  adversarial instruction but return True(found)) will get the
  found_objid but None. So it will raise a Key error interrupt.

- Solution: found that there is a condition function in agent_obj.py
  return NOT_FOUND when detecting the found_objId is -1 but it doesn't
  detect None so that it always return FOUND when the found_objId is
  None. It doens't follow the rules. I add a new condition to check
  whether the found_objId is -1 or None

- Other: For run the testing scripts, the user can shoose to replace the
  exists output file.
This commit is contained in:
Ting-Jun Wang 2024-01-19 15:25:26 +08:00
parent b2dce6111e
commit de3326ae85
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 15 additions and 4 deletions

View File

@ -27,7 +27,13 @@ class BaseAgent(object):
def get_results(self, detailed_output=False): def get_results(self, detailed_output=False):
output = [] output = []
for k, v in self.results.items(): for k, v in self.results.items():
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']}) output.append({
'instr_id': k,
'trajectory': v['path'],
'pred_objid': v['pred_objid'],
'found': v['found'],
'gt_found': v['gt_found']
})
if detailed_output: if detailed_output:
output[-1]['details'] = v['details'] output[-1]['details'] = v['details']
return output return output

View File

@ -504,7 +504,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if stop_node is not None and obs[i]['viewpoint'] != stop_node: if stop_node is not None and obs[i]['viewpoint'] != stop_node:
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node)) traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
traj[i]['pred_objid'] = stop_score['og'] traj[i]['pred_objid'] = stop_score['og']
if stop_score['og'] == -1: if stop_score['og'] == -1 or stop_score['og'] == None:
traj[i]['found'] = False traj[i]['found'] = False
else: else:
traj[i]['found'] = True traj[i]['found'] = True

View File

@ -69,7 +69,9 @@ def build_dataset(args, rank=0):
val_env_names = [ 'val_seen', 'val_unseen'] val_env_names = [ 'val_seen', 'val_unseen']
if args.submit: if args.submit:
val_env_names.append('test') include_test = input('Include test dataset? (y/n)')
if include_test == 'y' or include_test == 'Y':
val_env_names.append('test')
val_envs = {} val_envs = {}
for split in val_env_names: for split in val_env_names:
@ -239,11 +241,14 @@ def valid(args, train_env, val_envs, rank=-1):
write_to_record_file(str(args) + '\n\n', record_file) write_to_record_file(str(args) + '\n\n', record_file)
for env_name, env in val_envs.items(): for env_name, env in val_envs.items():
print(env_name)
prefix = 'submit' if args.detailed_output is False else 'detail' prefix = 'submit' if args.detailed_output is False else 'detail'
output_file = os.path.join(args.pred_dir, "%s_%s_%s.json" % ( output_file = os.path.join(args.pred_dir, "%s_%s_%s.json" % (
prefix, env_name, args.fusion)) prefix, env_name, args.fusion))
if os.path.exists(output_file): if os.path.exists(output_file):
continue replace = input(f"{output_file} exists. Replace? (y/n): ")
if replace != 'y' and replace != 'Y':
continue
agent.logs = defaultdict(list) agent.logs = defaultdict(list)
agent.env = env agent.env = env