feat: add found in get_obs()

This commit is contained in:
Ting-Jun Wang 2023-12-11 01:42:14 +08:00
parent 93e8b23316
commit bc8bc1b9d4
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 4 additions and 2 deletions

View File

@ -87,6 +87,7 @@ def construct_instrs(anno_dir, dataset, splits, tokenizer, max_instr_len=512):
new_item['objId'] = None new_item['objId'] = None
new_item['instruction'] = instr new_item['instruction'] = instr
new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len] new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len]
new_item['found'] = item['found'][j]
del new_item['instructions'] del new_item['instructions']
del new_item['instr_encodings'] del new_item['instr_encodings']
data.append(new_item) data.append(new_item)

View File

@ -311,6 +311,7 @@ class ReverieObjectNavBatch(object):
'navigableLocations' : state.navigableLocations, 'navigableLocations' : state.navigableLocations,
'instruction' : item['instruction'], 'instruction' : item['instruction'],
'instr_encoding': item['instr_encoding'], 'instr_encoding': item['instr_encoding'],
'gt_found' : item['found'],
'gt_path' : item['path'], 'gt_path' : item['path'],
'gt_end_vps': item.get('end_vps', []), 'gt_end_vps': item.get('end_vps', []),
'gt_obj_id': item['objId'], 'gt_obj_id': item['objId'],

View File

@ -65,8 +65,8 @@ def build_dataset(args, rank=0):
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints, multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
) )
# val_env_names = ['val_train_seen'] # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] val_env_names = [ 'val_seen', 'val_unseen']
if args.submit: if args.submit:
val_env_names.append('test') val_env_names.append('test')