Compare commits
10 Commits
1602aefcb5
...
ce1ac697cd
| Author | SHA1 | Date | |
|---|---|---|---|
| ce1ac697cd | |||
| 52b90f5298 | |||
|
|
ad147f464a | ||
|
|
885a17c433 | ||
|
|
ee8022c050 | ||
|
|
9e8511956e | ||
|
|
5d5a8c5e92 | ||
|
|
d8bbabdce2 | ||
|
|
ced7c35ce7 | ||
|
|
a8b36d7184 |
22
LICENSE
Normal file
22
LICENSE
Normal file
@ -0,0 +1,22 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2021 Yicong Hong, Qi Wu, Yuankai Qi,
|
||||
Cristian Rodriguez-Opazo, Stephen Gould
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
19
README.md
19
README.md
@ -1,16 +1,19 @@
|
||||
# Recurrent VLN-BERT
|
||||
|
||||
Code of the Recurrent-VLN-BERT paper:
|
||||
Code of the **CVPR 2021 Oral** paper:<br>
|
||||
**A Recurrent Vision-and-Language BERT for Navigation**<br>
|
||||
[**Yicong Hong**](http://www.yiconghong.me/), [Qi Wu](http://www.qi-wu.me/), [Yuankai Qi](https://sites.google.com/site/yuankiqi/home), [Cristian Rodriguez-Opazo](https://crodriguezo.github.io/), [Stephen Gould](http://users.cecs.anu.edu.au/~sgould/)<br>
|
||||
|
||||
[[Paper & Appendices](https://arxiv.org/abs/2011.13922) | [GitHub](https://github.com/YicongHong/Recurrent-VLN-BERT)]
|
||||
[[Paper & Appendices](https://arxiv.org/abs/2011.13922)] [[GitHub](https://github.com/YicongHong/Recurrent-VLN-BERT)]
|
||||
|
||||
"*Neo : Are you saying I have to choose whether Trinity lives or dies? The Oracle : No, you've already made the choice. Now you have to understand it.*" --- [The Matrix Reloaded (2003)](https://www.imdb.com/title/tt0234215/).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Installation
|
||||
|
||||
Install the [Matterport3D Simulator](https://github.com/peteanderson80/Matterport3DSimulator).
|
||||
Install the [Matterport3D Simulator](https://github.com/peteanderson80/Matterport3DSimulator). Notice that this code uses the [old version (v0.1)](https://github.com/peteanderson80/Matterport3DSimulator/tree/v0.1) of the simulator, but you can easily change to the latest version which supports batches of agents and it is much more efficient.
|
||||
|
||||
Please find the versions of packages in our environment [here](https://github.com/YicongHong/Recurrent-VLN-BERT/blob/main/recurrent-vln-bert.yml).
|
||||
|
||||
Install the [Pytorch-Transformers](https://github.com/huggingface/transformers).
|
||||
@ -69,10 +72,12 @@ The trained Navigator will be saved under `snap/`.
|
||||
## Citation
|
||||
If you use or discuss our Recurrent VLN-BERT, please cite our paper:
|
||||
```
|
||||
@article{hong2020recurrent,
|
||||
title={A Recurrent Vision-and-Language BERT for Navigation},
|
||||
@InProceedings{Hong_2021_CVPR,
|
||||
author = {Hong, Yicong and Wu, Qi and Qi, Yuankai and Rodriguez-Opazo, Cristian and Gould, Stephen},
|
||||
journal={arXiv preprint arXiv:2011.13922},
|
||||
year={2020}
|
||||
title = {A Recurrent Vision-and-Language BERT for Navigation},
|
||||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2021},
|
||||
pages = {1643-1653}
|
||||
}
|
||||
```
|
||||
|
||||
254
r2r_src/agent.py
254
r2r_src/agent.py
@ -1,5 +1,3 @@
|
||||
# R2R-EnvDrop, 2019, haotan@cs.unc.edu
|
||||
# Modified in Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
|
||||
|
||||
import json
|
||||
import os
|
||||
@ -16,9 +14,9 @@ from torch import optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from env import R2RBatch
|
||||
from utils import padding_idx, add_idx, Tokenizer, print_progress
|
||||
import utils
|
||||
from utils import padding_idx, print_progress
|
||||
import model_OSCAR, model_PREVALENT
|
||||
import model
|
||||
import param
|
||||
from param import args
|
||||
from collections import defaultdict
|
||||
@ -35,12 +33,12 @@ class BaseAgent(object):
|
||||
self.losses = [] # For learning agents
|
||||
|
||||
def write_results(self):
|
||||
output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()]
|
||||
output = [{'instr_id': k, 'trajectory': v, 'ref': r} for k, (v,r) in self.results.items()]
|
||||
with open(self.results_path, 'w') as f:
|
||||
json.dump(output, f)
|
||||
|
||||
def get_results(self):
|
||||
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
|
||||
output = [{'instr_id': k, 'trajectory': v, 'ref': r} for k, (v,r) in self.results.items()]
|
||||
return output
|
||||
|
||||
def rollout(self, **args):
|
||||
@ -63,7 +61,7 @@ class BaseAgent(object):
|
||||
for i in range(iters):
|
||||
for traj in self.rollout(**kwargs):
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = traj['path']
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['ref'])
|
||||
else: # Do a full round
|
||||
while True:
|
||||
for traj in self.rollout(**kwargs):
|
||||
@ -71,7 +69,7 @@ class BaseAgent(object):
|
||||
looped = True
|
||||
else:
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = traj['path']
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['ref'])
|
||||
if looped:
|
||||
break
|
||||
|
||||
@ -81,14 +79,14 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
# For now, the agent can't pick which forward move to make - just the one in the middle
|
||||
env_actions = {
|
||||
'left': (0,-1, 0), # left
|
||||
'right': (0, 1, 0), # right
|
||||
'up': (0, 0, 1), # up
|
||||
'down': (0, 0,-1), # down
|
||||
'forward': (1, 0, 0), # forward
|
||||
'<end>': (0, 0, 0), # <end>
|
||||
'<start>': (0, 0, 0), # <start>
|
||||
'<ignore>': (0, 0, 0) # <ignore>
|
||||
'left': ([0],[-1], [0]), # left
|
||||
'right': ([0], [1], [0]), # right
|
||||
'up': ([0], [0], [1]), # up
|
||||
'down': ([0], [0],[-1]), # down
|
||||
'forward': ([1], [0], [0]), # forward
|
||||
'<end>': ([0], [0], [0]), # <end>
|
||||
'<start>': ([0], [0], [0]), # <start>
|
||||
'<ignore>': ([0], [0], [0]) # <ignore>
|
||||
}
|
||||
|
||||
def __init__(self, env, results_path, tok, episode_len=20):
|
||||
@ -98,12 +96,9 @@ class Seq2SeqAgent(BaseAgent):
|
||||
self.feature_size = self.env.feature_size
|
||||
|
||||
# Models
|
||||
if args.vlnbert == 'oscar':
|
||||
self.vln_bert = model_OSCAR.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
|
||||
self.critic = model_OSCAR.Critic().cuda()
|
||||
elif args.vlnbert == 'prevalent':
|
||||
self.vln_bert = model_PREVALENT.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
|
||||
self.critic = model_PREVALENT.Critic().cuda()
|
||||
self.vln_bert = model.VLNBERT(directions=args.directions,
|
||||
feature_size=self.feature_size + args.angle_feat_size).cuda()
|
||||
self.critic = model.Critic().cuda()
|
||||
self.models = (self.vln_bert, self.critic)
|
||||
|
||||
# Optimizers
|
||||
@ -114,16 +109,21 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Evaluations
|
||||
self.losses = []
|
||||
self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
|
||||
self.criterion_REF = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
|
||||
self.ndtw_criterion = utils.ndtw_initialize()
|
||||
self.objProposals, self.obj2viewpoint = utils.loadObjProposals()
|
||||
|
||||
# Logs
|
||||
sys.stdout.flush()
|
||||
self.logs = defaultdict(list)
|
||||
|
||||
def _sort_batch(self, obs):
|
||||
''' Extract instructions from a list of observations and sort by descending
|
||||
sequence length (to enable PyTorch packing). '''
|
||||
|
||||
seq_tensor = np.array([ob['instr_encoding'] for ob in obs])
|
||||
seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1)
|
||||
seq_lengths[seq_lengths == 0] = seq_tensor.shape[1]
|
||||
seq_lengths[seq_lengths == 0] = seq_tensor.shape[1] # Full length
|
||||
|
||||
seq_tensor = torch.from_numpy(seq_tensor)
|
||||
seq_lengths = torch.from_numpy(seq_lengths)
|
||||
@ -131,23 +131,30 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Sort sequences by lengths
|
||||
seq_lengths, perm_idx = seq_lengths.sort(0, True) # True -> descending
|
||||
sorted_tensor = seq_tensor[perm_idx]
|
||||
mask = (sorted_tensor != padding_idx)
|
||||
mask = (sorted_tensor != padding_idx) # seq_lengths[0] is the Maximum length
|
||||
|
||||
token_type_ids = torch.zeros_like(mask)
|
||||
|
||||
visual_mask = torch.ones(args.directions).bool()
|
||||
visual_mask = visual_mask.unsqueeze(0).repeat(mask.size(0),1)
|
||||
visual_mask = torch.cat((mask, visual_mask), -1)
|
||||
|
||||
return Variable(sorted_tensor, requires_grad=False).long().cuda(), \
|
||||
mask.long().cuda(), token_type_ids.long().cuda(), \
|
||||
visual_mask.long().cuda(), \
|
||||
list(seq_lengths), list(perm_idx)
|
||||
|
||||
def _feature_variable(self, obs):
|
||||
''' Extract precomputed features into variable. '''
|
||||
features = np.empty((len(obs), args.views, self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
features = np.empty((len(obs), args.directions, self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
|
||||
for i, ob in enumerate(obs):
|
||||
features[i, :, :] = ob['feature'] # Image feat
|
||||
|
||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||
|
||||
def _candidate_variable(self, obs):
|
||||
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
|
||||
candidate_leng = [len(ob['candidate']) for ob in obs]
|
||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||
# which is zero in my implementation
|
||||
@ -157,17 +164,33 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||
|
||||
def _object_variable(self, obs):
|
||||
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
|
||||
cand_obj_feat = np.zeros((len(obs), max(cand_obj_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
cand_obj_pos = np.zeros((len(obs), max(cand_obj_leng), 5), dtype=np.float32)
|
||||
|
||||
for i, ob in enumerate(obs):
|
||||
obj_local_pos, obj_features, candidate_objId = ob['candidate_obj']
|
||||
for j, cc in enumerate(candidate_objId):
|
||||
cand_obj_feat[i, j, :] = obj_features[j]
|
||||
cand_obj_pos[i, j, :] = obj_local_pos[j]
|
||||
|
||||
return torch.from_numpy(cand_obj_feat).cuda(), torch.from_numpy(cand_obj_pos).cuda(), cand_obj_leng
|
||||
|
||||
def get_input_feat(self, obs):
|
||||
input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32)
|
||||
for i, ob in enumerate(obs):
|
||||
input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation'])
|
||||
input_a_t = torch.from_numpy(input_a_t).cuda()
|
||||
# f_t = self._feature_variable(obs) # Pano image features from obs
|
||||
|
||||
f_t = self._feature_variable(obs) # Image features from obs
|
||||
candidate_feat, candidate_leng = self._candidate_variable(obs)
|
||||
|
||||
return input_a_t, candidate_feat, candidate_leng
|
||||
obj_feat, obj_pos, obj_leng = self._object_variable(obs)
|
||||
|
||||
def _teacher_action(self, obs, ended):
|
||||
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
|
||||
|
||||
def _teacher_action(self, obs, ended, cand_size):
|
||||
"""
|
||||
Extract teacher actions into variable.
|
||||
:param obs: The observation.
|
||||
@ -185,7 +208,22 @@ class Seq2SeqAgent(BaseAgent):
|
||||
break
|
||||
else: # Stop here
|
||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||
a[i] = len(ob['candidate'])
|
||||
a[i] = cand_size - 1
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def _teacher_REF(self, obs, just_ended):
|
||||
a = np.zeros(len(obs), dtype=np.int64)
|
||||
for i, ob in enumerate(obs):
|
||||
if not just_ended[i]: # Just ignore this index
|
||||
a[i] = args.ignoreid
|
||||
else:
|
||||
candidate_objs = ob['candidate_obj'][2]
|
||||
for k, kid in enumerate(candidate_objs):
|
||||
if kid == ob['objId']:
|
||||
a[i] = k
|
||||
break
|
||||
else:
|
||||
a[i] = args.ignoreid
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
||||
@ -195,7 +233,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
"""
|
||||
def take_action(i, idx, name):
|
||||
if type(name) is int: # Go to the next view
|
||||
self.env.env.sims[idx].makeAction(name, 0, 0)
|
||||
self.env.env.sims[idx].makeAction([name], [0], [0])
|
||||
else: # Adjust
|
||||
self.env.env.sims[idx].makeAction(*self.env_actions[name])
|
||||
|
||||
@ -216,22 +254,24 @@ class Seq2SeqAgent(BaseAgent):
|
||||
while src_level > trg_level: # Tune down
|
||||
take_action(i, idx, 'down')
|
||||
src_level -= 1
|
||||
while self.env.env.sims[idx].getState().viewIndex != trg_point: # Turn right until the target
|
||||
while self.env.env.sims[idx].getState()[0].viewIndex != trg_point: # Turn right until the target
|
||||
take_action(i, idx, 'right')
|
||||
assert select_candidate['viewpointId'] == \
|
||||
self.env.env.sims[idx].getState().navigableLocations[select_candidate['idx']].viewpointId
|
||||
self.env.env.sims[idx].getState()[0].navigableLocations[select_candidate['idx']].viewpointId
|
||||
take_action(i, idx, select_candidate['idx'])
|
||||
|
||||
state = self.env.env.sims[idx].getState()
|
||||
state = self.env.env.sims[idx].getState()[0]
|
||||
if traj is not None:
|
||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||
|
||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||
def rollout(self, train_ml=None, train_rl=True, reset=True, speaker=None):
|
||||
"""
|
||||
:param train_ml: The weight to train with maximum likelihood
|
||||
:param train_rl: whether use RL in training
|
||||
:param reset: Reset the environment
|
||||
|
||||
:param speaker: Speaker used in back translation.
|
||||
If the speaker is not None, use back translation.
|
||||
O.w., normal training
|
||||
:return:
|
||||
"""
|
||||
if self.feedback == 'teacher' or self.feedback == 'argmax':
|
||||
@ -244,9 +284,9 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
batch_size = len(obs)
|
||||
|
||||
# Language input
|
||||
# Reorder the language input for the encoder (do not ruin the original code)
|
||||
sentence, language_attention_mask, token_type_ids, \
|
||||
seq_lengths, perm_idx = self._sort_batch(obs)
|
||||
visual_attention_mask, seq_lengths, perm_idx = self._sort_batch(obs)
|
||||
perm_obs = obs[perm_idx]
|
||||
|
||||
''' Language BERT '''
|
||||
@ -255,15 +295,13 @@ class Seq2SeqAgent(BaseAgent):
|
||||
'attention_mask': language_attention_mask,
|
||||
'lang_mask': language_attention_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
if args.vlnbert == 'oscar':
|
||||
language_features = self.vln_bert(**language_inputs)
|
||||
elif args.vlnbert == 'prevalent':
|
||||
h_t, language_features = self.vln_bert(**language_inputs)
|
||||
|
||||
# Record starting point
|
||||
traj = [{
|
||||
'instr_id': ob['instr_id'],
|
||||
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
||||
'ref': None
|
||||
} for ob in perm_obs]
|
||||
|
||||
# Init the reward shaping
|
||||
@ -276,48 +314,65 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
# Initialization the tracking state
|
||||
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
|
||||
just_ended = np.array([False] * batch_size)
|
||||
|
||||
# Init the logs
|
||||
rewards = []
|
||||
hidden_states = []
|
||||
policy_log_probs = []
|
||||
masks = []
|
||||
stop_mask = torch.tensor([False] * batch_size).cuda().unsqueeze(1)
|
||||
entropys = []
|
||||
ml_loss = 0.
|
||||
ref_loss = 0.
|
||||
|
||||
for t in range(self.episode_len):
|
||||
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, serves
|
||||
# the first [CLS] token, initialized by the language BERT, servers
|
||||
# as the agent's state passing through time steps
|
||||
if (t >= 1) or (args.vlnbert=='prevalent'):
|
||||
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
|
||||
|
||||
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
|
||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
||||
obj_temp_mask = (utils.length2mask(obj_leng) == 0).long()
|
||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1)
|
||||
|
||||
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
||||
self.vln_bert.vln_bert.config.obj_directions = max(obj_leng)
|
||||
|
||||
''' Visual BERT '''
|
||||
visual_inputs = {'mode': 'visual',
|
||||
'sentence': language_features,
|
||||
'attention_mask': visual_attention_mask,
|
||||
'lang_mask': language_attention_mask,
|
||||
'vis_mask': visual_temp_mask,
|
||||
'obj_mask': obj_temp_mask,
|
||||
'token_type_ids': token_type_ids,
|
||||
'action_feats': input_a_t,
|
||||
# 'pano_feats': f_t,
|
||||
'cand_feats': candidate_feat}
|
||||
h_t, logit = self.vln_bert(**visual_inputs)
|
||||
'pano_feats': f_t,
|
||||
'cand_feats': candidate_feat,
|
||||
'obj_feats': obj_feat,
|
||||
'obj_pos': obj_pos,
|
||||
'already_dropfeat': (speaker is not None)}
|
||||
h_t, logit, logit_REF = self.vln_bert(**visual_inputs)
|
||||
hidden_states.append(h_t)
|
||||
|
||||
# print('time step', t)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
# Mask outputs where agent can't move forward
|
||||
# Here the logit is [b, max_candidate]
|
||||
if train_ml is not None:
|
||||
candidate_mask = utils.length2mask(candidate_leng)
|
||||
candidate_mask = torch.cat((candidate_mask, stop_mask), dim=-1)
|
||||
candidate_mask_obj = utils.length2mask(obj_leng)
|
||||
|
||||
logit.masked_fill_(candidate_mask, -float('inf'))
|
||||
logit_REF.masked_fill_(candidate_mask_obj, -float('inf'))
|
||||
|
||||
# Supervised training
|
||||
target = self._teacher_action(perm_obs, ended)
|
||||
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1))
|
||||
ml_loss += self.criterion(logit, target)
|
||||
|
||||
# Determine next model inputs
|
||||
@ -338,13 +393,37 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
print(self.feedback)
|
||||
sys.exit('Invalid feedback option')
|
||||
|
||||
# print('a_t', a_t)
|
||||
|
||||
# Prepare environment action
|
||||
# NOTE: Env action is in the perm_obs space
|
||||
cpu_a_t = a_t.cpu().numpy()
|
||||
for i, next_id in enumerate(cpu_a_t):
|
||||
if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
||||
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped
|
||||
just_ended[i] = True
|
||||
if self.feedback == 'argmax':
|
||||
_, ref_t = logit_REF[i].max(0)
|
||||
if ref_t != obj_leng[i]-1: # decide not to do REF
|
||||
traj[i]['ref'] = perm_obs[i]['candidate_obj'][2][ref_t]
|
||||
else:
|
||||
just_ended[i] = False
|
||||
|
||||
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end>
|
||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||
|
||||
''' Supervised training for REF '''
|
||||
if train_ml is not None:
|
||||
target_obj = self._teacher_REF(perm_obs, just_ended)
|
||||
ref_loss += self.criterion_REF(logit_REF, target_obj)
|
||||
|
||||
# print('logit', logit)
|
||||
# print('logit_REF', logit_REF)
|
||||
# print('just_ended', just_ended)
|
||||
# print('ended', ended)
|
||||
# print('cpu_a_t', cpu_a_t)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
# Make action and get the new state
|
||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
||||
obs = np.array(self.env._get_obs())
|
||||
@ -358,22 +437,29 @@ class Seq2SeqAgent(BaseAgent):
|
||||
mask = np.ones(batch_size, np.float32)
|
||||
for i, ob in enumerate(perm_obs):
|
||||
dist[i] = ob['distance']
|
||||
path_act = [vp[0] for vp in traj[i]['path']]
|
||||
ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
|
||||
|
||||
# path_act = [vp[0] for vp in traj[i]['path']]
|
||||
# ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
|
||||
if ended[i]:
|
||||
reward[i] = 0.0
|
||||
mask[i] = 0.0
|
||||
else:
|
||||
action_idx = cpu_a_t[i]
|
||||
# Target reward
|
||||
if action_idx == -1: # If the action now is end
|
||||
if dist[i] < 3.0: # Correct
|
||||
# navigation success if the target object is visible when STOP
|
||||
# end_viewpoint_id = ob['scan'] + '_' + ob['viewpoint']
|
||||
# if self.objProposals.__contains__(end_viewpoint_id):
|
||||
# if ob['objId'] in self.objProposals[end_viewpoint_id]['objId']:
|
||||
# reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
# else:
|
||||
# reward[i] = -2.0
|
||||
# else:
|
||||
# reward[i] = -2.0
|
||||
if dist[i] < 1.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
else: # The action is not end
|
||||
# Path fidelity rewards (distance & nDTW)
|
||||
# Change of distance and nDTW reward
|
||||
reward[i] = - (dist[i] - last_dist[i])
|
||||
ndtw_reward = ndtw_score[i] - last_ndtw[i]
|
||||
if reward[i] > 0.0: # Quantification
|
||||
@ -382,7 +468,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
reward[i] = -1.0 + ndtw_reward
|
||||
else:
|
||||
raise NameError("The action doesn't change the move")
|
||||
# Miss the target penalty
|
||||
# miss the target penalty
|
||||
if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0):
|
||||
reward[i] -= (1.0 - last_dist[i]) * 2.0
|
||||
rewards.append(reward)
|
||||
@ -400,25 +486,31 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
if train_rl:
|
||||
# Last action in A2C
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
|
||||
|
||||
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
|
||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
||||
obj_temp_mask = (utils.length2mask(obj_leng) == 0).long()
|
||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1)
|
||||
|
||||
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
||||
self.vln_bert.vln_bert.config.obj_directions = max(obj_leng)
|
||||
''' Visual BERT '''
|
||||
visual_inputs = {'mode': 'visual',
|
||||
'sentence': language_features,
|
||||
'attention_mask': visual_attention_mask,
|
||||
'lang_mask': language_attention_mask,
|
||||
'vis_mask': visual_temp_mask,
|
||||
'obj_mask': obj_temp_mask,
|
||||
'token_type_ids': token_type_ids,
|
||||
'action_feats': input_a_t,
|
||||
# 'pano_feats': f_t,
|
||||
'cand_feats': candidate_feat}
|
||||
last_h_, _ = self.vln_bert(**visual_inputs)
|
||||
'pano_feats': f_t,
|
||||
'cand_feats': candidate_feat,
|
||||
'obj_feats': obj_feat,
|
||||
'obj_pos': obj_pos,
|
||||
'already_dropfeat': (speaker is not None)}
|
||||
last_h_, _, _ = self.vln_bert(**visual_inputs)
|
||||
|
||||
rl_loss = 0.
|
||||
|
||||
@ -440,6 +532,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
v_ = self.critic(hidden_states[t])
|
||||
a_ = (r_ - v_).detach()
|
||||
|
||||
# r_: The higher, the better. -ln(p(action)) * (discount_reward - value)
|
||||
rl_loss += (-policy_log_probs[t] * a_ * mask_).sum()
|
||||
rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss
|
||||
if self.feedback == 'sample':
|
||||
@ -458,17 +551,18 @@ class Seq2SeqAgent(BaseAgent):
|
||||
assert args.normalize_loss == 'none'
|
||||
|
||||
self.loss += rl_loss
|
||||
self.logs['RL_loss'].append(rl_loss.item())
|
||||
|
||||
if train_ml is not None:
|
||||
self.loss += ml_loss * train_ml / batch_size
|
||||
self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item())
|
||||
self.loss += ref_loss / batch_size
|
||||
|
||||
if type(self.loss) is int: # For safety, it will be activated if no losses are added
|
||||
self.losses.append(0.)
|
||||
else:
|
||||
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
return traj
|
||||
|
||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||
@ -544,7 +638,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
self.critic_optimizer.step()
|
||||
|
||||
if args.aug is None:
|
||||
print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50)
|
||||
print_progress(iter, n_iters, prefix='Progress:', suffix='Complete', bar_length=50)
|
||||
|
||||
def save(self, epoch, path):
|
||||
''' Snapshot models '''
|
||||
@ -582,3 +676,33 @@ class Seq2SeqAgent(BaseAgent):
|
||||
for param in all_tuple:
|
||||
recover_state(*param)
|
||||
return states['vln_bert']['epoch'] - 1
|
||||
|
||||
def load_pretrain(self, path):
|
||||
''' Loads parameters from pretrained network '''
|
||||
load_states = torch.load(path)
|
||||
# print(self.vln_bert.state_dict()['candidate_att_layer.linear_in.weight'])
|
||||
# print(self.vln_bert.state_dict()['visual_bert.bert.encoder.layer.9.intermediate.dense.weight'])
|
||||
|
||||
def recover_state(name, model):
|
||||
state = model.state_dict()
|
||||
model_keys = set(state.keys())
|
||||
load_keys = set(load_states[name]['state_dict'].keys())
|
||||
if model_keys != load_keys:
|
||||
print("NOTICE: DIFFERENT KEYS FOUND IN MODEL")
|
||||
for ikey in model_keys:
|
||||
if ikey not in load_keys:
|
||||
print('key not in model: ', ikey)
|
||||
for ikey in load_keys:
|
||||
if ikey not in model_keys:
|
||||
print('key not in loaded states: ', ikey)
|
||||
|
||||
state.update(load_states[name]['state_dict'])
|
||||
model.load_state_dict(state)
|
||||
|
||||
all_tuple = [("vln_bert", self.vln_bert)]
|
||||
for param in all_tuple:
|
||||
recover_state(*param)
|
||||
# print(self.vln_bert.state_dict()['candidate_att_layer.linear_in.weight'])
|
||||
# print(self.vln_bert.state_dict()['visual_bert.bert.encoder.layer.9.intermediate.dense.weight'])
|
||||
|
||||
return load_states['vln_bert']['epoch'] - 1
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
''' Batched Room-to-Room navigation environment '''
|
||||
|
||||
import sys
|
||||
sys.path.append('buildpy36')
|
||||
sys.path.append('Matterport_Simulator/build/')
|
||||
import MatterSim
|
||||
import csv
|
||||
import numpy as np
|
||||
@ -12,6 +10,7 @@ import utils
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import pickle as pkl
|
||||
import networkx as nx
|
||||
from param import args
|
||||
|
||||
@ -45,6 +44,7 @@ class EnvBatch():
|
||||
self.image_w = 640
|
||||
self.image_h = 480
|
||||
self.vfov = 60
|
||||
# self.featurized_scans = set([key.split("_")[0] for key in list(self.features.keys())])
|
||||
self.sims = []
|
||||
for i in range(batch_size):
|
||||
sim = MatterSim.Simulator()
|
||||
@ -52,7 +52,7 @@ class EnvBatch():
|
||||
sim.setDiscretizedViewingAngles(True) # Set increment/decrement to 30 degree. (otherwise by radians)
|
||||
sim.setCameraResolution(self.image_w, self.image_h)
|
||||
sim.setCameraVFOV(math.radians(self.vfov))
|
||||
sim.init()
|
||||
sim.initialize()
|
||||
self.sims.append(sim)
|
||||
|
||||
def _make_id(self, scanId, viewpointId):
|
||||
@ -60,7 +60,9 @@ class EnvBatch():
|
||||
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings):
|
||||
for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
|
||||
# print("New episode %d" % i)
|
||||
# sys.stdout.flush()
|
||||
self.sims[i].newEpisode([scanId], [viewpointId], [heading], [0])
|
||||
|
||||
def getStates(self):
|
||||
"""
|
||||
@ -71,11 +73,11 @@ class EnvBatch():
|
||||
"""
|
||||
feature_states = []
|
||||
for i, sim in enumerate(self.sims):
|
||||
state = sim.getState()
|
||||
state = sim.getState()[0]
|
||||
|
||||
long_id = self._make_id(state.scanId, state.location.viewpointId)
|
||||
if self.features:
|
||||
feature = self.features[long_id]
|
||||
feature = self.features[long_id] # Get feature for
|
||||
feature_states.append((feature, state))
|
||||
else:
|
||||
feature_states.append((None, state))
|
||||
@ -87,7 +89,6 @@ class EnvBatch():
|
||||
for i, (index, heading, elevation) in enumerate(actions):
|
||||
self.sims[i].makeAction(index, heading, elevation)
|
||||
|
||||
|
||||
class R2RBatch():
|
||||
''' Implements the Room to Room navigation task, using discretized viewpoints and pretrained features '''
|
||||
|
||||
@ -104,38 +105,35 @@ class R2RBatch():
|
||||
scans = []
|
||||
for split in splits:
|
||||
for i_item, item in enumerate(load_datasets([split])):
|
||||
if args.test_only and i_item == 64:
|
||||
break
|
||||
if "/" in split:
|
||||
try:
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = item['path_id']
|
||||
new_item['instructions'] = item['instructions'][0]
|
||||
new_item['instr_encoding'] = item['instr_enc']
|
||||
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||
self.data.append(new_item)
|
||||
scans.append(item['scan'])
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
# if args.test_only and i_item == 64:
|
||||
# break
|
||||
# Split multiple instructions into separate entries
|
||||
for j, instr in enumerate(item['instructions']):
|
||||
# if item['scan'] not in self.env.featurized_scans: # For fast training
|
||||
# continue
|
||||
try:
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||
new_item['instr_id'] = '%s_%d' % (item['id'], j)
|
||||
new_item['instructions'] = instr
|
||||
|
||||
''' BERT tokenizer '''
|
||||
instr_tokens = tokenizer.tokenize(instr)
|
||||
padded_instr_tokens, num_words = pad_instr_tokens(instr_tokens, args.maxInput)
|
||||
padded_instr_tokens = pad_instr_tokens(instr_tokens, args.maxInput)
|
||||
new_item['instr_encoding'] = tokenizer.convert_tokens_to_ids(padded_instr_tokens)
|
||||
|
||||
# if tokenizer:
|
||||
# new_item['instr_encoding'] = tokenizer.encode_sentence(instr)
|
||||
# if not tokenizer or new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||
self.data.append(new_item)
|
||||
scans.append(item['scan'])
|
||||
except:
|
||||
continue
|
||||
|
||||
# load object features
|
||||
with open('data/BBoxS/REVERIE_obj_feats.pkl', 'rb') as f_obj:
|
||||
self.obj_feats = pkl.load(f_obj)
|
||||
|
||||
if name is None:
|
||||
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
||||
else:
|
||||
@ -228,13 +226,13 @@ class R2RBatch():
|
||||
if long_id not in self.buffered_state_dict:
|
||||
for ix in range(36):
|
||||
if ix == 0:
|
||||
self.sim.newEpisode(scanId, viewpointId, 0, math.radians(-30))
|
||||
self.sim.newEpisode([scanId], [viewpointId], [0], [math.radians(-30)])
|
||||
elif ix % 12 == 0:
|
||||
self.sim.makeAction(0, 1.0, 1.0)
|
||||
self.sim.makeAction([0], [1.0], [1.0])
|
||||
else:
|
||||
self.sim.makeAction(0, 1.0, 0)
|
||||
self.sim.makeAction([0], [1.0], [0])
|
||||
|
||||
state = self.sim.getState()
|
||||
state = self.sim.getState()[0]
|
||||
assert state.viewIndex == ix
|
||||
|
||||
# Heading and elevation for the viewpoint center
|
||||
@ -302,8 +300,22 @@ class R2RBatch():
|
||||
|
||||
# Full features
|
||||
candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)
|
||||
# [visual_feature, angle_feature] for views
|
||||
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
||||
# (visual_feature, angel_feature) for views
|
||||
|
||||
directional_feature = self.angle_feature[base_view_id]
|
||||
feature = np.concatenate((feature, directional_feature), -1)
|
||||
centered_feature = utils.get_centered_visual_features(feature, base_view_id)
|
||||
|
||||
try:
|
||||
obj_local_pos = []; obj_features = []; candidate_objId = []
|
||||
# prepare object features
|
||||
for vis_pos, objects in self.obj_feats[state.scanId][state.location.viewpointId].items():
|
||||
for objId, obj in objects.items():
|
||||
candidate_objId.append(objId)
|
||||
obj_local_pos.append(utils.get_obj_local_pos(obj['boxes'].toarray())) # xyxy
|
||||
obj_features.append(np.concatenate((obj['features'].toarray().squeeze(), directional_feature[int(vis_pos)]), -1))
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
obs.append({
|
||||
'instr_id' : item['instr_id'],
|
||||
@ -312,13 +324,15 @@ class R2RBatch():
|
||||
'viewIndex' : state.viewIndex,
|
||||
'heading' : state.heading,
|
||||
'elevation' : state.elevation,
|
||||
'feature' : feature,
|
||||
'feature' : centered_feature,
|
||||
'candidate': candidate,
|
||||
'navigableLocations' : state.navigableLocations,
|
||||
'instructions' : item['instructions'],
|
||||
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['path_id']
|
||||
'path_id' : item['id'],
|
||||
'objId': str(item['objId']), # target objId
|
||||
'candidate_obj': (obj_local_pos, obj_features, candidate_objId)
|
||||
})
|
||||
if 'instr_encoding' in item:
|
||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||
|
||||
162
r2r_src/eval.py
162
r2r_src/eval.py
@ -10,6 +10,7 @@ import pprint
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
from env import R2RBatch
|
||||
import utils
|
||||
from utils import load_datasets, load_nav_graphs, ndtw_graphload, DTW
|
||||
from agent import BaseAgent
|
||||
|
||||
@ -28,15 +29,28 @@ class Evaluation(object):
|
||||
for item in load_datasets([split]):
|
||||
if scans is not None and item['scan'] not in scans:
|
||||
continue
|
||||
self.gt[str(item['path_id'])] = item
|
||||
self.gt[str(item['id'])] = item
|
||||
self.scans.append(item['scan'])
|
||||
self.instr_ids += ['%s_%d' % (item['path_id'], i) for i in range(len(item['instructions']))]
|
||||
self.instr_ids += ['%s_%d' % (item['id'], i) for i in range(len(item['instructions']))]
|
||||
self.scans = set(self.scans)
|
||||
self.instr_ids = set(self.instr_ids)
|
||||
self.graphs = load_nav_graphs(self.scans)
|
||||
self.distances = {}
|
||||
for scan,G in self.graphs.items(): # compute all shortest paths
|
||||
self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))
|
||||
self.objProposals, self.obj2viewpoint = utils.loadObjProposals()
|
||||
|
||||
# self.ndtw_criterion = {}
|
||||
# scan_gts_dir = '/home/yicong/research/selfmonitoring-agent/tasks/Env-back/data/id_paths.json'
|
||||
# with open(scan_gts_dir) as f_:
|
||||
# self.scan_gts = json.load(f_)
|
||||
# all_scan_ids = []
|
||||
# for key in self.scan_gts:
|
||||
# path_scan_id = self.scan_gts[key][0]
|
||||
# if path_scan_id not in all_scan_ids:
|
||||
# all_scan_ids.append(path_scan_id)
|
||||
# ndtw_graph = ndtw_graphload(path_scan_id)
|
||||
# self.ndtw_criterion[path_scan_id] = DTW(ndtw_graph)
|
||||
|
||||
def _get_nearest(self, scan, goal_id, path):
|
||||
near_id = path[0][0]
|
||||
@ -48,20 +62,21 @@ class Evaluation(object):
|
||||
near_d = d
|
||||
return near_id
|
||||
|
||||
def _score_item(self, instr_id, path):
|
||||
def _score_item(self, instr_id, path, ref_objId):
|
||||
''' Calculate error based on the final position in trajectory, and also
|
||||
the closest position (oracle stopping rule).
|
||||
The path contains [view_id, angle, vofv] '''
|
||||
gt = self.gt[instr_id.split('_')[-2]]
|
||||
gt = self.gt[instr_id[:-2]]
|
||||
start = gt['path'][0]
|
||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||
goal = gt['path'][-1]
|
||||
final_position = path[-1][0] # the first of [view_id, angle, vofv]
|
||||
nearest_position = self._get_nearest(gt['scan'], goal, path)
|
||||
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
||||
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
||||
# self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
||||
# self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
||||
self.scores['trajectory_steps'].append(len(path)-1)
|
||||
distance = 0 # length of the path in meters
|
||||
|
||||
distance = 0 # Work out the length of the path in meters
|
||||
prev = path[0]
|
||||
for curr in path[1:]:
|
||||
distance += self.distances[gt['scan']][prev[0]][curr[0]]
|
||||
@ -71,6 +86,55 @@ class Evaluation(object):
|
||||
self.distances[gt['scan']][start][goal]
|
||||
)
|
||||
|
||||
# REF sucess or not
|
||||
if ref_objId == str(gt['objId']):
|
||||
self.scores['rgs'].append(1)
|
||||
else:
|
||||
self.scores['rgs'].append(0)
|
||||
# navigation - success or not
|
||||
end_viewpoint_id = gt['scan'] + '_' + final_position
|
||||
if self.objProposals.__contains__(end_viewpoint_id):
|
||||
if str(gt['objId']) in self.objProposals[end_viewpoint_id]['objId']:
|
||||
self.scores['visible'].append(1)
|
||||
else:
|
||||
self.scores['visible'].append(0)
|
||||
else:
|
||||
self.scores['visible'].append(0)
|
||||
# navigation - oracle success or not
|
||||
oracle_succ = 0
|
||||
for passvp in path:
|
||||
oracle_viewpoint_id = gt['scan'] + '_' + passvp[0]
|
||||
if self.objProposals.__contains__(oracle_viewpoint_id):
|
||||
if str(gt['objId']) in self.objProposals[oracle_viewpoint_id]['objId']:
|
||||
oracle_succ = 1
|
||||
break
|
||||
self.scores['oracle_visible'].append(oracle_succ)
|
||||
|
||||
|
||||
# # if self.scores['nav_errors'][-1] < self.error_margin:
|
||||
# # print('item', item)
|
||||
# ndtw_path = [k[0] for k in item['trajectory']]
|
||||
# # print('path', ndtw_path)
|
||||
#
|
||||
# path_id = item['instr_id'][:-2]
|
||||
# # print('path id', path_id)
|
||||
# path_scan_id, path_ref = self.scan_gts[path_id]
|
||||
# # print('path_scan_id', path_scan_id)
|
||||
# # print('path_ref', path_ref)
|
||||
#
|
||||
# path_act = []
|
||||
# for jdx, pid in enumerate(ndtw_path):
|
||||
# if jdx != 0:
|
||||
# if pid != path_act[-1]:
|
||||
# path_act.append(pid)
|
||||
# else:
|
||||
# path_act.append(pid)
|
||||
# # print('path act', path_act)
|
||||
#
|
||||
# ndtw_score = self.ndtw_criterion[path_scan_id](path_act, path_ref, metric='ndtw')
|
||||
# ndtw_scores.append(ndtw_score)
|
||||
# print('nDTW score: ', np.average(ndtw_scores))
|
||||
|
||||
def score(self, output_file):
|
||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||
self.scores = defaultdict(list)
|
||||
@ -86,27 +150,89 @@ class Evaluation(object):
|
||||
# Check against expected ids
|
||||
if item['instr_id'] in instr_ids:
|
||||
instr_ids.remove(item['instr_id'])
|
||||
self._score_item(item['instr_id'], item['trajectory'])
|
||||
self._score_item(item['instr_id'], item['trajectory'], item['ref'])
|
||||
|
||||
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
||||
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
||||
% (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file)
|
||||
assert len(self.scores['nav_errors']) == len(self.instr_ids)
|
||||
assert len(self.scores['visible']) == len(self.instr_ids)
|
||||
|
||||
score_summary = {
|
||||
'nav_error': np.average(self.scores['nav_errors']),
|
||||
'oracle_error': np.average(self.scores['oracle_errors']),
|
||||
'steps': np.average(self.scores['trajectory_steps']),
|
||||
'lengths': np.average(self.scores['trajectory_lengths'])
|
||||
}
|
||||
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
|
||||
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
|
||||
oracle_successes = len([i for i in self.scores['oracle_errors'] if i < self.error_margin])
|
||||
score_summary['oracle_rate'] = float(oracle_successes)/float(len(self.scores['oracle_errors']))
|
||||
end_successes = sum(self.scores['visible'])
|
||||
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
|
||||
oracle_successes = sum(self.scores['oracle_visible'])
|
||||
score_summary['oracle_rate'] = float(oracle_successes) / float(len(self.scores['oracle_visible']))
|
||||
|
||||
spl = [float(error < self.error_margin) * l / max(l, p, 0.01)
|
||||
for error, p, l in
|
||||
zip(self.scores['nav_errors'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||
spl = [float(visible == 1) * l / max(l, p, 0.01)
|
||||
for visible, p, l in
|
||||
zip(self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||
]
|
||||
score_summary['spl'] = np.average(spl)
|
||||
|
||||
assert len(self.scores['rgs']) == len(self.instr_ids)
|
||||
num_rgs = sum(self.scores['rgs'])
|
||||
score_summary['rgs'] = float(num_rgs)/float(len(self.scores['rgs']))
|
||||
|
||||
rgspl = [float(rgsi == 1) * l / max(l, p)
|
||||
for rgsi, p, l in
|
||||
zip(self.scores['rgs'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||
]
|
||||
score_summary['rgspl'] = np.average(rgspl)
|
||||
|
||||
return score_summary, self.scores
|
||||
|
||||
def bleu_score(self, path2inst):
|
||||
from bleu import compute_bleu
|
||||
refs = []
|
||||
candidates = []
|
||||
for path_id, inst in path2inst.items():
|
||||
path_id = str(path_id)
|
||||
assert path_id in self.gt
|
||||
# There are three references
|
||||
refs.append([self.tok.split_sentence(sent) for sent in self.gt[path_id]['instructions']])
|
||||
candidates.append([self.tok.index_to_word[word_id] for word_id in inst])
|
||||
|
||||
tuple = compute_bleu(refs, candidates, smooth=False)
|
||||
bleu_score = tuple[0]
|
||||
precisions = tuple[1]
|
||||
|
||||
return bleu_score, precisions
|
||||
|
||||
|
||||
RESULT_DIR = 'tasks/R2R/results/'
|
||||
|
||||
def eval_simple_agents():
|
||||
''' Run simple baselines on each split. '''
|
||||
for split in ['train', 'val_seen', 'val_unseen', 'test']:
|
||||
env = R2RBatch(None, batch_size=1, splits=[split])
|
||||
ev = Evaluation([split])
|
||||
|
||||
for agent_type in ['Stop', 'Shortest', 'Random']:
|
||||
outfile = '%s%s_%s_agent.json' % (RESULT_DIR, split, agent_type.lower())
|
||||
agent = BaseAgent.get_agent(agent_type)(env, outfile)
|
||||
agent.test()
|
||||
agent.write_results()
|
||||
score_summary, _ = ev.score(outfile)
|
||||
print('\n%s' % agent_type)
|
||||
pp.pprint(score_summary)
|
||||
|
||||
|
||||
def eval_seq2seq():
|
||||
''' Eval sequence to sequence models on val splits (iteration selected from training error) '''
|
||||
outfiles = [
|
||||
RESULT_DIR + 'seq2seq_teacher_imagenet_%s_iter_5000.json',
|
||||
RESULT_DIR + 'seq2seq_sample_imagenet_%s_iter_20000.json'
|
||||
]
|
||||
for outfile in outfiles:
|
||||
for split in ['val_seen', 'val_unseen']:
|
||||
ev = Evaluation([split])
|
||||
score_summary, _ = ev.score(outfile % split)
|
||||
print('\n%s' % outfile)
|
||||
pp.pprint(score_summary)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
eval_simple_agents()
|
||||
|
||||
273
r2r_src/model.py
Normal file
273
r2r_src/model.py
Normal file
@ -0,0 +1,273 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
from param import args
|
||||
|
||||
from vlnbert.vlnbert_model import get_vlnbert_models
|
||||
|
||||
class VLNBERT(nn.Module):
|
||||
def __init__(self, directions=4, feature_size=2048+128):
|
||||
super(VLNBERT, self).__init__()
|
||||
print('\nInitalizing the VLN-BERT model ...')
|
||||
|
||||
self.vln_bert = get_vlnbert_models(config=None) # initialize the VLN-BERT
|
||||
self.vln_bert.config.directions = directions
|
||||
|
||||
hidden_size = self.vln_bert.config.hidden_size
|
||||
layer_norm_eps = self.vln_bert.config.layer_norm_eps
|
||||
|
||||
self.action_state_project = nn.Sequential(
|
||||
nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh())
|
||||
self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.obj_pos_encode = nn.Linear(5, args.angle_feat_size, bias=True)
|
||||
self.obj_projection = nn.Linear(feature_size+args.angle_feat_size, hidden_size, bias=True)
|
||||
self.obj_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.drop_env = nn.Dropout(p=args.featdropout)
|
||||
self.img_projection = nn.Linear(feature_size, hidden_size, bias=True)
|
||||
self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.state_proj = nn.Linear(hidden_size*2, hidden_size, bias=True)
|
||||
self.state_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(self, mode, sentence, token_type_ids=None, attention_mask=None,
|
||||
lang_mask=None, vis_mask=None, obj_mask=None,
|
||||
position_ids=None, action_feats=None, pano_feats=None, cand_feats=None,
|
||||
obj_feats=None, obj_pos=None, already_dropfeat=False):
|
||||
|
||||
if mode == 'language':
|
||||
init_state, encoded_sentence = self.vln_bert(mode, sentence, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids, attention_mask=attention_mask, lang_mask=lang_mask)
|
||||
|
||||
return init_state, encoded_sentence
|
||||
|
||||
elif mode == 'visual':
|
||||
|
||||
state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1)
|
||||
state_with_action = self.action_state_project(state_action_embed)
|
||||
state_with_action = self.action_LayerNorm(state_with_action)
|
||||
state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1)
|
||||
|
||||
if not already_dropfeat:
|
||||
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
|
||||
obj_feats[..., :-args.angle_feat_size] = self.drop_env(obj_feats[..., :-args.angle_feat_size])
|
||||
|
||||
cand_feats_embed = self.img_projection(cand_feats) # [2176 * 768] projection
|
||||
cand_feats_embed = self.cand_LayerNorm(cand_feats_embed)
|
||||
|
||||
obj_feats_embed = self.obj_pos_encode(obj_pos)
|
||||
obj_feats_concat = torch.cat((obj_feats[..., :-args.angle_feat_size], obj_feats_embed, obj_feats[..., -args.angle_feat_size:]), dim=-1)
|
||||
obj_feats_embed = self.obj_projection(obj_feats_concat)
|
||||
obj_feats_embed = self.obj_LayerNorm(obj_feats_embed)
|
||||
|
||||
cand_obj_feats_embed = torch.cat((cand_feats_embed, obj_feats_embed), dim=1)
|
||||
|
||||
# logit is the attention scores over the candidate features
|
||||
h_t, logit, logit_obj, attended_visual = self.vln_bert(mode,
|
||||
state_feats, attention_mask=attention_mask,
|
||||
lang_mask=lang_mask, vis_mask=vis_mask, obj_mask=obj_mask,
|
||||
img_feats=cand_obj_feats_embed)
|
||||
|
||||
state_output = torch.cat((h_t, attended_visual), dim=-1)
|
||||
state_proj = self.state_proj(state_output)
|
||||
state_proj = self.state_LayerNorm(state_proj)
|
||||
|
||||
return state_proj, logit, logit_obj
|
||||
|
||||
else:
|
||||
ModuleNotFoundError
|
||||
|
||||
class SoftDotAttention(nn.Module):
|
||||
'''Soft Dot Attention.
|
||||
|
||||
Ref: http://www.aclweb.org/anthology/D15-1166
|
||||
Adapted from PyTorch OPEN NMT.
|
||||
'''
|
||||
|
||||
def __init__(self, query_dim, ctx_dim):
|
||||
'''Initialize layer.'''
|
||||
super(SoftDotAttention, self).__init__()
|
||||
self.linear_in = nn.Linear(query_dim, ctx_dim, bias=False)
|
||||
self.sm = nn.Softmax()
|
||||
self.linear_out = nn.Linear(query_dim + ctx_dim, query_dim, bias=False)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def forward(self, h, context, mask=None,
|
||||
output_tilde=True, output_prob=True, input_project=True):
|
||||
'''Propagate h through the network.
|
||||
|
||||
h: batch x dim
|
||||
context: batch x seq_len x dim
|
||||
mask: batch x seq_len indices to be masked
|
||||
'''
|
||||
if input_project:
|
||||
target = self.linear_in(h).unsqueeze(2) # batch x dim x 1
|
||||
else:
|
||||
target = h.unsqueeze(2) # batch x dim x 1
|
||||
|
||||
# Get attention
|
||||
attn = torch.bmm(context, target).squeeze(2) # batch x seq_len
|
||||
logit = attn
|
||||
|
||||
if mask is not None:
|
||||
# -Inf masking prior to the softmax
|
||||
attn.masked_fill_(mask, -float('inf'))
|
||||
attn = self.sm(attn) # There will be a bug here, but it's actually a problem in torch source code.
|
||||
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x seq_len
|
||||
|
||||
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
|
||||
if not output_prob:
|
||||
attn = logit
|
||||
if output_tilde:
|
||||
h_tilde = torch.cat((weighted_context, h), 1)
|
||||
h_tilde = self.tanh(self.linear_out(h_tilde))
|
||||
return h_tilde, attn
|
||||
else:
|
||||
return weighted_context, attn
|
||||
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class AttnDecoderLSTM(nn.Module):
|
||||
''' An unrolled LSTM with attention over instructions for decoding navigation actions. '''
|
||||
|
||||
def __init__(self, hidden_size, dropout_ratio, feature_size=2048+4):
|
||||
super(AttnDecoderLSTM, self).__init__()
|
||||
self.drop = nn.Dropout(p=dropout_ratio)
|
||||
self.drop_env = nn.Dropout(p=args.featdropout)
|
||||
self.candidate_att_layer = SoftDotAttention(768, feature_size) # 768 is the output feature dimension from BERT
|
||||
|
||||
def forward(self, h_t, cand_feat,
|
||||
already_dropfeat=False):
|
||||
|
||||
if not already_dropfeat:
|
||||
cand_feat[..., :-args.angle_feat_size] = self.drop_env(cand_feat[..., :-args.angle_feat_size])
|
||||
|
||||
_, logit = self.candidate_att_layer(h_t, cand_feat, output_prob=False)
|
||||
|
||||
return logit
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self):
|
||||
super(Critic, self).__init__()
|
||||
self.state2value = nn.Sequential(
|
||||
nn.Linear(768, args.rnn_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(args.dropout),
|
||||
nn.Linear(args.rnn_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
return self.state2value(state).squeeze()
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
def __init__(self, feature_size, hidden_size, dropout_ratio, bidirectional):
|
||||
super().__init__()
|
||||
self.num_directions = 2 if bidirectional else 1
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = 1
|
||||
self.feature_size = feature_size
|
||||
|
||||
if bidirectional:
|
||||
print("BIDIR in speaker encoder!!")
|
||||
|
||||
self.lstm = nn.LSTM(feature_size, self.hidden_size // self.num_directions, self.num_layers,
|
||||
batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional)
|
||||
self.drop = nn.Dropout(p=dropout_ratio)
|
||||
self.drop3 = nn.Dropout(p=args.featdropout)
|
||||
self.attention_layer = SoftDotAttention(self.hidden_size, feature_size)
|
||||
|
||||
self.post_lstm = nn.LSTM(self.hidden_size, self.hidden_size // self.num_directions, self.num_layers,
|
||||
batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional)
|
||||
|
||||
def forward(self, action_embeds, feature, lengths, already_dropfeat=False):
|
||||
"""
|
||||
:param action_embeds: (batch_size, length, 2052). The feature of the view
|
||||
:param feature: (batch_size, length, 36, 2052). The action taken (with the image feature)
|
||||
:param lengths: Not used in it
|
||||
:return: context with shape (batch_size, length, hidden_size)
|
||||
"""
|
||||
x = action_embeds
|
||||
if not already_dropfeat:
|
||||
x[..., :-args.angle_feat_size] = self.drop3(x[..., :-args.angle_feat_size]) # Do not dropout the spatial features
|
||||
|
||||
# LSTM on the action embed
|
||||
ctx, _ = self.lstm(x)
|
||||
ctx = self.drop(ctx)
|
||||
|
||||
# Att and Handle with the shape
|
||||
batch_size, max_length, _ = ctx.size()
|
||||
if not already_dropfeat:
|
||||
feature[..., :-args.angle_feat_size] = self.drop3(feature[..., :-args.angle_feat_size]) # Dropout the image feature
|
||||
x, _ = self.attention_layer( # Attend to the feature map
|
||||
ctx.contiguous().view(-1, self.hidden_size), # (batch, length, hidden) --> (batch x length, hidden)
|
||||
feature.view(batch_size * max_length, -1, self.feature_size), # (batch, length, # of images, feature_size) --> (batch x length, # of images, feature_size)
|
||||
)
|
||||
x = x.view(batch_size, max_length, -1)
|
||||
x = self.drop(x)
|
||||
|
||||
# Post LSTM layer
|
||||
x, _ = self.post_lstm(x)
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
class SpeakerDecoder(nn.Module):
|
||||
def __init__(self, vocab_size, embedding_size, padding_idx, hidden_size, dropout_ratio):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx)
|
||||
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
|
||||
self.drop = nn.Dropout(dropout_ratio)
|
||||
self.attention_layer = SoftDotAttention(hidden_size, hidden_size)
|
||||
self.projection = nn.Linear(hidden_size, vocab_size)
|
||||
self.baseline_projection = nn.Sequential(
|
||||
nn.Linear(hidden_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_ratio),
|
||||
nn.Linear(128, 1)
|
||||
)
|
||||
|
||||
def forward(self, words, ctx, ctx_mask, h0, c0):
|
||||
embeds = self.embedding(words)
|
||||
embeds = self.drop(embeds)
|
||||
x, (h1, c1) = self.lstm(embeds, (h0, c0))
|
||||
|
||||
x = self.drop(x)
|
||||
|
||||
# Get the size
|
||||
batchXlength = words.size(0) * words.size(1)
|
||||
multiplier = batchXlength // ctx.size(0) # By using this, it also supports the beam-search
|
||||
|
||||
# Att and Handle with the shape
|
||||
# Reshaping x <the output> --> (b(word)*l(word), r)
|
||||
# Expand the ctx from (b, a, r) --> (b(word)*l(word), a, r)
|
||||
# Expand the ctx_mask (b, a) --> (b(word)*l(word), a)
|
||||
x, _ = self.attention_layer(
|
||||
x.contiguous().view(batchXlength, self.hidden_size),
|
||||
ctx.unsqueeze(1).expand(-1, multiplier, -1, -1).contiguous(). view(batchXlength, -1, self.hidden_size),
|
||||
mask=ctx_mask.unsqueeze(1).expand(-1, multiplier, -1).contiguous().view(batchXlength, -1)
|
||||
)
|
||||
x = x.view(words.size(0), words.size(1), self.hidden_size)
|
||||
|
||||
# Output the prediction logit
|
||||
x = self.drop(x)
|
||||
logit = self.projection(x)
|
||||
|
||||
return logit, h1, c1
|
||||
@ -7,45 +7,58 @@ class Param:
|
||||
self.parser = argparse.ArgumentParser(description="")
|
||||
|
||||
# General
|
||||
self.parser.add_argument('--test_only', type=int, default=0, help='fast mode for testing')
|
||||
|
||||
self.parser.add_argument('--iters', type=int, default=300000, help='training iterations')
|
||||
self.parser.add_argument('--name', type=str, default='default', help='experiment id')
|
||||
self.parser.add_argument('--vlnbert', type=str, default='oscar', help='oscar or prevalent')
|
||||
self.parser.add_argument('--train', type=str, default='listener')
|
||||
self.parser.add_argument('--iters', type=int, default=100000)
|
||||
self.parser.add_argument('--name', type=str, default='default')
|
||||
self.parser.add_argument('--train', type=str, default='speaker')
|
||||
self.parser.add_argument("--load_pretrain", default=None)
|
||||
# --load_pretrain snap/vlnp-v3.0.0/state_dict/Iter_100000
|
||||
self.parser.add_argument('--test_only', type=int, default=0)
|
||||
self.parser.add_argument('--description', type=str, default='no description\n')
|
||||
|
||||
# Data preparation
|
||||
self.parser.add_argument('--maxInput', type=int, default=80, help="max input instruction")
|
||||
self.parser.add_argument('--maxAction', type=int, default=15, help='Max Action sequence')
|
||||
self.parser.add_argument('--batchSize', type=int, default=8)
|
||||
# self.parser.add_argument('--maxDecode', type=int, default=120, help="max input instruction")
|
||||
self.parser.add_argument('--maxAction', type=int, default=20, help='Max Action sequence')
|
||||
self.parser.add_argument('--batchSize', type=int, default=64)
|
||||
self.parser.add_argument('--ignoreid', type=int, default=-100)
|
||||
self.parser.add_argument('--directions', type=int, default=4, help='agent-centered visual directions') # fix to 4 for now
|
||||
self.parser.add_argument('--feature_size', type=int, default=2048)
|
||||
self.parser.add_argument("--loadOptim",action="store_const", default=False, const=True)
|
||||
|
||||
# Load the model from
|
||||
self.parser.add_argument("--load", default=None, help='path of the trained model')
|
||||
self.parser.add_argument("--speaker", default=None)
|
||||
self.parser.add_argument("--listener", default=None)
|
||||
self.parser.add_argument("--load", default=None)
|
||||
# snap/vlnb-v3.0.0/state_dict/Iter_100000
|
||||
|
||||
# Augmented Paths from
|
||||
# More Paths from
|
||||
self.parser.add_argument("--aug", default=None)
|
||||
|
||||
# Listener Model Config
|
||||
self.parser.add_argument("--zeroInit", dest='zero_init', action='store_const', default=False, const=True)
|
||||
self.parser.add_argument("--mlWeight", dest='ml_weight', type=float, default=0.20)
|
||||
self.parser.add_argument("--mlWeight", dest='ml_weight', type=float, default=0.05)
|
||||
self.parser.add_argument("--teacherWeight", dest='teacher_weight', type=float, default=1.)
|
||||
self.parser.add_argument("--features", type=str, default='places365')
|
||||
self.parser.add_argument("--accumulateGrad", dest='accumulate_grad', type=int, default=0)
|
||||
self.parser.add_argument("--features", type=str, default='imagenet')
|
||||
|
||||
# Dropout Param
|
||||
self.parser.add_argument('--dropout', type=float, default=0.5)
|
||||
# Env Dropout Param
|
||||
self.parser.add_argument('--featdropout', type=float, default=0.3)
|
||||
|
||||
# SSL configuration
|
||||
self.parser.add_argument("--selfTrain", dest='self_train', action='store_const', default=False, const=True)
|
||||
|
||||
# Submision configuration
|
||||
self.parser.add_argument("--submit", type=int, default=0)
|
||||
self.parser.add_argument("--candidates", type=int, default=1)
|
||||
self.parser.add_argument("--paramSearch", dest='param_search', action='store_const', default=False, const=True)
|
||||
self.parser.add_argument("--submit", action='store_const', default=False, const=True)
|
||||
self.parser.add_argument("--beam", action="store_const", default=False, const=True)
|
||||
self.parser.add_argument("--alpha", type=float, default=0.5)
|
||||
|
||||
# Training Configurations
|
||||
self.parser.add_argument('--optim', type=str, default='rms') # rms, adam
|
||||
self.parser.add_argument('--lr', type=float, default=0.00001, help="the learning rate")
|
||||
self.parser.add_argument('--lr', type=float, default=0.0001, help="The learning rate")
|
||||
self.parser.add_argument('--decay', dest='weight_decay', type=float, default=0.)
|
||||
self.parser.add_argument('--dropout', type=float, default=0.5)
|
||||
self.parser.add_argument('--feedback', type=str, default='sample',
|
||||
help='How to choose next position, one of ``teacher``, ``sample`` and ``argmax``')
|
||||
self.parser.add_argument('--teacher', type=str, default='final',
|
||||
@ -53,6 +66,20 @@ class Param:
|
||||
self.parser.add_argument('--epsilon', type=float, default=0.1)
|
||||
|
||||
# Model hyper params:
|
||||
self.parser.add_argument('--rnnDim', dest="rnn_dim", type=int, default=512)
|
||||
self.parser.add_argument('--wemb', type=int, default=256)
|
||||
self.parser.add_argument('--aemb', type=int, default=64)
|
||||
self.parser.add_argument('--proj', type=int, default=512)
|
||||
self.parser.add_argument("--fast", dest="fast_train", action="store_const", default=False, const=True)
|
||||
self.parser.add_argument("--valid", action="store_const", default=False, const=True)
|
||||
self.parser.add_argument("--candidate", dest="candidate_mask",
|
||||
action="store_const", default=False, const=True)
|
||||
|
||||
self.parser.add_argument("--bidir", type=bool, default=True) # This is not full option
|
||||
self.parser.add_argument("--encode", type=str, default="word") # sub, word, sub_ctx
|
||||
self.parser.add_argument("--subout", dest="sub_out", type=str, default="tanh") # tanh, max
|
||||
self.parser.add_argument("--attn", type=str, default="soft") # soft, mono, shift, dis_shift
|
||||
|
||||
self.parser.add_argument("--angleFeatSize", dest="angle_feat_size", type=int, default=4)
|
||||
|
||||
# A2C
|
||||
@ -78,11 +105,16 @@ class Param:
|
||||
|
||||
param = Param()
|
||||
args = param.args
|
||||
args.TRAIN_VOCAB = 'tasks/R2R/data/train_vocab.txt'
|
||||
args.TRAINVAL_VOCAB = 'tasks/R2R/data/trainval_vocab.txt'
|
||||
|
||||
args.description = args.name
|
||||
args.IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
|
||||
args.CANDIDATE_FEATURES = 'img_features/ResNet-152-candidate.tsv'
|
||||
args.features_fast = 'img_features/ResNet-152-imagenet-fast.tsv'
|
||||
args.log_dir = 'snap/%s' % args.name
|
||||
|
||||
args.directions = args.directions * 3 # times 3 for up, horizon and bottom
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.makedirs(args.log_dir)
|
||||
DEBUG_FILE = open(os.path.join('snap', args.name, "debug.log"), 'w')
|
||||
|
||||
@ -6,8 +6,10 @@ import json
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
# from speaker import Speaker
|
||||
|
||||
from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress
|
||||
# from utils import Tokenizer
|
||||
import utils
|
||||
from env import R2RBatch
|
||||
from agent import Seq2SeqAgent
|
||||
@ -18,12 +20,15 @@ import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from vlnbert.vlnbert_init import get_tokenizer
|
||||
from vlnbert.vlnbert_model import get_tokenizer
|
||||
|
||||
log_dir = 'snap/%s' % args.name
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
TRAIN_VOCAB = 'data/train_vocab.txt'
|
||||
TRAINVAL_VOCAB = 'data/trainval_vocab.txt'
|
||||
|
||||
IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
|
||||
PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv'
|
||||
|
||||
@ -38,7 +43,7 @@ print(args); print('')
|
||||
|
||||
|
||||
''' train the listener '''
|
||||
def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
def train(train_env, tok, n_iters, log_every=1000, val_envs={}, aug_env=None):
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)
|
||||
|
||||
@ -54,6 +59,10 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
else:
|
||||
load_iter = listner.load(os.path.join(args.load))
|
||||
print("\nLOAD the model from {}, iteration ".format(args.load, load_iter))
|
||||
# elif args.load_pretrain is not None:
|
||||
# print("LOAD the pretrained model from %s" % args.load_pretrain)
|
||||
# listner.load_pretrain(os.path.join(args.load_pretrain))
|
||||
# print("Pretrained model loaded\n")
|
||||
|
||||
start = time.time()
|
||||
print('\nListener training starts, start iteration: %s' % str(start_iter))
|
||||
@ -66,9 +75,13 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
iter = idx + interval
|
||||
|
||||
# Train for log_every interval
|
||||
if aug_env is None:
|
||||
if aug_env is None: # The default training process
|
||||
listner.env = train_env
|
||||
listner.train(interval, feedback=feedback_method) # Train interval iters
|
||||
print('-----------default training process no accumulate_grad')
|
||||
else:
|
||||
if args.accumulate_grad: # default False
|
||||
None
|
||||
else:
|
||||
jdx_length = len(range(interval // 2))
|
||||
for jdx in range(interval // 2):
|
||||
@ -87,25 +100,26 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
# Log the training stats to tensorboard
|
||||
total = max(sum(listner.logs['total']), 1)
|
||||
length = max(len(listner.logs['critic_loss']), 1)
|
||||
critic_loss = sum(listner.logs['critic_loss']) / total
|
||||
RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
|
||||
IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
|
||||
entropy = sum(listner.logs['entropy']) / total
|
||||
critic_loss = sum(listner.logs['critic_loss']) / total #/ length / args.batchSize
|
||||
entropy = sum(listner.logs['entropy']) / total #/ length / args.batchSize
|
||||
predict_loss = sum(listner.logs['us_loss']) / max(len(listner.logs['us_loss']), 1)
|
||||
writer.add_scalar("loss/critic", critic_loss, idx)
|
||||
writer.add_scalar("policy_entropy", entropy, idx)
|
||||
writer.add_scalar("loss/RL_loss", RL_loss, idx)
|
||||
writer.add_scalar("loss/IL_loss", IL_loss, idx)
|
||||
writer.add_scalar("loss/unsupervised", predict_loss, idx)
|
||||
writer.add_scalar("total_actions", total, idx)
|
||||
writer.add_scalar("max_length", length, idx)
|
||||
# print("total_actions", total, ", max_length", length)
|
||||
print("total_actions", total, ", max_length", length)
|
||||
|
||||
# Run validation
|
||||
loss_str = "iter {}".format(iter)
|
||||
for env_name, (env, evaluator) in val_envs.items():
|
||||
listner.env = env
|
||||
|
||||
# Get validation loss under the same conditions as training
|
||||
iters = None if args.fast_train or env_name != 'train' else 20 # 20 * 64 = 1280
|
||||
|
||||
# Get validation distance from goal under test evaluation conditions
|
||||
listner.test(use_dropout=False, feedback='argmax', iters=None)
|
||||
listner.test(use_dropout=False, feedback='argmax', iters=iters)
|
||||
result = listner.get_results()
|
||||
score_summary, _ = evaluator.score(result)
|
||||
loss_str += ", %s " % env_name
|
||||
@ -125,6 +139,7 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
record_file.write(loss_str + '\n')
|
||||
record_file.close()
|
||||
|
||||
|
||||
for env_name in best_val:
|
||||
if best_val[env_name]['update']:
|
||||
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
|
||||
@ -175,16 +190,26 @@ def valid(train_env, tok, val_envs={}):
|
||||
sort_keys=True, indent=4, separators=(',', ': ')
|
||||
)
|
||||
|
||||
|
||||
def setup():
|
||||
torch.manual_seed(1)
|
||||
torch.cuda.manual_seed(1)
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
# Check for vocabs
|
||||
if not os.path.exists(TRAIN_VOCAB):
|
||||
write_vocab(build_vocab(splits=['train']), TRAIN_VOCAB)
|
||||
# if not os.path.exists(TRAINVAL_VOCAB):
|
||||
# write_vocab(build_vocab(splits=['train','val_seen','val_unseen']), TRAINVAL_VOCAB)
|
||||
|
||||
|
||||
def train_val(test_only=False):
|
||||
''' Train on the training set, and validate on seen and unseen splits. '''
|
||||
setup()
|
||||
tok = get_tokenizer(args)
|
||||
# Create a batch training environment that will also preprocess text
|
||||
vocab = read_vocab(TRAIN_VOCAB)
|
||||
# tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)
|
||||
tok = get_tokenizer()
|
||||
|
||||
feat_dict = read_img_features(features, test_only=test_only)
|
||||
|
||||
@ -193,15 +218,18 @@ def train_val(test_only=False):
|
||||
val_env_names = ['val_train_seen']
|
||||
else:
|
||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
val_env_names = ['val_seen', 'val_unseen']
|
||||
# val_env_names = ['val_unseen']
|
||||
|
||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
if args.submit:
|
||||
val_env_names.append('test')
|
||||
else:
|
||||
pass
|
||||
#val_env_names.append('train')
|
||||
|
||||
val_envs = OrderedDict(
|
||||
((split,
|
||||
@ -226,7 +254,8 @@ def train_val_augment(test_only=False):
|
||||
setup()
|
||||
|
||||
# Create a batch training environment that will also preprocess text
|
||||
tok_bert = get_tokenizer(args)
|
||||
vocab = read_vocab(TRAIN_VOCAB)
|
||||
tok_bert = get_tokenizer()
|
||||
|
||||
# Load the env img features
|
||||
feat_dict = read_img_features(features, test_only=test_only)
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
sys.path.append('Matterport_Simulator/build/')
|
||||
import MatterSim
|
||||
import string
|
||||
import json
|
||||
@ -69,10 +68,9 @@ def load_datasets(splits):
|
||||
# if split in ['train', 'val_seen', 'val_unseen', 'test',
|
||||
# 'val_unseen_half1', 'val_unseen_half2', 'val_seen_half1', 'val_seen_half2']: # Add two halves for sanity check
|
||||
if "/" not in split:
|
||||
with open('data/R2R_%s.json' % split) as f:
|
||||
with open('data/REVERIE_%s.json' % split) as f:
|
||||
new_data = json.load(f)
|
||||
else:
|
||||
print('\nLoading prevalent data for pretraining...')
|
||||
with open(split) as f:
|
||||
new_data = json.load(f)
|
||||
|
||||
@ -97,12 +95,11 @@ def pad_instr_tokens(instr_tokens, maxlength=20):
|
||||
instr_tokens = instr_tokens[:(maxlength-2)]
|
||||
|
||||
instr_tokens = ['[CLS]'] + instr_tokens + ['[SEP]']
|
||||
num_words = len(instr_tokens) # - 1 # include [SEP]
|
||||
instr_tokens += ['[PAD]'] * (maxlength-len(instr_tokens))
|
||||
|
||||
assert len(instr_tokens) == maxlength
|
||||
|
||||
return instr_tokens, num_words
|
||||
return instr_tokens
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
@ -127,7 +124,6 @@ class Tokenizer(object):
|
||||
assert self.vocab_size() == old+1
|
||||
print("OLD_VOCAB_SIZE", old)
|
||||
print("VOCAB_SIZE", self.vocab_size())
|
||||
print("VOACB", len(vocab))
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
@ -253,7 +249,7 @@ def read_img_features(feature_store, test_only=False):
|
||||
import base64
|
||||
from tqdm import tqdm
|
||||
|
||||
print("Start loading the image feature ... (~50 seconds)")
|
||||
print("Start loading the image feature ... (~30 seconds)")
|
||||
start = time.time()
|
||||
|
||||
if "detectfeat" in args.features:
|
||||
@ -348,7 +344,7 @@ def new_simulator():
|
||||
sim.setCameraResolution(WIDTH, HEIGHT)
|
||||
sim.setCameraVFOV(math.radians(VFOV))
|
||||
sim.setDiscretizedViewingAngles(True)
|
||||
sim.init()
|
||||
sim.initialize()
|
||||
|
||||
return sim
|
||||
|
||||
@ -359,13 +355,13 @@ def get_point_angle_feature(baseViewId=0):
|
||||
base_heading = (baseViewId % 12) * math.radians(30)
|
||||
for ix in range(36):
|
||||
if ix == 0:
|
||||
sim.newEpisode('ZMojNkEp431', '2f4d90acd4024c269fb0efe49a8ac540', 0, math.radians(-30))
|
||||
sim.newEpisode(['ZMojNkEp431'], ['2f4d90acd4024c269fb0efe49a8ac540'], [0], [math.radians(-30)])
|
||||
elif ix % 12 == 0:
|
||||
sim.makeAction(0, 1.0, 1.0)
|
||||
sim.makeAction([0], [1.0], [1.0])
|
||||
else:
|
||||
sim.makeAction(0, 1.0, 0)
|
||||
sim.makeAction([0], [1.0], [0])
|
||||
|
||||
state = sim.getState()
|
||||
state = sim.getState()[0]
|
||||
assert state.viewIndex == ix
|
||||
|
||||
heading = state.heading - base_heading
|
||||
@ -376,6 +372,32 @@ def get_point_angle_feature(baseViewId=0):
|
||||
def get_all_point_angle_feature():
|
||||
return [get_point_angle_feature(baseViewId) for baseViewId in range(36)]
|
||||
|
||||
def get_centered_visual_features(features, baseViewId):
|
||||
# [0-11 up, 12-23 horizon, 24-35 down]
|
||||
centered_features = np.concatenate((features[24:,:], features[:24,:]), 0)
|
||||
|
||||
baseviewid = baseViewId % 12
|
||||
|
||||
viewid_up = [(baseviewid+delta_viewid)%12 for delta_viewid in [0,3,6,9]]
|
||||
viewid_horizon = [id_+12 for id_ in viewid_up]
|
||||
viewid_down = [id_+12 for id_ in viewid_horizon]
|
||||
|
||||
views_up = centered_features[viewid_up, :]
|
||||
views_horizon = centered_features[viewid_horizon, :]
|
||||
views_down = centered_features[viewid_down, :]
|
||||
|
||||
centered_features = np.concatenate((views_up, views_horizon, views_down), 0) # [12, 2176]
|
||||
|
||||
return centered_features
|
||||
|
||||
def get_obj_local_pos(raw_obj_pos):
|
||||
x1, y1, x2, y2 = raw_obj_pos[0]
|
||||
w = x2 - x1; h = y2 - y1
|
||||
assert (w>0) and (h>0)
|
||||
|
||||
obj_local_pos = np.array([x1/640, y1/480, x2/640, y2/480, w*h/(640*480)])
|
||||
return obj_local_pos
|
||||
|
||||
def add_idx(inst):
|
||||
toks = Tokenizer.split_sentence(inst)
|
||||
return " ".join([str(idx)+tok for idx, tok in enumerate(toks)])
|
||||
@ -560,7 +582,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt
|
||||
str_format = "{0:." + str(decimals) + "f}"
|
||||
percents = str_format.format(100 * (iteration / float(total)))
|
||||
filled_length = int(round(bar_length * iteration / float(total)))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
bar = 'LL' * filled_length + '-' * (bar_length - filled_length)
|
||||
|
||||
sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
|
||||
|
||||
@ -570,7 +592,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt
|
||||
|
||||
def ndtw_initialize():
|
||||
ndtw_criterion = {}
|
||||
scan_gts_dir = 'data/id_paths.json'
|
||||
scan_gts_dir = './data/id_paths.json'
|
||||
with open(scan_gts_dir) as f_:
|
||||
scan_gts = json.load(f_)
|
||||
all_scan_ids = []
|
||||
@ -672,3 +694,45 @@ class DTW(object):
|
||||
|
||||
success = self.distance[prediction[-1]][reference[-1]] <= self.threshold
|
||||
return success * ndtw
|
||||
|
||||
import os.path as osp
|
||||
def loadObjProposals():
|
||||
bboxDir = 'data/BBox'
|
||||
objProposals = {}
|
||||
obj2viewpoint = {}
|
||||
|
||||
for efile in os.listdir(bboxDir):
|
||||
if efile.endswith('.json'):
|
||||
with open(osp.join(bboxDir, efile)) as f:
|
||||
scan = efile.split('_')[0]
|
||||
scanvp, _ = efile.split('.')
|
||||
data = json.load(f)
|
||||
|
||||
# for a viewpoint (for loop not needed)
|
||||
for vp, vv in data.items():
|
||||
# for all visible objects at that viewpoint
|
||||
for objid, objinfo in vv.items():
|
||||
|
||||
if objinfo['visible_pos']:
|
||||
# if such object not already in the dict
|
||||
if obj2viewpoint.__contains__(scan+'_'+objid):
|
||||
if vp not in obj2viewpoint[scan+'_'+objid]:
|
||||
obj2viewpoint[scan+'_'+objid].append(vp)
|
||||
else:
|
||||
obj2viewpoint[scan+'_'+objid] = [vp,]
|
||||
|
||||
# if such object not already in the dict
|
||||
if objProposals.__contains__(scanvp):
|
||||
for ii, bbox in enumerate(objinfo['bbox2d']):
|
||||
objProposals[scanvp]['bbox'].append(bbox)
|
||||
objProposals[scanvp]['visible_pos'].append(objinfo['visible_pos'][ii])
|
||||
objProposals[scanvp]['objId'].append(objid)
|
||||
|
||||
else:
|
||||
objProposals[scanvp] = {'bbox': objinfo['bbox2d'],
|
||||
'visible_pos': objinfo['visible_pos']}
|
||||
objProposals[scanvp]['objId'] = []
|
||||
for _ in objinfo['visible_pos']:
|
||||
objProposals[scanvp]['objId'].append(objid)
|
||||
|
||||
return objProposals, obj2viewpoint
|
||||
|
||||
404
r2r_src/vlnbert/modeling_bert.py
Normal file
404
r2r_src/vlnbert/modeling_bert.py
Normal file
@ -0,0 +1,404 @@
|
||||
# Copyright (c) 2020 Microsoft Corporation. Licensed under the MIT license.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import logging
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
import sys
|
||||
|
||||
from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
|
||||
BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
||||
BertSelfOutput, BertIntermediate, BertOutput,
|
||||
BertPooler, BertLayerNorm, BertPreTrainedModel,
|
||||
BertPredictionHeadTransform)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CaptionBertSelfAttention(BertSelfAttention):
|
||||
"""
|
||||
Modified from BertSelfAttention to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertSelfAttention, self).__init__(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None,
|
||||
history_state=None):
|
||||
if history_state is not None:
|
||||
x_states = torch.cat([history_state, hidden_states], dim=1)
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(x_states)
|
||||
mixed_value_layer = self.value(x_states)
|
||||
else: # default
|
||||
mixed_query_layer = self.query(hidden_states) # [24, 95, 768]
|
||||
mixed_key_layer = self.key(hidden_states) # [24, 95, 768]
|
||||
mixed_value_layer = self.value(hidden_states) # [24, 95, 768]
|
||||
|
||||
# transpose into shape [24, 12, 95, 64] as [batch_size, num_heads, seq_length, feature_dim]
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [24, 12, 95, 95]
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs) # [24, 12, 95, 95]
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer) # [24, 12, 95, 64]
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape) # [24, 95, 768]
|
||||
|
||||
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CaptionBertAttention(BertAttention):
|
||||
"""
|
||||
Modified from BertAttention to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertAttention, self).__init__(config)
|
||||
self.self = CaptionBertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask, head_mask=None,
|
||||
history_state=None):
|
||||
''' transformer processing '''
|
||||
self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state)
|
||||
''' feed-forward network with residule '''
|
||||
attention_output = self.output(self_outputs[0], input_tensor)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CaptionBertEncoder(BertEncoder):
|
||||
"""
|
||||
Modified from BertEncoder to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertEncoder, self).__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
# 12 Bert layers
|
||||
self.layer = nn.ModuleList([CaptionBertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None,
|
||||
encoder_history_states=None):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
|
||||
# iterate over the 12 Bert layers
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# if self.output_hidden_states: # default False
|
||||
# all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
history_state = None if encoder_history_states is None else encoder_history_states[i] # default None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, attention_mask, head_mask[i],
|
||||
history_state)
|
||||
hidden_states = layer_outputs[0] # the output features [24, 95, 768]
|
||||
|
||||
# if self.output_attentions: # default False
|
||||
# all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
# if self.output_hidden_states: # default False
|
||||
# all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
# if self.output_hidden_states: # default False
|
||||
# outputs = outputs + (all_hidden_states,)
|
||||
# if self.output_attentions: # default False
|
||||
# outputs = outputs + (all_attentions,)
|
||||
|
||||
return outputs # outputs, (hidden states), (attentions)
|
||||
|
||||
|
||||
class CaptionBertLayer(BertLayer):
|
||||
"""
|
||||
Modified from BertLayer to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertLayer, self).__init__(config)
|
||||
self.attention = CaptionBertAttention(config) # one layer of transformer
|
||||
self.intermediate = BertIntermediate(config) # [768 * 3072]
|
||||
self.output = BertOutput(config) # [3072 * 768]
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None,
|
||||
history_state=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask,
|
||||
head_mask, history_state)
|
||||
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class BertImgModel(BertPreTrainedModel):
|
||||
""" Expand from BertModel to handle image region features as input
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertImgModel, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = CaptionBertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.img_dim = config.img_feature_dim
|
||||
logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
|
||||
# self.img_feature_type = config.img_feature_type
|
||||
# if hasattr(config, 'use_img_layernorm'):
|
||||
# self.use_img_layernorm = config.use_img_layernorm
|
||||
# else:
|
||||
# self.use_img_layernorm = None
|
||||
|
||||
# if config.img_feature_type == 'dis_code':
|
||||
# self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
|
||||
# self.img_embedding = nn.Linear(config.code_dim, self.config.hidden_size, bias=True)
|
||||
# elif config.img_feature_type == 'dis_code_t': # transpose
|
||||
# self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
|
||||
# self.img_embedding = nn.Linear(config.code_size, self.config.hidden_size, bias=True)
|
||||
# elif config.img_feature_type == 'dis_code_scale': # scaled
|
||||
# self.input_embeddings = nn.Linear(config.code_dim, config.code_size, bias=True)
|
||||
# self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
|
||||
# self.img_embedding = nn.Linear(config.code_dim, self.config.hidden_size, bias=True)
|
||||
# else:
|
||||
self.img_projection = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
# if self.use_img_layernorm:
|
||||
# self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.img_layer_norm_eps)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
# def _resize_token_embeddings(self, new_num_tokens):
|
||||
# old_embeddings = self.embeddings.word_embeddings
|
||||
# new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
# self.embeddings.word_embeddings = new_embeddings
|
||||
# return self.embeddings.word_embeddings
|
||||
|
||||
# def _prune_heads(self, heads_to_prune):
|
||||
# """ Prunes heads of the model.
|
||||
# heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
# See base class PreTrainedModel
|
||||
# """
|
||||
# for layer, heads in heads_to_prune.items():
|
||||
# self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, head_mask=None, img_feats=None,
|
||||
encoder_history_states=None):
|
||||
# if attention_mask is None:
|
||||
# attention_mask = torch.ones_like(input_ids)
|
||||
# if token_type_ids is None:
|
||||
# token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
if attention_mask.dim() == 2:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
elif attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
# if head_mask is not None:
|
||||
# if head_mask.dim() == 1:
|
||||
# head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
# head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) # 12 heads
|
||||
# elif head_mask.dim() == 2:
|
||||
# head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
# # switch to float if needed + fp16 compatibility
|
||||
# head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
# else:
|
||||
# head_mask = [None] * self.config.num_hidden_layers # 12 heads
|
||||
head_mask = [None] * self.config.num_hidden_layers # 12 heads
|
||||
|
||||
# language embeddings [24, 55, 768]
|
||||
# embedding_output = self.embeddings(input_ids, position_ids=position_ids,
|
||||
# token_type_ids=token_type_ids)
|
||||
language_features = input_ids
|
||||
# if encoder_history_states:
|
||||
# assert img_feats is None, "Cannot take image features while using encoder history states"
|
||||
|
||||
# if img_feats is not None:
|
||||
# if self.img_feature_type == 'dis_code':
|
||||
# code_emb = self.code_embeddings(img_feats)
|
||||
# img_embedding_output = self.img_embedding(code_emb)
|
||||
# elif self.img_feature_type == 'dis_code_t': # transpose
|
||||
# code_emb = self.code_embeddings(img_feats)
|
||||
# code_emb = code_emb.permute(0, 2, 1)
|
||||
# img_embedding_output = self.img_embedding(code_emb)
|
||||
# elif self.img_feature_type == 'dis_code_scale': # left scaled
|
||||
# code_emb = self.code_embeddings(img_feats)
|
||||
# img_embedding_output = self.img_embedding(code_emb)
|
||||
# else: # faster r-cnn
|
||||
img_embedding_output = self.img_projection(img_feats) # [2054 * 768] projection
|
||||
# if self.use_img_layernorm:
|
||||
# img_embedding_output = self.LayerNorm(img_embedding_output)
|
||||
# add dropout on image embedding
|
||||
img_embedding_output = self.dropout(img_embedding_output)
|
||||
|
||||
# concatenate two embeddings
|
||||
concat_embedding_output = torch.cat((language_features, img_embedding_output), 1) # [24, 55+40, 768]
|
||||
|
||||
''' pass to the Transformer layers '''
|
||||
encoder_outputs = self.encoder(concat_embedding_output,
|
||||
extended_attention_mask, head_mask=head_mask,
|
||||
encoder_history_states=encoder_history_states) # [24, 95, 768]
|
||||
|
||||
sequence_output = encoder_outputs[0] # [24, 95, 768]
|
||||
pooled_output = self.pooler(sequence_output) # We "pool" the model by simply taking the hidden state corresponding to the first token [24, 768]
|
||||
|
||||
# add hidden_states and attentions if they are here
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
class BertLanguageOnlyModel(BertPreTrainedModel):
|
||||
""" Expand from BertModel to handle image region features as input
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertLanguageOnlyModel, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = CaptionBertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, head_mask=None, img_feats=None):
|
||||
|
||||
if attention_mask.dim() == 2:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
elif attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
head_mask = [None] * self.config.num_hidden_layers # 12 heads
|
||||
|
||||
# language embeddings [24, 55, 768]
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
|
||||
concat_embedding_output = embedding_output
|
||||
|
||||
''' pass to the Transformer layers '''
|
||||
encoder_outputs = self.encoder(concat_embedding_output,
|
||||
extended_attention_mask, head_mask=head_mask) # [24, 95, 768]
|
||||
|
||||
sequence_output = encoder_outputs[0] # [24, 95, 768]
|
||||
pooled_output = self.pooler(sequence_output) # We "pool" the model by simply taking the hidden state corresponding to the first token [24, 768]
|
||||
|
||||
# add hidden_states and attentions if they are here
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class LanguageBert(BertPreTrainedModel):
|
||||
"""
|
||||
Modified from BertForMultipleChoice to support oscar training.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(LanguageBert, self).__init__(config)
|
||||
# self.loss_type = config.loss_type
|
||||
# if config.img_feature_dim > 0: # default for nlvr
|
||||
self.config = config
|
||||
if config.model_type == 'language':
|
||||
self.bert = BertLanguageOnlyModel(config) # LanuageOnlyBERT
|
||||
elif config.model_type == 'visual':
|
||||
self.bert = BertImgModel(config) # LanguageVisualBERT
|
||||
else:
|
||||
ModelTypeNotImplemented
|
||||
# else:
|
||||
# self.bert = BertModel(config) # original BERT
|
||||
|
||||
# the classifier for downstream tasks
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
# if hasattr(config, 'classifier'):
|
||||
# if not hasattr(config, 'cls_hidden_scale'): config.cls_hidden_scale = 2
|
||||
# if config.classifier == 'linear':
|
||||
# self.classifier = nn.Linear(config.num_choice*config.hidden_size, self.config.num_labels)
|
||||
# elif config.classifier == 'mlp':
|
||||
# self.classifier = nn.Sequential(
|
||||
# nn.Linear(config.num_choice*config.hidden_size, config.hidden_size*config.cls_hidden_scale),
|
||||
# nn.ReLU(),
|
||||
# nn.Linear(config.hidden_size*config.cls_hidden_scale, self.config.num_labels)
|
||||
# )
|
||||
# else:
|
||||
# self.classifier = nn.Linear(config.num_choice*config.hidden_size, self.config.num_labels) # original
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None, img_feats=None):
|
||||
# num_choices = input_ids.shape[1]
|
||||
#
|
||||
# flat_input_ids = input_ids.view(-1, input_ids.size(-1)) # [24, 55]
|
||||
# flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
# flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None # [24, 55]
|
||||
# flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None # [24, 95]
|
||||
#
|
||||
# flat_img_feats = img_feats.view(-1, img_feats.size(-2), img_feats.size(-1)) if img_feats is not None else None # [24, 40, 2054]
|
||||
|
||||
# if isinstance(self.bert, BertImgModel):
|
||||
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask, head_mask=head_mask, img_feats=img_feats)
|
||||
# else:
|
||||
# outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
|
||||
# attention_mask=flat_attention_mask, head_mask=head_mask)
|
||||
# outputs - the squence output
|
||||
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
# We "pool" the model by simply taking the hidden state corresponding to the first token [batch_size, 768]
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
|
||||
if self.config.model_type == 'language':
|
||||
return sequence_output
|
||||
elif self.config.model_type == 'visual':
|
||||
return pooled_output
|
||||
885
r2r_src/vlnbert/modeling_utils.py
Normal file
885
r2r_src/vlnbert/modeling_utils.py
Normal file
@ -0,0 +1,885 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from io import open
|
||||
|
||||
import six
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
#from .file_utils import cached_path
|
||||
from pytorch_pretrained_bert.file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn import Identity
|
||||
except ImportError:
|
||||
# Older PyTorch compatibility
|
||||
class Identity(nn.Module):
|
||||
r"""A placeholder identity operator that is argument-insensitive.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
if not six.PY2:
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = ''.join(docstr) + fn.__doc__
|
||||
return fn
|
||||
return docstring_decorator
|
||||
|
||||
def add_end_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = fn.__doc__ + ''.join(docstr)
|
||||
return fn
|
||||
return docstring_decorator
|
||||
else:
|
||||
# Not possible to update class docstrings on python2
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
return fn
|
||||
return docstring_decorator
|
||||
|
||||
def add_end_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
return fn
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
class PretrainedConfig(object):
|
||||
r""" Base class for all configuration classes.
|
||||
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
||||
Note:
|
||||
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
|
||||
It only affects the model's configuration.
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
|
||||
Parameters:
|
||||
``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
|
||||
``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens)
|
||||
``output_attentions``: boolean, default `False`. Should the model returns attentions weights.
|
||||
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
|
||||
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
||||
"""
|
||||
pretrained_config_archive_map = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.finetuning_task = kwargs.pop('finetuning_task', None)
|
||||
self.num_labels = kwargs.pop('num_labels', 2)
|
||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||
self.torchscript = kwargs.pop('torchscript', False)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a configuration object to the directory `save_directory`, so that it
|
||||
can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method.
|
||||
"""
|
||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||
|
||||
self.to_json_file(output_config_file)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
|
||||
Parameters:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.
|
||||
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
||||
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
return_unused_kwargs: (`optional`) bool:
|
||||
- If False, then this function returns just the final configuration object.
|
||||
- If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
|
||||
Examples::
|
||||
# We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
|
||||
# derived class: BertConfig
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
|
||||
config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
|
||||
config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
|
||||
assert config.output_attention == True
|
||||
config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
|
||||
foo=False, return_unused_kwargs=True)
|
||||
assert config.output_attention == True
|
||||
assert unused_kwargs == {'foo': False}
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
else:
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError as e:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
logger.error(
|
||||
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||
config_file))
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(cls.pretrained_config_archive_map.keys()),
|
||||
config_file))
|
||||
raise e
|
||||
if resolved_config_file == config_file:
|
||||
logger.info("loading configuration file {}".format(config_file))
|
||||
else:
|
||||
logger.info("loading configuration file {} from cache at {}".format(
|
||||
config_file, resolved_config_file))
|
||||
|
||||
# Load config
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
|
||||
# Update config with kwargs if needed
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info("Model config %s", config)
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
else:
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `Config` from a Python dictionary of parameters."""
|
||||
config = cls(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
""" Save this instance to a json file."""
|
||||
with open(json_file_path, "w", encoding='utf-8') as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
r""" Base class for all models.
|
||||
:class:`~pytorch_transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
|
||||
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
||||
- ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`,
|
||||
- ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`,
|
||||
- ``path``: a path (string) to the TensorFlow checkpoint.
|
||||
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
||||
"""
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
load_tf_weights = lambda model, config, path: None
|
||||
base_model_prefix = ""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
|
||||
"To create a model from a pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
Args:
|
||||
new_num_tokens: (`optional`) int
|
||||
New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: return the provided token Embedding Module.
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
"""
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
new_embeddings.to(old_embeddings.weight.device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self.init_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def _tie_or_clone_weights(self, first_module, second_module):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
first_module.weight = nn.Parameter(second_module.weight.clone())
|
||||
else:
|
||||
first_module.weight = second_module.weight
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
Arguments:
|
||||
new_num_tokens: (`optional`) int:
|
||||
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
||||
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
||||
Return: ``torch.nn.Embeddings``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
Arguments:
|
||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
|
||||
"""
|
||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||
|
||||
# Only save the model it-self if we are using distributed training
|
||||
model_to_save = self.module if hasattr(self, 'module') else self
|
||||
|
||||
# Save configuration file
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with ``model.train()``
|
||||
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
|
||||
It is up to you to train those weights with a downstream fine-tuning task.
|
||||
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
|
||||
Parameters:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
- the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||
In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
||||
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
||||
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
||||
Examples::
|
||||
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
||||
assert model.config.output_attention == True
|
||||
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
||||
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
"""
|
||||
config = kwargs.pop('config', None)
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
if config is None:
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args,
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
else:
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError as e:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error(
|
||||
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
||||
archive_file))
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(cls.pretrained_model_archive_map.keys()),
|
||||
archive_file))
|
||||
raise e
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info("loading weights file {}".format(archive_file))
|
||||
else:
|
||||
logger.info("loading weights file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
|
||||
# Instantiate model.
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
|
||||
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# Load from a PyTorch state_dict
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
start_prefix = ''
|
||||
model_to_load = model
|
||||
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
|
||||
start_prefix = cls.base_model_prefix + '.'
|
||||
if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
|
||||
model_to_load = getattr(model, cls.base_model_prefix)
|
||||
|
||||
load(model_to_load, prefix=start_prefix)
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
if hasattr(model, 'tie_weights'):
|
||||
model.tie_weights() # make sure word embedding weights are still tied
|
||||
|
||||
# Set model in evaluation mode to desactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, nx):
|
||||
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
|
||||
Basically works like a Linear layer but the weights are transposed
|
||||
"""
|
||||
super(Conv1D, self).__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.weight = nn.Parameter(w)
|
||||
self.bias = nn.Parameter(torch.zeros(nf))
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(*size_out)
|
||||
return x
|
||||
|
||||
|
||||
class PoolerStartLogits(nn.Module):
|
||||
""" Compute SQuAD start_logits from sequence hidden states. """
|
||||
def __init__(self, config):
|
||||
super(PoolerStartLogits, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
def forward(self, hidden_states, p_mask=None):
|
||||
""" Args:
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
|
||||
invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PoolerEndLogits(nn.Module):
|
||||
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PoolerEndLogits, self).__init__()
|
||||
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
||||
""" Args:
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span:
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||
|
||||
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
||||
x = self.activation(x)
|
||||
x = self.LayerNorm(x)
|
||||
x = self.dense_1(x).squeeze(-1)
|
||||
|
||||
if p_mask is not None:
|
||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PoolerAnswerClass(nn.Module):
|
||||
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
|
||||
def __init__(self, config):
|
||||
super(PoolerAnswerClass, self).__init__()
|
||||
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
||||
"""
|
||||
Args:
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
note(Original repo):
|
||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||
for each sample
|
||||
"""
|
||||
hsz = hidden_states.shape[-1]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
||||
|
||||
if cls_index is not None:
|
||||
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
||||
else:
|
||||
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
||||
|
||||
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
||||
x = self.activation(x)
|
||||
x = self.dense_1(x).squeeze(-1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SQuADHead(nn.Module):
|
||||
r""" A SQuAD head inspired by XLNet.
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
||||
Inputs:
|
||||
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
|
||||
hidden states of sequence tokens
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the last token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
**start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
**end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SQuADHead, self).__init__()
|
||||
self.start_n_top = config.start_n_top
|
||||
self.end_n_top = config.end_n_top
|
||||
|
||||
self.start_logits = PoolerStartLogits(config)
|
||||
self.end_logits = PoolerEndLogits(config)
|
||||
self.answer_class = PoolerAnswerClass(config)
|
||||
|
||||
def forward(self, hidden_states, start_positions=None, end_positions=None,
|
||||
cls_index=None, is_impossible=None, p_mask=None):
|
||||
outputs = ()
|
||||
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
||||
for x in (start_positions, end_positions, cls_index, is_impossible):
|
||||
if x is not None and x.dim() > 1:
|
||||
x.squeeze_(-1)
|
||||
|
||||
# during training, compute the end logits based on the ground truth of the start position
|
||||
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if cls_index is not None and is_impossible is not None:
|
||||
# Predict answerability from the representation of CLS and START
|
||||
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
|
||||
loss_fct_cls = nn.BCEWithLogitsLoss()
|
||||
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
||||
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
bsz, slen, hsz = hidden_states.size()
|
||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||
|
||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
|
||||
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
|
||||
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
|
||||
|
||||
end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
|
||||
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
||||
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
||||
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||
|
||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
||||
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||
# or (if labels are provided) (total_loss,)
|
||||
return outputs
|
||||
|
||||
|
||||
class SequenceSummary(nn.Module):
|
||||
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
||||
Args of the config class:
|
||||
summary_type:
|
||||
- 'last' => [default] take the last token hidden state (like XLNet)
|
||||
- 'first' => take the first token hidden state (like Bert)
|
||||
- 'mean' => take the mean of all tokens hidden states
|
||||
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
|
||||
- 'attn' => Not implemented now, use multi-head attention
|
||||
summary_use_proj: Add a projection after the vector extraction
|
||||
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
||||
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
|
||||
summary_first_dropout: Add a dropout before the projection and activation
|
||||
summary_last_dropout: Add a dropout after the projection and activation
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SequenceSummary, self).__init__()
|
||||
|
||||
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
|
||||
if self.summary_type == 'attn':
|
||||
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||
raise NotImplementedError
|
||||
|
||||
self.summary = Identity()
|
||||
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
|
||||
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
||||
num_classes = config.num_labels
|
||||
else:
|
||||
num_classes = config.hidden_size
|
||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
self.activation = Identity()
|
||||
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
self.first_dropout = Identity()
|
||||
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
|
||||
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
||||
|
||||
self.last_dropout = Identity()
|
||||
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
|
||||
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
||||
|
||||
def forward(self, hidden_states, cls_index=None):
|
||||
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
|
||||
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
|
||||
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
|
||||
if summary_type == 'cls_index' and cls_index is None:
|
||||
we take the last token of the sequence as classification token
|
||||
"""
|
||||
if self.summary_type == 'last':
|
||||
output = hidden_states[:, -1]
|
||||
elif self.summary_type == 'first':
|
||||
output = hidden_states[:, 0]
|
||||
elif self.summary_type == 'mean':
|
||||
output = hidden_states.mean(dim=1)
|
||||
elif self.summary_type == 'cls_index':
|
||||
if cls_index is None:
|
||||
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
|
||||
else:
|
||||
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||
cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
|
||||
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
||||
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
||||
elif self.summary_type == 'attn':
|
||||
raise NotImplementedError
|
||||
|
||||
output = self.first_dropout(output)
|
||||
output = self.summary(output)
|
||||
output = self.activation(output)
|
||||
output = self.last_dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def prune_linear_layer(layer, index, dim=0):
|
||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if layer.bias is not None:
|
||||
if dim == 1:
|
||||
b = layer.bias.clone().detach()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
if layer.bias is not None:
|
||||
new_layer.bias.requires_grad = False
|
||||
new_layer.bias.copy_(b.contiguous())
|
||||
new_layer.bias.requires_grad = True
|
||||
return new_layer
|
||||
|
||||
|
||||
def prune_conv1d_layer(layer, index, dim=1):
|
||||
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
|
||||
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
"""
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if dim == 0:
|
||||
b = layer.bias.clone().detach()
|
||||
else:
|
||||
b = layer.bias[index].clone().detach()
|
||||
new_size = list(layer.weight.size())
|
||||
new_size[dim] = len(index)
|
||||
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
||||
new_layer.weight.requires_grad = False
|
||||
new_layer.weight.copy_(W.contiguous())
|
||||
new_layer.weight.requires_grad = True
|
||||
new_layer.bias.requires_grad = False
|
||||
new_layer.bias.copy_(b.contiguous())
|
||||
new_layer.bias.requires_grad = True
|
||||
return new_layer
|
||||
|
||||
|
||||
def prune_layer(layer, index, dim=None):
|
||||
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
"""
|
||||
if isinstance(layer, nn.Linear):
|
||||
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
||||
elif isinstance(layer, Conv1D):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
||||
458
r2r_src/vlnbert/modeling_visbert.py
Normal file
458
r2r_src/vlnbert/modeling_visbert.py
Normal file
@ -0,0 +1,458 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel,
|
||||
prune_linear_layer, add_start_docstrings)
|
||||
|
||||
from transformers import BertPreTrainedModel,BertConfig
|
||||
import pdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
||||
BertLayerNorm = torch.nn.LayerNorm
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.output_attentions = True
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_scores) if self.output_attentions else (context_layer,)
|
||||
return outputs
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertAttention, self).__init__()
|
||||
self.self = BertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||
attention_output = self.output(self_outputs[0], input_tensor)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertXAttention(nn.Module):
|
||||
def __init__(self, config, ctx_dim=None):
|
||||
super().__init__()
|
||||
self.att = BertOutAttention(config, ctx_dim=ctx_dim)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
|
||||
output, attention_scores = self.att(input_tensor, ctx_tensor, ctx_att_mask)
|
||||
attention_output = self.output(output, input_tensor)
|
||||
return attention_output, attention_scores
|
||||
|
||||
|
||||
class BertOutAttention(nn.Module):
|
||||
def __init__(self, config, ctx_dim=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
# visual_dim = 2048
|
||||
if ctx_dim is None:
|
||||
ctx_dim =config.hidden_size
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(ctx_dim, self.all_head_size)
|
||||
self.value = nn.Linear(ctx_dim, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, context, attention_mask=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(context)
|
||||
mixed_value_layer = self.value(context)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
if attention_mask is not None:
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer, attention_scores
|
||||
|
||||
|
||||
class LXRTXLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# Lang self-att and FFN layer
|
||||
self.lang_self_att = BertAttention(config)
|
||||
self.lang_inter = BertIntermediate(config)
|
||||
self.lang_output = BertOutput(config)
|
||||
# Visn self-att and FFN layer
|
||||
self.visn_self_att = BertAttention(config)
|
||||
self.visn_inter = BertIntermediate(config)
|
||||
self.visn_output = BertOutput(config)
|
||||
# The cross attention layer
|
||||
self.visual_attention = BertXAttention(config)
|
||||
|
||||
def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
|
||||
''' Cross Attention -- cross for vision not for language '''
|
||||
visn_att_output, attention_scores = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask)
|
||||
return visn_att_output, attention_scores
|
||||
|
||||
def self_att(self, visn_input, visn_attention_mask):
|
||||
''' Self Attention -- on visual features with language clues '''
|
||||
visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)
|
||||
return visn_att_output
|
||||
|
||||
def output_fc(self, visn_input):
|
||||
''' Feed forward '''
|
||||
visn_inter_output = self.visn_inter(visn_input)
|
||||
visn_output = self.visn_output(visn_inter_output, visn_input)
|
||||
return visn_output
|
||||
|
||||
def forward(self, lang_feats, lang_attention_mask,
|
||||
visn_feats, visn_attention_mask, tdx):
|
||||
|
||||
''' visual self-attention with state '''
|
||||
visn_att_output = torch.cat((lang_feats[:, 0:1, :], visn_feats), dim=1) # [8, cand_dir+1, 768] vision with states
|
||||
state_vis_mask = torch.cat((lang_attention_mask[:,:,:,0:1], visn_attention_mask), dim=-1)
|
||||
|
||||
''' state and vision attend to language '''
|
||||
visn_att_output, cross_attention_scores = self.cross_att(lang_feats[:, 1:, :], lang_attention_mask[:, :, :, 1:], visn_att_output, state_vis_mask)
|
||||
|
||||
language_attention_scores = cross_attention_scores[:, :, 0, :]
|
||||
|
||||
state_visn_att_output = self.self_att(visn_att_output, state_vis_mask)
|
||||
|
||||
state_visn_output = self.output_fc(state_visn_att_output[0])
|
||||
|
||||
visn_att_output = state_visn_output[:, 1:, :]
|
||||
lang_att_output = torch.cat((state_visn_output[:, 0:1, :], lang_feats[:,1:,:]), dim=1) # [8, 80, 768]
|
||||
|
||||
visual_attention_scores = state_visn_att_output[1][:, :, 0, 1:]
|
||||
|
||||
return lang_att_output, visn_att_output, language_attention_scores, visual_attention_scores
|
||||
|
||||
|
||||
class VisionEncoder(nn.Module):
|
||||
def __init__(self, vision_size, config):
|
||||
super().__init__()
|
||||
feat_dim = vision_size
|
||||
|
||||
# Object feature encoding
|
||||
self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
|
||||
self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, visn_input):
|
||||
feats = visn_input
|
||||
|
||||
x = self.visn_fc(feats)
|
||||
x = self.visn_layer_norm(x)
|
||||
|
||||
output = self.dropout(x)
|
||||
return output
|
||||
|
||||
|
||||
class VLNBert(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(VLNBert, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.img_dim = config.img_feature_dim # 2176
|
||||
logger.info('VLNBert Image Dimension: {}'.format(self.img_dim))
|
||||
self.img_feature_type = config.img_feature_type # ''
|
||||
self.vl_layers = config.vl_layers # 4
|
||||
self.la_layers = config.la_layers # 9
|
||||
self.update_lang_bert = True # default False
|
||||
self.update_add_layer = True # default False
|
||||
self.lalayer = nn.ModuleList(
|
||||
[BertLayer(config) for _ in range(self.la_layers)])
|
||||
self.addlayer = nn.ModuleList(
|
||||
[LXRTXLayer(config) for _ in range(self.vl_layers)])
|
||||
# self.vision_encoder = VisionEncoder(self.config.img_feature_dim, self.config)
|
||||
# self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None,
|
||||
attention_mask=None, lang_mask=None, vis_mask=None, obj_mask=None,
|
||||
position_ids=None, head_mask=None, img_feats=None):
|
||||
|
||||
attention_mask = lang_mask
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
if mode == 'language':
|
||||
''' LXMERT language branch (in VLN only perform this at initialization) '''
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
text_embeds = embedding_output
|
||||
|
||||
for layer_module in self.lalayer:
|
||||
temp_output = layer_module(text_embeds, extended_attention_mask)
|
||||
text_embeds = temp_output[0] # [8, 80, 768]
|
||||
|
||||
sequence_output = text_embeds
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
return pooled_output, sequence_output
|
||||
|
||||
elif mode == 'visual':
|
||||
''' LXMERT visual branch (no language processing during navigation) '''
|
||||
text_embeds = input_ids
|
||||
text_mask = extended_attention_mask
|
||||
|
||||
img_embedding_output = img_feats # self.vision_encoder()
|
||||
img_seq_len = img_feats.shape[1]
|
||||
batch_size = text_embeds.size(0)
|
||||
|
||||
img_seq_mask = torch.cat((vis_mask, obj_mask), dim=1)
|
||||
|
||||
extended_img_mask = img_seq_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_img_mask = extended_img_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_img_mask = (1.0 - extended_img_mask) * -10000.0
|
||||
img_mask = extended_img_mask
|
||||
|
||||
lang_output = text_embeds
|
||||
visn_output = img_embedding_output
|
||||
|
||||
for tdx, layer_module in enumerate(self.addlayer):
|
||||
lang_output, visn_output, language_attention_scores, visual_attention_scores = layer_module(lang_output, text_mask, visn_output, img_mask, tdx)
|
||||
|
||||
sequence_output = lang_output # [8, 80, 768]
|
||||
pooled_output = self.pooler(sequence_output) # [8, 768]
|
||||
|
||||
# attentions over all objects, mean over the 12 heads
|
||||
attention_scores_obj = visual_attention_scores[:, :, -self.config.obj_directions:].mean(dim=1)
|
||||
# use attention on objects for stopping
|
||||
stop_scores, _ = attention_scores_obj.max(1)
|
||||
|
||||
# language_state_scores = language_attention_scores.mean(dim=1)
|
||||
candidate_scores = visual_attention_scores[:, :, :self.config.directions].mean(dim=1)
|
||||
|
||||
visual_action_scores = torch.cat((candidate_scores, stop_scores.unsqueeze(1)), dim=-1)
|
||||
|
||||
visual_attention_probs = nn.Softmax(dim=-1)(candidate_scores.clone()).unsqueeze(-1)
|
||||
attended_visual = (visual_attention_probs * img_embedding_output[:, :self.config.directions, :]).sum(1)
|
||||
|
||||
return pooled_output, visual_action_scores, attention_scores_obj, attended_visual
|
||||
@ -9,7 +9,13 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
|
||||
#from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
|
||||
# BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
||||
# BertSelfOutput, BertIntermediate, BertOutput,
|
||||
# BertPooler, BertLayerNorm, BertPreTrainedModel,
|
||||
# BertPredictionHeadTransform)
|
||||
|
||||
from pytorch_transformers.modeling_bertimport (BertEmbeddings,
|
||||
BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
||||
BertSelfOutput, BertIntermediate, BertOutput,
|
||||
BertPooler, BertLayerNorm, BertPreTrainedModel,
|
||||
@ -185,7 +191,8 @@ class BertImgModel(BertPreTrainedModel):
|
||||
self.img_dim = config.img_feature_dim
|
||||
logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
|
||||
|
||||
self.apply(self.init_weights)
|
||||
# self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, img_feats=None):
|
||||
@ -237,7 +244,8 @@ class VLNBert(BertPreTrainedModel):
|
||||
self.state_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.apply(self.init_weights)
|
||||
# self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, img_feats=None):
|
||||
|
||||
@ -14,7 +14,8 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
|
||||
#from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
|
||||
from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
|
||||
import pdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -379,7 +380,8 @@ class VLNBert(BertPreTrainedModel):
|
||||
self.addlayer = nn.ModuleList(
|
||||
[LXRTXLayer(config) for _ in range(self.vl_layers)])
|
||||
self.vision_encoder = VisionEncoder(self.config.img_feature_dim, self.config)
|
||||
self.apply(self.init_weights)
|
||||
# self.apply(self.init_weights)
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None,
|
||||
attention_mask=None, lang_mask=None, vis_mask=None, position_ids=None, head_mask=None, img_feats=None):
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
|
||||
|
||||
from transformers.pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
#from transformers.pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
# from pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
from pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
|
||||
def get_tokenizer(args):
|
||||
if args.vlnbert == 'oscar':
|
||||
tokenizer_class = BertTokenizer
|
||||
model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
model_name_or_path = 'r2r_src/vlnbert/Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
|
||||
elif args.vlnbert == 'prevalent':
|
||||
tokenizer_class = BertTokenizer
|
||||
@ -16,9 +18,10 @@ def get_vlnbert_models(args, config=None):
|
||||
config_class = BertConfig
|
||||
|
||||
if args.vlnbert == 'oscar':
|
||||
print('\n VLN-BERT model is Oscar!!!')
|
||||
from vlnbert.vlnbert_OSCAR import VLNBert
|
||||
model_class = VLNBert
|
||||
model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
model_name_or_path = 'r2r_src/vlnbert/Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
vis_config = config_class.from_pretrained(model_name_or_path, num_labels=2, finetuning_task='vln-r2r')
|
||||
|
||||
vis_config.model_type = 'visual'
|
||||
@ -31,9 +34,11 @@ def get_vlnbert_models(args, config=None):
|
||||
visual_model = model_class.from_pretrained(model_name_or_path, from_tf=False, config=vis_config)
|
||||
|
||||
elif args.vlnbert == 'prevalent':
|
||||
print('\n VLN-BERT model is prevalent!!!')
|
||||
from vlnbert.vlnbert_PREVALENT import VLNBert
|
||||
model_class = VLNBert
|
||||
model_name_or_path = 'Prevalent/pretrained_model/pytorch_model.bin'
|
||||
#model_name_or_path = './Prevalent/pretrained_model/pytorch_model.bin'
|
||||
model_name_or_path = 'r2r_src/vlnbert/Prevalent/pretrained_model/pytorch_model.bin'
|
||||
vis_config = config_class.from_pretrained('bert-base-uncased')
|
||||
vis_config.img_feature_dim = 2176
|
||||
vis_config.img_feature_type = ""
|
||||
|
||||
40
r2r_src/vlnbert/vlnbert_model.py
Normal file
40
r2r_src/vlnbert/vlnbert_model.py
Normal file
@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import (Dataset, DataLoader, RandomSampler, SequentialSampler, TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
import _pickle as cPickle
|
||||
|
||||
import sys
|
||||
from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer)
|
||||
from transformers import AdamW
|
||||
from transformers import get_linear_schedule_with_warmup as WarmupLinearSchedule
|
||||
|
||||
# from vlnbert.modeling_bert import LanguageBert
|
||||
from vlnbert.modeling_visbert import VLNBert
|
||||
|
||||
model_name_or_path = 'r2r_src/vlnbert/Prevalent/pretrained_model/pytorch_model.bin'
|
||||
|
||||
def get_tokenizer():
|
||||
tokenizer_class = BertTokenizer
|
||||
tokenizer = tokenizer_class.from_pretrained('bert-base-uncased')
|
||||
return tokenizer
|
||||
|
||||
def get_vlnbert_models(config=None):
|
||||
config_class = BertConfig
|
||||
model_class = VLNBert
|
||||
vis_config = config_class.from_pretrained('bert-base-uncased')
|
||||
|
||||
# all configurations (need to pack into args)
|
||||
vis_config.img_feature_dim = 2176
|
||||
vis_config.img_feature_type = ""
|
||||
vis_config.update_lang_bert = False
|
||||
vis_config.update_add_layer = False
|
||||
vis_config.vl_layers = 4
|
||||
vis_config.la_layers = 9
|
||||
visual_model = VLNBert(vis_config)
|
||||
|
||||
visual_model = model_class.from_pretrained(model_name_or_path, config=vis_config)
|
||||
|
||||
return visual_model
|
||||
@ -23,4 +23,4 @@ flag="--vlnbert prevalent
|
||||
--dropout 0.5"
|
||||
|
||||
mkdir -p snap/$name
|
||||
CUDA_VISIBLE_DEVICES=1 python r2r_src/train.py $flag --name $name
|
||||
CUDA_VISIBLE_DEVICES=0 python3 r2r_src/train.py $flag --name $name
|
||||
|
||||
Loading…
Reference in New Issue
Block a user