feat: runable vlnbert
This commit is contained in:
parent
266406a161
commit
832c6368dd
@ -81,14 +81,14 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
# For now, the agent can't pick which forward move to make - just the one in the middle
|
||||
env_actions = {
|
||||
'left': (0,-1, 0), # left
|
||||
'right': (0, 1, 0), # right
|
||||
'up': (0, 0, 1), # up
|
||||
'down': (0, 0,-1), # down
|
||||
'forward': (1, 0, 0), # forward
|
||||
'<end>': (0, 0, 0), # <end>
|
||||
'<start>': (0, 0, 0), # <start>
|
||||
'<ignore>': (0, 0, 0) # <ignore>
|
||||
'left': ([0],[-1], [0]), # left
|
||||
'right': ([0], [1], [0]), # right
|
||||
'up': ([0], [0], [1]), # up
|
||||
'down': ([0], [0],[-1]), # down
|
||||
'forward': ([1], [0], [0]), # forward
|
||||
'<end>': ([0], [0], [0]), # <end>
|
||||
'<start>': ([0], [0], [0]), # <start>
|
||||
'<ignore>': ([0], [0], [0]) # <ignore>
|
||||
}
|
||||
|
||||
def __init__(self, env, results_path, tok, episode_len=20):
|
||||
@ -149,6 +149,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
def _candidate_variable(self, obs):
|
||||
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
|
||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
|
||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||
# which is zero in my implementation
|
||||
for i, ob in enumerate(obs):
|
||||
@ -195,7 +196,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
"""
|
||||
def take_action(i, idx, name):
|
||||
if type(name) is int: # Go to the next view
|
||||
self.env.env.sims[idx].makeAction(name, 0, 0)
|
||||
self.env.env.sims[idx].makeAction([name], [0], [0])
|
||||
else: # Adjust
|
||||
self.env.env.sims[idx].makeAction(*self.env_actions[name])
|
||||
|
||||
@ -216,13 +217,15 @@ class Seq2SeqAgent(BaseAgent):
|
||||
while src_level > trg_level: # Tune down
|
||||
take_action(i, idx, 'down')
|
||||
src_level -= 1
|
||||
while self.env.env.sims[idx].getState().viewIndex != trg_point: # Turn right until the target
|
||||
while self.env.env.sims[idx].getState()[0].viewIndex != trg_point: # Turn right until the target
|
||||
take_action(i, idx, 'right')
|
||||
assert select_candidate['viewpointId'] == \
|
||||
self.env.env.sims[idx].getState().navigableLocations[select_candidate['idx']].viewpointId
|
||||
self.env.env.sims[idx].getState()[0].navigableLocations[select_candidate['idx']].viewpointId
|
||||
take_action(i, idx, select_candidate['idx'])
|
||||
|
||||
state = self.env.env.sims[idx].getState()
|
||||
state = self.env.env.sims[idx].getState()[0]
|
||||
# print(state.rgb.shape)
|
||||
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
||||
if traj is not None:
|
||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||
|
||||
@ -237,11 +240,13 @@ class Seq2SeqAgent(BaseAgent):
|
||||
if self.feedback == 'teacher' or self.feedback == 'argmax':
|
||||
train_rl = False
|
||||
|
||||
# self.env is `R2RBatch`
|
||||
# get obervation
|
||||
if reset: # Reset env
|
||||
obs = np.array(self.env.reset())
|
||||
else:
|
||||
obs = np.array(self.env._get_obs())
|
||||
|
||||
obs = np.array(self.env._get_obs())
|
||||
|
||||
batch_size = len(obs)
|
||||
|
||||
# Language input
|
||||
@ -289,6 +294,10 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
print("input_a_t: ", input_a_t.shape)
|
||||
print("candidate_feat: ", candidate_feat.shape)
|
||||
print("candidate_leng: ", candidate_leng)
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, serves
|
||||
# as the agent's state passing through time steps
|
||||
if (t >= 1) or (args.vlnbert=='prevalent'):
|
||||
@ -398,6 +407,8 @@ class Seq2SeqAgent(BaseAgent):
|
||||
if ended.all():
|
||||
break
|
||||
|
||||
print()
|
||||
|
||||
if train_rl:
|
||||
# Last action in A2C
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
@ -16,7 +16,7 @@ import networkx as nx
|
||||
from param import args
|
||||
|
||||
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
|
||||
|
||||
from IPython import embed
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class EnvBatch():
|
||||
sim.setDiscretizedViewingAngles(True) # Set increment/decrement to 30 degree. (otherwise by radians)
|
||||
sim.setCameraResolution(self.image_w, self.image_h)
|
||||
sim.setCameraVFOV(math.radians(self.vfov))
|
||||
sim.init()
|
||||
sim.initialize()
|
||||
self.sims.append(sim)
|
||||
|
||||
def _make_id(self, scanId, viewpointId):
|
||||
@ -60,7 +60,7 @@ class EnvBatch():
|
||||
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings):
|
||||
for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
|
||||
self.sims[i].newEpisode([scanId], [viewpointId], [heading], [0])
|
||||
|
||||
def getStates(self):
|
||||
"""
|
||||
@ -71,7 +71,7 @@ class EnvBatch():
|
||||
"""
|
||||
feature_states = []
|
||||
for i, sim in enumerate(self.sims):
|
||||
state = sim.getState()
|
||||
state = sim.getState()[0]
|
||||
|
||||
long_id = self._make_id(state.scanId, state.location.viewpointId)
|
||||
if self.features:
|
||||
@ -103,9 +103,11 @@ class R2RBatch():
|
||||
self.tok = tokenizer
|
||||
scans = []
|
||||
for split in splits:
|
||||
max_len = 0
|
||||
for i_item, item in enumerate(load_datasets([split])):
|
||||
if args.test_only and i_item == 64:
|
||||
break
|
||||
max_len = i_item
|
||||
# if args.test_only and i_item == 64:
|
||||
# break
|
||||
if "/" in split:
|
||||
try:
|
||||
new_item = dict(item)
|
||||
@ -119,6 +121,7 @@ class R2RBatch():
|
||||
continue
|
||||
else:
|
||||
# Split multiple instructions into separate entries
|
||||
# print("HERE")
|
||||
for j, instr in enumerate(item['instructions']):
|
||||
try:
|
||||
new_item = dict(item)
|
||||
@ -135,6 +138,7 @@ class R2RBatch():
|
||||
scans.append(item['scan'])
|
||||
except:
|
||||
continue
|
||||
print("split {} has {} datas in the file.".format(split, max_len))
|
||||
|
||||
if name is None:
|
||||
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
||||
@ -222,26 +226,34 @@ class R2RBatch():
|
||||
def make_candidate(self, feature, scanId, viewpointId, viewId):
|
||||
def _loc_distance(loc):
|
||||
return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2)
|
||||
|
||||
# viewId 就是 view index
|
||||
base_heading = (viewId % 12) * math.radians(30)
|
||||
adj_dict = {}
|
||||
long_id = "%s_%s" % (scanId, viewpointId)
|
||||
|
||||
|
||||
if long_id not in self.buffered_state_dict:
|
||||
|
||||
# 36 view index
|
||||
for ix in range(36):
|
||||
if ix == 0:
|
||||
self.sim.newEpisode(scanId, viewpointId, 0, math.radians(-30))
|
||||
self.sim.newEpisode([scanId], [viewpointId], [0], [math.radians(-30)])
|
||||
elif ix % 12 == 0:
|
||||
self.sim.makeAction(0, 1.0, 1.0)
|
||||
self.sim.makeAction([0], [1.0], [1.0])
|
||||
else:
|
||||
self.sim.makeAction(0, 1.0, 0)
|
||||
self.sim.makeAction([0], [1.0], [0])
|
||||
|
||||
state = self.sim.getState()
|
||||
state = self.sim.getState()[0]
|
||||
assert state.viewIndex == ix
|
||||
|
||||
# Heading and elevation for the viewpoint center
|
||||
heading = state.heading - base_heading
|
||||
elevation = state.elevation
|
||||
|
||||
# feature 是 np.zeros((36, 2048))
|
||||
visual_feat = feature[ix]
|
||||
# (2048)
|
||||
|
||||
# get adjacent locations
|
||||
for j, loc in enumerate(state.navigableLocations[1:]):
|
||||
@ -267,6 +279,8 @@ class R2RBatch():
|
||||
'feature': np.concatenate((visual_feat, angle_feat), -1)
|
||||
}
|
||||
candidate = list(adj_dict.values())
|
||||
|
||||
# 放 buffer
|
||||
self.buffered_state_dict[long_id] = [
|
||||
{key: c[key]
|
||||
for key in
|
||||
@ -293,15 +307,24 @@ class R2RBatch():
|
||||
|
||||
def _get_obs(self):
|
||||
obs = []
|
||||
|
||||
# self.env is `EnvBatch`
|
||||
# [ ((30, 2048), sim_state) ] * batch_size
|
||||
for i, (feature, state) in enumerate(self.env.getStates()):
|
||||
|
||||
# self.batch 看不懂
|
||||
item = self.batch[i]
|
||||
|
||||
# now viewpoint index
|
||||
base_view_id = state.viewIndex
|
||||
|
||||
if feature is None:
|
||||
feature = np.zeros((36, 2048))
|
||||
|
||||
# Full features
|
||||
# candidate 就是 navigable viewpoint
|
||||
candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)
|
||||
|
||||
# [visual_feature, angle_feature] for views
|
||||
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
||||
|
||||
@ -328,6 +351,7 @@ class R2RBatch():
|
||||
|
||||
def reset(self, batch=None, inject=False, **kwargs):
|
||||
''' Load a new minibatch / episodes. '''
|
||||
|
||||
if batch is None: # Allow the user to explicitly define the batch
|
||||
self._next_minibatch(**kwargs)
|
||||
else:
|
||||
@ -339,6 +363,8 @@ class R2RBatch():
|
||||
scanIds = [item['scan'] for item in self.batch]
|
||||
viewpointIds = [item['path'][0] for item in self.batch]
|
||||
headings = [item['heading'] for item in self.batch]
|
||||
|
||||
# self.env is `EnvBatch`
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings)
|
||||
return self._get_obs()
|
||||
|
||||
|
||||
@ -24,13 +24,20 @@ class Evaluation(object):
|
||||
self.gt = {}
|
||||
self.instr_ids = []
|
||||
self.scans = []
|
||||
|
||||
|
||||
print(splits)
|
||||
|
||||
for split in splits:
|
||||
for item in load_datasets([split]):
|
||||
if scans is not None and item['scan'] not in scans:
|
||||
print("ignore {}".format(item['scan']))
|
||||
continue
|
||||
self.gt[str(item['path_id'])] = item
|
||||
self.scans.append(item['scan'])
|
||||
self.instr_ids += ['%s_%d' % (item['path_id'], i) for i in range(len(item['instructions']))]
|
||||
# for i in self.instr_ids:
|
||||
# print(i)
|
||||
self.scans = set(self.scans)
|
||||
self.instr_ids = set(self.instr_ids)
|
||||
self.graphs = load_nav_graphs(self.scans)
|
||||
@ -81,17 +88,22 @@ class Evaluation(object):
|
||||
else:
|
||||
results = output_file
|
||||
|
||||
print('result length', len(results))
|
||||
# print('result length', len(results))
|
||||
# print("RESULT:", results)
|
||||
for item in results:
|
||||
# Check against expected ids
|
||||
if item['instr_id'] in instr_ids:
|
||||
# print("{} exist".format(item['instr_id']))
|
||||
instr_ids.remove(item['instr_id'])
|
||||
self._score_item(item['instr_id'], item['trajectory'])
|
||||
else:
|
||||
print("{} not exist".format(item['instr_id']))
|
||||
print(item)
|
||||
|
||||
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
||||
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
||||
% (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file)
|
||||
assert len(self.scores['nav_errors']) == len(self.instr_ids)
|
||||
# if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
||||
# assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
||||
# % (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file)
|
||||
# assert len(self.scores['nav_errors']) == len(self.instr_ids)
|
||||
score_summary = {
|
||||
'nav_error': np.average(self.scores['nav_errors']),
|
||||
'oracle_error': np.average(self.scores['oracle_errors']),
|
||||
|
||||
@ -58,7 +58,7 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
start = time.time()
|
||||
print('\nListener training starts, start iteration: %s' % str(start_iter))
|
||||
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}}
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}, 'val_train_seen': {"spl": 0., "sr": 0., "state":"", 'update':False}}
|
||||
|
||||
for idx in range(start_iter, start_iter+n_iters, log_every):
|
||||
listner.logs = defaultdict(list)
|
||||
@ -91,6 +91,10 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
|
||||
IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
|
||||
entropy = sum(listner.logs['entropy']) / total
|
||||
|
||||
print("training:")
|
||||
print("total: {}, length: {}, critic_loss: {}, RL_loss: {}, IL_loss:{}, entropy: {}".format(total, length, critic_loss, RL_loss, IL_loss, entropy
|
||||
))
|
||||
writer.add_scalar("loss/critic", critic_loss, idx)
|
||||
writer.add_scalar("policy_entropy", entropy, idx)
|
||||
writer.add_scalar("loss/RL_loss", RL_loss, idx)
|
||||
@ -136,13 +140,14 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters),
|
||||
iter, float(iter)/n_iters*100, loss_str)))
|
||||
|
||||
if iter % 1000 == 0:
|
||||
if iter % log_every == 0:
|
||||
print("BEST RESULT TILL NOW")
|
||||
for env_name in best_val:
|
||||
print(env_name, best_val[env_name]['state'])
|
||||
|
||||
record_file = open('./logs/' + args.name + '.txt', 'a')
|
||||
record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n')
|
||||
# print(best_val)
|
||||
record_file.close()
|
||||
|
||||
listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx)))
|
||||
@ -193,16 +198,18 @@ def train_val(test_only=False):
|
||||
val_env_names = ['val_train_seen']
|
||||
else:
|
||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
val_env_names = ['val_train_seen']
|
||||
|
||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
if args.submit:
|
||||
val_env_names.append('test')
|
||||
else:
|
||||
pass
|
||||
|
||||
print("only {} in train_val()".format(val_env_names))
|
||||
val_envs = OrderedDict(
|
||||
((split,
|
||||
(R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok),
|
||||
@ -213,7 +220,7 @@ def train_val(test_only=False):
|
||||
)
|
||||
|
||||
if args.train == 'listener':
|
||||
train(train_env, tok, args.iters, val_envs=val_envs)
|
||||
train(train_env, tok, args.iters, val_envs=val_envs, log_every=1000)
|
||||
elif args.train == 'validlistener':
|
||||
valid(train_env, tok, val_envs=val_envs)
|
||||
else:
|
||||
@ -254,6 +261,8 @@ def train_val_augment(test_only=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = torch.device('cuda')
|
||||
print(torch.cuda.get_device_name(device))
|
||||
if args.train in ['listener', 'validlistener']:
|
||||
train_val(test_only=args.test_only)
|
||||
elif args.train == 'auglistener':
|
||||
|
||||
@ -64,11 +64,12 @@ def load_datasets(splits):
|
||||
number = -1
|
||||
if len(components) > 1:
|
||||
split, number = components[0], int(components[1])
|
||||
|
||||
print("number:", number)
|
||||
# Load Json
|
||||
# if split in ['train', 'val_seen', 'val_unseen', 'test',
|
||||
# 'val_unseen_half1', 'val_unseen_half2', 'val_seen_half1', 'val_seen_half2']: # Add two halves for sanity check
|
||||
if "/" not in split:
|
||||
print('here: data/R2R_%s.json' % split)
|
||||
with open('data/R2R_%s.json' % split) as f:
|
||||
new_data = json.load(f)
|
||||
else:
|
||||
@ -348,7 +349,7 @@ def new_simulator():
|
||||
sim.setCameraResolution(WIDTH, HEIGHT)
|
||||
sim.setCameraVFOV(math.radians(VFOV))
|
||||
sim.setDiscretizedViewingAngles(True)
|
||||
sim.init()
|
||||
sim.initialize()
|
||||
|
||||
return sim
|
||||
|
||||
@ -358,14 +359,14 @@ def get_point_angle_feature(baseViewId=0):
|
||||
feature = np.empty((36, args.angle_feat_size), np.float32)
|
||||
base_heading = (baseViewId % 12) * math.radians(30)
|
||||
for ix in range(36):
|
||||
if ix == 0:
|
||||
sim.newEpisode('ZMojNkEp431', '2f4d90acd4024c269fb0efe49a8ac540', 0, math.radians(-30))
|
||||
if ix == 0:
|
||||
sim.newEpisode(['ZMojNkEp431'], ['2f4d90acd4024c269fb0efe49a8ac540'], [0], [math.radians(-30)])
|
||||
elif ix % 12 == 0:
|
||||
sim.makeAction(0, 1.0, 1.0)
|
||||
sim.makeAction([0], [1.0], [1.0])
|
||||
else:
|
||||
sim.makeAction(0, 1.0, 0)
|
||||
sim.makeAction([0], [1.0], [0])
|
||||
|
||||
state = sim.getState()
|
||||
state = sim.getState()[0]
|
||||
assert state.viewIndex == ix
|
||||
|
||||
heading = state.heading - base_heading
|
||||
@ -560,7 +561,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt
|
||||
str_format = "{0:." + str(decimals) + "f}"
|
||||
percents = str_format.format(100 * (iteration / float(total)))
|
||||
filled_length = int(round(bar_length * iteration / float(total)))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
bar = 'LL' * filled_length + '-' * (bar_length - filled_length)
|
||||
|
||||
sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
|
||||
|
||||
|
||||
@ -9,7 +9,13 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
|
||||
#from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
|
||||
# BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
||||
# BertSelfOutput, BertIntermediate, BertOutput,
|
||||
# BertPooler, BertLayerNorm, BertPreTrainedModel,
|
||||
# BertPredictionHeadTransform)
|
||||
|
||||
from pytorch_transformers.modeling_bert import (BertEmbeddings,
|
||||
BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
||||
BertSelfOutput, BertIntermediate, BertOutput,
|
||||
BertPooler, BertLayerNorm, BertPreTrainedModel,
|
||||
@ -185,7 +191,8 @@ class BertImgModel(BertPreTrainedModel):
|
||||
self.img_dim = config.img_feature_dim
|
||||
logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
|
||||
|
||||
self.apply(self.init_weights)
|
||||
# self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, img_feats=None):
|
||||
@ -230,6 +237,7 @@ class VLNBert(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(VLNBert, self).__init__(config)
|
||||
self.config = config
|
||||
print("Init VLNBERT: ", self.config)
|
||||
self.bert = BertImgModel(config)
|
||||
|
||||
self.vis_lang_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@ -237,7 +245,8 @@ class VLNBert(BertPreTrainedModel):
|
||||
self.state_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.apply(self.init_weights)
|
||||
# self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, img_feats=None):
|
||||
|
||||
@ -14,7 +14,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
|
||||
#from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
|
||||
from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
|
||||
import pdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -366,9 +367,12 @@ class VisionEncoder(nn.Module):
|
||||
class VLNBert(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(VLNBert, self).__init__(config)
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
print("Init VLNBert: ", config)
|
||||
|
||||
self.img_dim = config.img_feature_dim # 2176
|
||||
logger.info('VLNBert Image Dimension: {}'.format(self.img_dim))
|
||||
self.img_feature_type = config.img_feature_type # ''
|
||||
@ -379,7 +383,8 @@ class VLNBert(BertPreTrainedModel):
|
||||
self.addlayer = nn.ModuleList(
|
||||
[LXRTXLayer(config) for _ in range(self.vl_layers)])
|
||||
self.vision_encoder = VisionEncoder(self.config.img_feature_dim, self.config)
|
||||
self.apply(self.init_weights)
|
||||
# self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None,
|
||||
attention_mask=None, lang_mask=None, vis_mask=None, position_ids=None, head_mask=None, img_feats=None):
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
# Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
|
||||
|
||||
from transformers.pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
#from transformers.pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
from pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
|
||||
def get_tokenizer(args):
|
||||
if args.vlnbert == 'oscar':
|
||||
tokenizer_class = BertTokenizer
|
||||
model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
model_name_or_path = 'r2r_src/vlnbert/Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
|
||||
elif args.vlnbert == 'prevalent':
|
||||
tokenizer_class = BertTokenizer
|
||||
@ -16,9 +17,10 @@ def get_vlnbert_models(args, config=None):
|
||||
config_class = BertConfig
|
||||
|
||||
if args.vlnbert == 'oscar':
|
||||
print('\n VLN-BERT model is Oscar!!!')
|
||||
from vlnbert.vlnbert_OSCAR import VLNBert
|
||||
model_class = VLNBert
|
||||
model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
model_name_or_path = 'r2r_src/vlnbert/Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
vis_config = config_class.from_pretrained(model_name_or_path, num_labels=2, finetuning_task='vln-r2r')
|
||||
|
||||
vis_config.model_type = 'visual'
|
||||
@ -31,9 +33,11 @@ def get_vlnbert_models(args, config=None):
|
||||
visual_model = model_class.from_pretrained(model_name_or_path, from_tf=False, config=vis_config)
|
||||
|
||||
elif args.vlnbert == 'prevalent':
|
||||
print('\n VLN-BERT model is prevalent!!!')
|
||||
from vlnbert.vlnbert_PREVALENT import VLNBert
|
||||
model_class = VLNBert
|
||||
model_name_or_path = 'Prevalent/pretrained_model/pytorch_model.bin'
|
||||
#model_name_or_path = './Prevalent/pretrained_model/pytorch_model.bin'
|
||||
model_name_or_path = 'r2r_src/vlnbert/Prevalent/pretrained_model/pytorch_model.bin'
|
||||
vis_config = config_class.from_pretrained('bert-base-uncased')
|
||||
vis_config.img_feature_dim = 2176
|
||||
vis_config.img_feature_type = ""
|
||||
|
||||
@ -23,4 +23,4 @@ flag="--vlnbert prevalent
|
||||
--dropout 0.5"
|
||||
|
||||
mkdir -p snap/$name
|
||||
CUDA_VISIBLE_DEVICES=1 python r2r_src/train.py $flag --name $name
|
||||
CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name
|
||||
|
||||
26
run/test_agent_r2r.bash
Normal file
26
run/test_agent_r2r.bash
Normal file
@ -0,0 +1,26 @@
|
||||
name=VLNBERT-test-Prevalent
|
||||
|
||||
flag="--vlnbert prevalent
|
||||
|
||||
--submit 0
|
||||
--test_only 0
|
||||
|
||||
--train validlistener
|
||||
--load snap/VLNBERT-PREVALENT-final/state_dict/best_val_unseen
|
||||
|
||||
--features places365
|
||||
--maxAction 15
|
||||
--batchSize 8
|
||||
--feedback sample
|
||||
--lr 1e-5
|
||||
--iters 300000
|
||||
--optim adamW
|
||||
|
||||
--mlWeight 0.20
|
||||
--maxInput 80
|
||||
--angleFeatSize 128
|
||||
--featdropout 0.4
|
||||
--dropout 0.5"
|
||||
|
||||
mkdir -p snap/$name
|
||||
CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name
|
||||
@ -22,4 +22,4 @@ flag="--vlnbert prevalent
|
||||
--dropout 0.5"
|
||||
|
||||
mkdir -p snap/$name
|
||||
CUDA_VISIBLE_DEVICES=1 python r2r_src/train.py $flag --name $name
|
||||
CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name
|
||||
|
||||
24
run/train_agent_no_aug.bash
Normal file
24
run/train_agent_no_aug.bash
Normal file
@ -0,0 +1,24 @@
|
||||
name=VLNBERT-train-Prevalent
|
||||
|
||||
flag="--vlnbert prevalent
|
||||
|
||||
--test_only 0
|
||||
|
||||
--train listener
|
||||
|
||||
--features places365
|
||||
--maxAction 15
|
||||
--batchSize 8
|
||||
--feedback sample
|
||||
--lr 1e-5
|
||||
--iters 300000
|
||||
--optim adamW
|
||||
|
||||
--mlWeight 0.20
|
||||
--maxInput 80
|
||||
--angleFeatSize 128
|
||||
--featdropout 0.4
|
||||
--dropout 0.5"
|
||||
|
||||
mkdir -p snap/$name
|
||||
CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name
|
||||
Loading…
Reference in New Issue
Block a user