From ab5010d32daa1f15fae9072892f8c58ecbe7a00d Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 6 Nov 2023 15:51:50 +0800 Subject: [PATCH] feat: vlnbert which can run with adversarial json --- r2r_src/agent.py | 4 ---- r2r_src/train.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/r2r_src/agent.py b/r2r_src/agent.py index 7afbaa4..aae150a 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -294,9 +294,6 @@ class Seq2SeqAgent(BaseAgent): input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) - print("input_a_t: ", input_a_t.shape) - print("candidate_feat: ", candidate_feat.shape) - print("candidate_leng: ", candidate_leng) # the first [CLS] token, initialized by the language BERT, serves # as the agent's state passing through time steps @@ -407,7 +404,6 @@ class Seq2SeqAgent(BaseAgent): if ended.all(): break - print() if train_rl: # Last action in A2C diff --git a/r2r_src/train.py b/r2r_src/train.py index aa9b5ef..162d6dd 100644 --- a/r2r_src/train.py +++ b/r2r_src/train.py @@ -195,11 +195,11 @@ def train_val(test_only=False): if test_only: featurized_scans = None - val_env_names = ['val_train_seen'] + val_env_names = ['val_unseen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] - val_env_names = ['val_train_seen'] + val_env_names = ['val_unseen'] train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) from collections import OrderedDict