diff --git a/r2r_src/agent.py b/r2r_src/agent.py index 7203eff..7afbaa4 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -81,14 +81,14 @@ class Seq2SeqAgent(BaseAgent): # For now, the agent can't pick which forward move to make - just the one in the middle env_actions = { - 'left': (0,-1, 0), # left - 'right': (0, 1, 0), # right - 'up': (0, 0, 1), # up - 'down': (0, 0,-1), # down - 'forward': (1, 0, 0), # forward - '': (0, 0, 0), # - '': (0, 0, 0), # - '': (0, 0, 0) # + 'left': ([0],[-1], [0]), # left + 'right': ([0], [1], [0]), # right + 'up': ([0], [0], [1]), # up + 'down': ([0], [0],[-1]), # down + 'forward': ([1], [0], [0]), # forward + '': ([0], [0], [0]), # + '': ([0], [0], [0]), # + '': ([0], [0], [0]) # } def __init__(self, env, results_path, tok, episode_len=20): @@ -149,6 +149,7 @@ class Seq2SeqAgent(BaseAgent): def _candidate_variable(self, obs): candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) + # Note: The candidate_feat at len(ob['candidate']) is the feature for the END # which is zero in my implementation for i, ob in enumerate(obs): @@ -195,7 +196,7 @@ class Seq2SeqAgent(BaseAgent): """ def take_action(i, idx, name): if type(name) is int: # Go to the next view - self.env.env.sims[idx].makeAction(name, 0, 0) + self.env.env.sims[idx].makeAction([name], [0], [0]) else: # Adjust self.env.env.sims[idx].makeAction(*self.env_actions[name]) @@ -216,13 +217,15 @@ class Seq2SeqAgent(BaseAgent): while src_level > trg_level: # Tune down take_action(i, idx, 'down') src_level -= 1 - while self.env.env.sims[idx].getState().viewIndex != trg_point: # Turn right until the target + while self.env.env.sims[idx].getState()[0].viewIndex != trg_point: # Turn right until the target take_action(i, idx, 'right') assert select_candidate['viewpointId'] == \ - self.env.env.sims[idx].getState().navigableLocations[select_candidate['idx']].viewpointId + self.env.env.sims[idx].getState()[0].navigableLocations[select_candidate['idx']].viewpointId take_action(i, idx, select_candidate['idx']) - state = self.env.env.sims[idx].getState() + state = self.env.env.sims[idx].getState()[0] + # print(state.rgb.shape) + # print("action: {} view_index: {}".format(action, state.viewIndex)) if traj is not None: traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation)) @@ -237,11 +240,13 @@ class Seq2SeqAgent(BaseAgent): if self.feedback == 'teacher' or self.feedback == 'argmax': train_rl = False + # self.env is `R2RBatch` + # get obervation if reset: # Reset env obs = np.array(self.env.reset()) else: - obs = np.array(self.env._get_obs()) - + obs = np.array(self.env._get_obs()) + batch_size = len(obs) # Language input @@ -289,6 +294,10 @@ class Seq2SeqAgent(BaseAgent): input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) + print("input_a_t: ", input_a_t.shape) + print("candidate_feat: ", candidate_feat.shape) + print("candidate_leng: ", candidate_leng) + # the first [CLS] token, initialized by the language BERT, serves # as the agent's state passing through time steps if (t >= 1) or (args.vlnbert=='prevalent'): @@ -398,6 +407,8 @@ class Seq2SeqAgent(BaseAgent): if ended.all(): break + print() + if train_rl: # Last action in A2C input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) diff --git a/r2r_src/env.py b/r2r_src/env.py index accbcd8..5b36e62 100644 --- a/r2r_src/env.py +++ b/r2r_src/env.py @@ -16,7 +16,7 @@ import networkx as nx from param import args from utils import load_datasets, load_nav_graphs, pad_instr_tokens - +from IPython import embed csv.field_size_limit(sys.maxsize) @@ -52,7 +52,7 @@ class EnvBatch(): sim.setDiscretizedViewingAngles(True) # Set increment/decrement to 30 degree. (otherwise by radians) sim.setCameraResolution(self.image_w, self.image_h) sim.setCameraVFOV(math.radians(self.vfov)) - sim.init() + sim.initialize() self.sims.append(sim) def _make_id(self, scanId, viewpointId): @@ -60,7 +60,7 @@ class EnvBatch(): def newEpisodes(self, scanIds, viewpointIds, headings): for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)): - self.sims[i].newEpisode(scanId, viewpointId, heading, 0) + self.sims[i].newEpisode([scanId], [viewpointId], [heading], [0]) def getStates(self): """ @@ -71,7 +71,7 @@ class EnvBatch(): """ feature_states = [] for i, sim in enumerate(self.sims): - state = sim.getState() + state = sim.getState()[0] long_id = self._make_id(state.scanId, state.location.viewpointId) if self.features: @@ -103,9 +103,11 @@ class R2RBatch(): self.tok = tokenizer scans = [] for split in splits: + max_len = 0 for i_item, item in enumerate(load_datasets([split])): - if args.test_only and i_item == 64: - break + max_len = i_item + # if args.test_only and i_item == 64: + # break if "/" in split: try: new_item = dict(item) @@ -119,6 +121,7 @@ class R2RBatch(): continue else: # Split multiple instructions into separate entries + # print("HERE") for j, instr in enumerate(item['instructions']): try: new_item = dict(item) @@ -135,6 +138,7 @@ class R2RBatch(): scans.append(item['scan']) except: continue + print("split {} has {} datas in the file.".format(split, max_len)) if name is None: self.name = splits[0] if len(splits) > 0 else "FAKE" @@ -222,26 +226,34 @@ class R2RBatch(): def make_candidate(self, feature, scanId, viewpointId, viewId): def _loc_distance(loc): return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2) + + # viewId 就是 view index base_heading = (viewId % 12) * math.radians(30) adj_dict = {} long_id = "%s_%s" % (scanId, viewpointId) + + if long_id not in self.buffered_state_dict: + + # 36 view index for ix in range(36): if ix == 0: - self.sim.newEpisode(scanId, viewpointId, 0, math.radians(-30)) + self.sim.newEpisode([scanId], [viewpointId], [0], [math.radians(-30)]) elif ix % 12 == 0: - self.sim.makeAction(0, 1.0, 1.0) + self.sim.makeAction([0], [1.0], [1.0]) else: - self.sim.makeAction(0, 1.0, 0) + self.sim.makeAction([0], [1.0], [0]) - state = self.sim.getState() + state = self.sim.getState()[0] assert state.viewIndex == ix # Heading and elevation for the viewpoint center heading = state.heading - base_heading elevation = state.elevation + # feature 是 np.zeros((36, 2048)) visual_feat = feature[ix] + # (2048) # get adjacent locations for j, loc in enumerate(state.navigableLocations[1:]): @@ -267,6 +279,8 @@ class R2RBatch(): 'feature': np.concatenate((visual_feat, angle_feat), -1) } candidate = list(adj_dict.values()) + + # 放 buffer self.buffered_state_dict[long_id] = [ {key: c[key] for key in @@ -293,15 +307,24 @@ class R2RBatch(): def _get_obs(self): obs = [] + + # self.env is `EnvBatch` + # [ ((30, 2048), sim_state) ] * batch_size for i, (feature, state) in enumerate(self.env.getStates()): + + # self.batch 看不懂 item = self.batch[i] + + # now viewpoint index base_view_id = state.viewIndex if feature is None: feature = np.zeros((36, 2048)) # Full features + # candidate 就是 navigable viewpoint candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex) + # [visual_feature, angle_feature] for views feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1) @@ -328,6 +351,7 @@ class R2RBatch(): def reset(self, batch=None, inject=False, **kwargs): ''' Load a new minibatch / episodes. ''' + if batch is None: # Allow the user to explicitly define the batch self._next_minibatch(**kwargs) else: @@ -339,6 +363,8 @@ class R2RBatch(): scanIds = [item['scan'] for item in self.batch] viewpointIds = [item['path'][0] for item in self.batch] headings = [item['heading'] for item in self.batch] + + # self.env is `EnvBatch` self.env.newEpisodes(scanIds, viewpointIds, headings) return self._get_obs() diff --git a/r2r_src/eval.py b/r2r_src/eval.py index ee0efcf..162cc5f 100644 --- a/r2r_src/eval.py +++ b/r2r_src/eval.py @@ -24,13 +24,20 @@ class Evaluation(object): self.gt = {} self.instr_ids = [] self.scans = [] + + + print(splits) + for split in splits: for item in load_datasets([split]): if scans is not None and item['scan'] not in scans: + print("ignore {}".format(item['scan'])) continue self.gt[str(item['path_id'])] = item self.scans.append(item['scan']) self.instr_ids += ['%s_%d' % (item['path_id'], i) for i in range(len(item['instructions']))] + # for i in self.instr_ids: + # print(i) self.scans = set(self.scans) self.instr_ids = set(self.instr_ids) self.graphs = load_nav_graphs(self.scans) @@ -81,17 +88,22 @@ class Evaluation(object): else: results = output_file - print('result length', len(results)) + # print('result length', len(results)) + # print("RESULT:", results) for item in results: # Check against expected ids if item['instr_id'] in instr_ids: + # print("{} exist".format(item['instr_id'])) instr_ids.remove(item['instr_id']) self._score_item(item['instr_id'], item['trajectory']) + else: + print("{} not exist".format(item['instr_id'])) + print(item) - if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial) - assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\ - % (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file) - assert len(self.scores['nav_errors']) == len(self.instr_ids) + # if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial) + # assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\ + # % (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file) + # assert len(self.scores['nav_errors']) == len(self.instr_ids) score_summary = { 'nav_error': np.average(self.scores['nav_errors']), 'oracle_error': np.average(self.scores['oracle_errors']), diff --git a/r2r_src/train.py b/r2r_src/train.py index f3d9cce..aa9b5ef 100644 --- a/r2r_src/train.py +++ b/r2r_src/train.py @@ -58,7 +58,7 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None): start = time.time() print('\nListener training starts, start iteration: %s' % str(start_iter)) - best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}} + best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}, 'val_train_seen': {"spl": 0., "sr": 0., "state":"", 'update':False}} for idx in range(start_iter, start_iter+n_iters, log_every): listner.logs = defaultdict(list) @@ -91,6 +91,10 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None): RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1) IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1) entropy = sum(listner.logs['entropy']) / total + + print("training:") + print("total: {}, length: {}, critic_loss: {}, RL_loss: {}, IL_loss:{}, entropy: {}".format(total, length, critic_loss, RL_loss, IL_loss, entropy + )) writer.add_scalar("loss/critic", critic_loss, idx) writer.add_scalar("policy_entropy", entropy, idx) writer.add_scalar("loss/RL_loss", RL_loss, idx) @@ -136,13 +140,14 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None): print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters), iter, float(iter)/n_iters*100, loss_str))) - if iter % 1000 == 0: + if iter % log_every == 0: print("BEST RESULT TILL NOW") for env_name in best_val: print(env_name, best_val[env_name]['state']) record_file = open('./logs/' + args.name + '.txt', 'a') record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n') + # print(best_val) record_file.close() listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx))) @@ -193,16 +198,18 @@ def train_val(test_only=False): val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) - val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] + # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] + val_env_names = ['val_train_seen'] train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) from collections import OrderedDict - + if args.submit: val_env_names.append('test') else: pass + print("only {} in train_val()".format(val_env_names)) val_envs = OrderedDict( ((split, (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok), @@ -213,7 +220,7 @@ def train_val(test_only=False): ) if args.train == 'listener': - train(train_env, tok, args.iters, val_envs=val_envs) + train(train_env, tok, args.iters, val_envs=val_envs, log_every=1000) elif args.train == 'validlistener': valid(train_env, tok, val_envs=val_envs) else: @@ -254,6 +261,8 @@ def train_val_augment(test_only=False): if __name__ == "__main__": + device = torch.device('cuda') + print(torch.cuda.get_device_name(device)) if args.train in ['listener', 'validlistener']: train_val(test_only=args.test_only) elif args.train == 'auglistener': diff --git a/r2r_src/utils.py b/r2r_src/utils.py index 6da9e89..7cb9776 100644 --- a/r2r_src/utils.py +++ b/r2r_src/utils.py @@ -64,11 +64,12 @@ def load_datasets(splits): number = -1 if len(components) > 1: split, number = components[0], int(components[1]) - + print("number:", number) # Load Json # if split in ['train', 'val_seen', 'val_unseen', 'test', # 'val_unseen_half1', 'val_unseen_half2', 'val_seen_half1', 'val_seen_half2']: # Add two halves for sanity check if "/" not in split: + print('here: data/R2R_%s.json' % split) with open('data/R2R_%s.json' % split) as f: new_data = json.load(f) else: @@ -348,7 +349,7 @@ def new_simulator(): sim.setCameraResolution(WIDTH, HEIGHT) sim.setCameraVFOV(math.radians(VFOV)) sim.setDiscretizedViewingAngles(True) - sim.init() + sim.initialize() return sim @@ -358,14 +359,14 @@ def get_point_angle_feature(baseViewId=0): feature = np.empty((36, args.angle_feat_size), np.float32) base_heading = (baseViewId % 12) * math.radians(30) for ix in range(36): - if ix == 0: - sim.newEpisode('ZMojNkEp431', '2f4d90acd4024c269fb0efe49a8ac540', 0, math.radians(-30)) + if ix == 0: + sim.newEpisode(['ZMojNkEp431'], ['2f4d90acd4024c269fb0efe49a8ac540'], [0], [math.radians(-30)]) elif ix % 12 == 0: - sim.makeAction(0, 1.0, 1.0) + sim.makeAction([0], [1.0], [1.0]) else: - sim.makeAction(0, 1.0, 0) + sim.makeAction([0], [1.0], [0]) - state = sim.getState() + state = sim.getState()[0] assert state.viewIndex == ix heading = state.heading - base_heading @@ -560,7 +561,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt str_format = "{0:." + str(decimals) + "f}" percents = str_format.format(100 * (iteration / float(total))) filled_length = int(round(bar_length * iteration / float(total))) - bar = '█' * filled_length + '-' * (bar_length - filled_length) + bar = 'LL' * filled_length + '-' * (bar_length - filled_length) sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), diff --git a/r2r_src/vlnbert/vlnbert_OSCAR.py b/r2r_src/vlnbert/vlnbert_OSCAR.py index c6bd2c1..855f367 100644 --- a/r2r_src/vlnbert/vlnbert_OSCAR.py +++ b/r2r_src/vlnbert/vlnbert_OSCAR.py @@ -9,7 +9,13 @@ from torch import nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss, MSELoss -from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings, +#from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings, +# BertSelfAttention, BertAttention, BertEncoder, BertLayer, +# BertSelfOutput, BertIntermediate, BertOutput, +# BertPooler, BertLayerNorm, BertPreTrainedModel, +# BertPredictionHeadTransform) + +from pytorch_transformers.modeling_bert import (BertEmbeddings, BertSelfAttention, BertAttention, BertEncoder, BertLayer, BertSelfOutput, BertIntermediate, BertOutput, BertPooler, BertLayerNorm, BertPreTrainedModel, @@ -185,7 +191,8 @@ class BertImgModel(BertPreTrainedModel): self.img_dim = config.img_feature_dim logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim)) - self.apply(self.init_weights) + # self.apply(self.init_weights) + self.init_weights() def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, img_feats=None): @@ -230,6 +237,7 @@ class VLNBert(BertPreTrainedModel): def __init__(self, config): super(VLNBert, self).__init__(config) self.config = config + print("Init VLNBERT: ", self.config) self.bert = BertImgModel(config) self.vis_lang_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -237,7 +245,8 @@ class VLNBert(BertPreTrainedModel): self.state_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.apply(self.init_weights) + # self.apply(self.init_weights) + self.init_weights() def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, img_feats=None): diff --git a/r2r_src/vlnbert/vlnbert_PREVALENT.py b/r2r_src/vlnbert/vlnbert_PREVALENT.py index 4e3ee30..8453d0c 100644 --- a/r2r_src/vlnbert/vlnbert_PREVALENT.py +++ b/r2r_src/vlnbert/vlnbert_PREVALENT.py @@ -14,7 +14,8 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss -from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig +#from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig +from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig import pdb logger = logging.getLogger(__name__) @@ -366,9 +367,12 @@ class VisionEncoder(nn.Module): class VLNBert(BertPreTrainedModel): def __init__(self, config): super(VLNBert, self).__init__(config) + self.embeddings = BertEmbeddings(config) self.pooler = BertPooler(config) + print("Init VLNBert: ", config) + self.img_dim = config.img_feature_dim # 2176 logger.info('VLNBert Image Dimension: {}'.format(self.img_dim)) self.img_feature_type = config.img_feature_type # '' @@ -379,7 +383,8 @@ class VLNBert(BertPreTrainedModel): self.addlayer = nn.ModuleList( [LXRTXLayer(config) for _ in range(self.vl_layers)]) self.vision_encoder = VisionEncoder(self.config.img_feature_dim, self.config) - self.apply(self.init_weights) + # self.apply(self.init_weights) + self.init_weights() def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None, lang_mask=None, vis_mask=None, position_ids=None, head_mask=None, img_feats=None): diff --git a/r2r_src/vlnbert/vlnbert_init.py b/r2r_src/vlnbert/vlnbert_init.py index 3d423a8..0393943 100644 --- a/r2r_src/vlnbert/vlnbert_init.py +++ b/r2r_src/vlnbert/vlnbert_init.py @@ -1,11 +1,12 @@ # Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au -from transformers.pytorch_transformers import (BertConfig, BertTokenizer) +#from transformers.pytorch_transformers import (BertConfig, BertTokenizer) +from pytorch_transformers import (BertConfig, BertTokenizer) def get_tokenizer(args): if args.vlnbert == 'oscar': tokenizer_class = BertTokenizer - model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997' + model_name_or_path = 'r2r_src/vlnbert/Oscar/pretrained_models/base-no-labels/ep_67_588997' tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True) elif args.vlnbert == 'prevalent': tokenizer_class = BertTokenizer @@ -16,9 +17,10 @@ def get_vlnbert_models(args, config=None): config_class = BertConfig if args.vlnbert == 'oscar': + print('\n VLN-BERT model is Oscar!!!') from vlnbert.vlnbert_OSCAR import VLNBert model_class = VLNBert - model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997' + model_name_or_path = 'r2r_src/vlnbert/Oscar/pretrained_models/base-no-labels/ep_67_588997' vis_config = config_class.from_pretrained(model_name_or_path, num_labels=2, finetuning_task='vln-r2r') vis_config.model_type = 'visual' @@ -31,9 +33,11 @@ def get_vlnbert_models(args, config=None): visual_model = model_class.from_pretrained(model_name_or_path, from_tf=False, config=vis_config) elif args.vlnbert == 'prevalent': + print('\n VLN-BERT model is prevalent!!!') from vlnbert.vlnbert_PREVALENT import VLNBert model_class = VLNBert - model_name_or_path = 'Prevalent/pretrained_model/pytorch_model.bin' + #model_name_or_path = './Prevalent/pretrained_model/pytorch_model.bin' + model_name_or_path = 'r2r_src/vlnbert/Prevalent/pretrained_model/pytorch_model.bin' vis_config = config_class.from_pretrained('bert-base-uncased') vis_config.img_feature_dim = 2176 vis_config.img_feature_type = "" diff --git a/run/test_agent.bash b/run/test_agent.bash index e0e70a0..b13a361 100644 --- a/run/test_agent.bash +++ b/run/test_agent.bash @@ -23,4 +23,4 @@ flag="--vlnbert prevalent --dropout 0.5" mkdir -p snap/$name -CUDA_VISIBLE_DEVICES=1 python r2r_src/train.py $flag --name $name +CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name diff --git a/run/test_agent_r2r.bash b/run/test_agent_r2r.bash new file mode 100644 index 0000000..b13a361 --- /dev/null +++ b/run/test_agent_r2r.bash @@ -0,0 +1,26 @@ +name=VLNBERT-test-Prevalent + +flag="--vlnbert prevalent + + --submit 0 + --test_only 0 + + --train validlistener + --load snap/VLNBERT-PREVALENT-final/state_dict/best_val_unseen + + --features places365 + --maxAction 15 + --batchSize 8 + --feedback sample + --lr 1e-5 + --iters 300000 + --optim adamW + + --mlWeight 0.20 + --maxInput 80 + --angleFeatSize 128 + --featdropout 0.4 + --dropout 0.5" + +mkdir -p snap/$name +CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name diff --git a/run/train_agent.bash b/run/train_agent.bash index 048777b..827d63c 100644 --- a/run/train_agent.bash +++ b/run/train_agent.bash @@ -22,4 +22,4 @@ flag="--vlnbert prevalent --dropout 0.5" mkdir -p snap/$name -CUDA_VISIBLE_DEVICES=1 python r2r_src/train.py $flag --name $name +CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name diff --git a/run/train_agent_no_aug.bash b/run/train_agent_no_aug.bash new file mode 100644 index 0000000..dffece6 --- /dev/null +++ b/run/train_agent_no_aug.bash @@ -0,0 +1,24 @@ +name=VLNBERT-train-Prevalent + +flag="--vlnbert prevalent + + --test_only 0 + + --train listener + + --features places365 + --maxAction 15 + --batchSize 8 + --feedback sample + --lr 1e-5 + --iters 300000 + --optim adamW + + --mlWeight 0.20 + --maxInput 80 + --angleFeatSize 128 + --featdropout 0.4 + --dropout 0.5" + +mkdir -p snap/$name +CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name