feat: add found in get_obs()
This commit is contained in:
parent
93e8b23316
commit
bc8bc1b9d4
@ -87,6 +87,7 @@ def construct_instrs(anno_dir, dataset, splits, tokenizer, max_instr_len=512):
|
|||||||
new_item['objId'] = None
|
new_item['objId'] = None
|
||||||
new_item['instruction'] = instr
|
new_item['instruction'] = instr
|
||||||
new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len]
|
new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len]
|
||||||
|
new_item['found'] = item['found'][j]
|
||||||
del new_item['instructions']
|
del new_item['instructions']
|
||||||
del new_item['instr_encodings']
|
del new_item['instr_encodings']
|
||||||
data.append(new_item)
|
data.append(new_item)
|
||||||
|
|||||||
@ -311,6 +311,7 @@ class ReverieObjectNavBatch(object):
|
|||||||
'navigableLocations' : state.navigableLocations,
|
'navigableLocations' : state.navigableLocations,
|
||||||
'instruction' : item['instruction'],
|
'instruction' : item['instruction'],
|
||||||
'instr_encoding': item['instr_encoding'],
|
'instr_encoding': item['instr_encoding'],
|
||||||
|
'gt_found' : item['found'],
|
||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'gt_end_vps': item.get('end_vps', []),
|
'gt_end_vps': item.get('end_vps', []),
|
||||||
'gt_obj_id': item['objId'],
|
'gt_obj_id': item['objId'],
|
||||||
|
|||||||
@ -65,8 +65,8 @@ def build_dataset(args, rank=0):
|
|||||||
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
|
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
|
||||||
)
|
)
|
||||||
|
|
||||||
# val_env_names = ['val_train_seen']
|
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
val_env_names = [ 'val_seen', 'val_unseen']
|
||||||
|
|
||||||
if args.submit:
|
if args.submit:
|
||||||
val_env_names.append('test')
|
val_env_names.append('test')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user