feat: add found in get_obs()
This commit is contained in:
parent
93e8b23316
commit
bc8bc1b9d4
@ -87,6 +87,7 @@ def construct_instrs(anno_dir, dataset, splits, tokenizer, max_instr_len=512):
|
||||
new_item['objId'] = None
|
||||
new_item['instruction'] = instr
|
||||
new_item['instr_encoding'] = item['instr_encodings'][j][:max_instr_len]
|
||||
new_item['found'] = item['found'][j]
|
||||
del new_item['instructions']
|
||||
del new_item['instr_encodings']
|
||||
data.append(new_item)
|
||||
|
||||
@ -311,6 +311,7 @@ class ReverieObjectNavBatch(object):
|
||||
'navigableLocations' : state.navigableLocations,
|
||||
'instruction' : item['instruction'],
|
||||
'instr_encoding': item['instr_encoding'],
|
||||
'gt_found' : item['found'],
|
||||
'gt_path' : item['path'],
|
||||
'gt_end_vps': item.get('end_vps', []),
|
||||
'gt_obj_id': item['objId'],
|
||||
|
||||
@ -65,8 +65,8 @@ def build_dataset(args, rank=0):
|
||||
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
|
||||
)
|
||||
|
||||
# val_env_names = ['val_train_seen']
|
||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
val_env_names = [ 'val_seen', 'val_unseen']
|
||||
|
||||
if args.submit:
|
||||
val_env_names.append('test')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user