feat: vlnbert which can run with adversarial json
This commit is contained in:
parent
832c6368dd
commit
ab5010d32d
@ -294,9 +294,6 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
print("input_a_t: ", input_a_t.shape)
|
||||
print("candidate_feat: ", candidate_feat.shape)
|
||||
print("candidate_leng: ", candidate_leng)
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, serves
|
||||
# as the agent's state passing through time steps
|
||||
@ -407,7 +404,6 @@ class Seq2SeqAgent(BaseAgent):
|
||||
if ended.all():
|
||||
break
|
||||
|
||||
print()
|
||||
|
||||
if train_rl:
|
||||
# Last action in A2C
|
||||
|
||||
@ -195,11 +195,11 @@ def train_val(test_only=False):
|
||||
|
||||
if test_only:
|
||||
featurized_scans = None
|
||||
val_env_names = ['val_train_seen']
|
||||
val_env_names = ['val_unseen']
|
||||
else:
|
||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
val_env_names = ['val_train_seen']
|
||||
val_env_names = ['val_unseen']
|
||||
|
||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||
from collections import OrderedDict
|
||||
|
||||
Loading…
Reference in New Issue
Block a user