fix: some KEY ERROR situation
- situation: in visualization tool, some False Negative case(receive adversarial instruction but return True(found)) will get the found_objid but None. So it will raise a Key error interrupt. - Solution: found that there is a condition function in agent_obj.py return NOT_FOUND when detecting the found_objId is -1 but it doesn't detect None so that it always return FOUND when the found_objId is None. It doens't follow the rules. I add a new condition to check whether the found_objId is -1 or None - Other: For run the testing scripts, the user can shoose to replace the exists output file.
This commit is contained in:
parent
b2dce6111e
commit
de3326ae85
@ -27,7 +27,13 @@ class BaseAgent(object):
|
||||
def get_results(self, detailed_output=False):
|
||||
output = []
|
||||
for k, v in self.results.items():
|
||||
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
|
||||
output.append({
|
||||
'instr_id': k,
|
||||
'trajectory': v['path'],
|
||||
'pred_objid': v['pred_objid'],
|
||||
'found': v['found'],
|
||||
'gt_found': v['gt_found']
|
||||
})
|
||||
if detailed_output:
|
||||
output[-1]['details'] = v['details']
|
||||
return output
|
||||
|
||||
@ -504,7 +504,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
|
||||
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
|
||||
traj[i]['pred_objid'] = stop_score['og']
|
||||
if stop_score['og'] == -1:
|
||||
if stop_score['og'] == -1 or stop_score['og'] == None:
|
||||
traj[i]['found'] = False
|
||||
else:
|
||||
traj[i]['found'] = True
|
||||
|
||||
@ -69,7 +69,9 @@ def build_dataset(args, rank=0):
|
||||
val_env_names = [ 'val_seen', 'val_unseen']
|
||||
|
||||
if args.submit:
|
||||
val_env_names.append('test')
|
||||
include_test = input('Include test dataset? (y/n)')
|
||||
if include_test == 'y' or include_test == 'Y':
|
||||
val_env_names.append('test')
|
||||
|
||||
val_envs = {}
|
||||
for split in val_env_names:
|
||||
@ -239,11 +241,14 @@ def valid(args, train_env, val_envs, rank=-1):
|
||||
write_to_record_file(str(args) + '\n\n', record_file)
|
||||
|
||||
for env_name, env in val_envs.items():
|
||||
print(env_name)
|
||||
prefix = 'submit' if args.detailed_output is False else 'detail'
|
||||
output_file = os.path.join(args.pred_dir, "%s_%s_%s.json" % (
|
||||
prefix, env_name, args.fusion))
|
||||
if os.path.exists(output_file):
|
||||
continue
|
||||
replace = input(f"{output_file} exists. Replace? (y/n): ")
|
||||
if replace != 'y' and replace != 'Y':
|
||||
continue
|
||||
agent.logs = defaultdict(list)
|
||||
agent.env = env
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user