feat: vlnbert which can run with adversarial json
This commit is contained in:
parent
832c6368dd
commit
ab5010d32d
@ -294,9 +294,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||||
|
|
||||||
print("input_a_t: ", input_a_t.shape)
|
|
||||||
print("candidate_feat: ", candidate_feat.shape)
|
|
||||||
print("candidate_leng: ", candidate_leng)
|
|
||||||
|
|
||||||
# the first [CLS] token, initialized by the language BERT, serves
|
# the first [CLS] token, initialized by the language BERT, serves
|
||||||
# as the agent's state passing through time steps
|
# as the agent's state passing through time steps
|
||||||
@ -407,7 +404,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
if ended.all():
|
if ended.all():
|
||||||
break
|
break
|
||||||
|
|
||||||
print()
|
|
||||||
|
|
||||||
if train_rl:
|
if train_rl:
|
||||||
# Last action in A2C
|
# Last action in A2C
|
||||||
|
|||||||
@ -195,11 +195,11 @@ def train_val(test_only=False):
|
|||||||
|
|
||||||
if test_only:
|
if test_only:
|
||||||
featurized_scans = None
|
featurized_scans = None
|
||||||
val_env_names = ['val_train_seen']
|
val_env_names = ['val_unseen']
|
||||||
else:
|
else:
|
||||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||||
val_env_names = ['val_train_seen']
|
val_env_names = ['val_unseen']
|
||||||
|
|
||||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user