feat: vlnbert which can run with adversarial json

This commit is contained in:
Ting-Jun Wang 2023-11-06 15:51:50 +08:00
parent 832c6368dd
commit ab5010d32d
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
2 changed files with 2 additions and 6 deletions

View File

@ -294,9 +294,6 @@ class Seq2SeqAgent(BaseAgent):
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
print("input_a_t: ", input_a_t.shape)
print("candidate_feat: ", candidate_feat.shape)
print("candidate_leng: ", candidate_leng)
# the first [CLS] token, initialized by the language BERT, serves
# as the agent's state passing through time steps
@ -407,7 +404,6 @@ class Seq2SeqAgent(BaseAgent):
if ended.all():
break
print()
if train_rl:
# Last action in A2C

View File

@ -195,11 +195,11 @@ def train_val(test_only=False):
if test_only:
featurized_scans = None
val_env_names = ['val_train_seen']
val_env_names = ['val_unseen']
else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
val_env_names = ['val_train_seen']
val_env_names = ['val_unseen']
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
from collections import OrderedDict