274 lines
12 KiB
Python
274 lines
12 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Variable
|
|
import torch.nn.functional as F
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
from param import args
|
|
|
|
from vlnbert.vlnbert_model import get_vlnbert_models
|
|
|
|
class VLNBERT(nn.Module):
|
|
def __init__(self, directions=4, feature_size=2048+128):
|
|
super(VLNBERT, self).__init__()
|
|
print('\nInitalizing the VLN-BERT model ...')
|
|
|
|
self.vln_bert = get_vlnbert_models(config=None) # initialize the VLN-BERT
|
|
self.vln_bert.config.directions = directions
|
|
|
|
hidden_size = self.vln_bert.config.hidden_size
|
|
layer_norm_eps = self.vln_bert.config.layer_norm_eps
|
|
|
|
self.action_state_project = nn.Sequential(
|
|
nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh())
|
|
self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
|
|
|
self.obj_pos_encode = nn.Linear(5, args.angle_feat_size, bias=True)
|
|
self.obj_projection = nn.Linear(feature_size+args.angle_feat_size, hidden_size, bias=True)
|
|
self.obj_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
|
|
|
self.drop_env = nn.Dropout(p=args.featdropout)
|
|
self.img_projection = nn.Linear(feature_size, hidden_size, bias=True)
|
|
self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
|
|
|
self.state_proj = nn.Linear(hidden_size*2, hidden_size, bias=True)
|
|
self.state_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
|
|
|
def forward(self, mode, sentence, token_type_ids=None, attention_mask=None,
|
|
lang_mask=None, vis_mask=None, obj_mask=None,
|
|
position_ids=None, action_feats=None, pano_feats=None, cand_feats=None,
|
|
obj_feats=None, obj_pos=None, already_dropfeat=False):
|
|
|
|
if mode == 'language':
|
|
init_state, encoded_sentence = self.vln_bert(mode, sentence, position_ids=position_ids,
|
|
token_type_ids=token_type_ids, attention_mask=attention_mask, lang_mask=lang_mask)
|
|
|
|
return init_state, encoded_sentence
|
|
|
|
elif mode == 'visual':
|
|
|
|
state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1)
|
|
state_with_action = self.action_state_project(state_action_embed)
|
|
state_with_action = self.action_LayerNorm(state_with_action)
|
|
state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1)
|
|
|
|
if not already_dropfeat:
|
|
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
|
|
obj_feats[..., :-args.angle_feat_size] = self.drop_env(obj_feats[..., :-args.angle_feat_size])
|
|
|
|
cand_feats_embed = self.img_projection(cand_feats) # [2176 * 768] projection
|
|
cand_feats_embed = self.cand_LayerNorm(cand_feats_embed)
|
|
|
|
obj_feats_embed = self.obj_pos_encode(obj_pos)
|
|
obj_feats_concat = torch.cat((obj_feats[..., :-args.angle_feat_size], obj_feats_embed, obj_feats[..., -args.angle_feat_size:]), dim=-1)
|
|
obj_feats_embed = self.obj_projection(obj_feats_concat)
|
|
obj_feats_embed = self.obj_LayerNorm(obj_feats_embed)
|
|
|
|
cand_obj_feats_embed = torch.cat((cand_feats_embed, obj_feats_embed), dim=1)
|
|
|
|
# logit is the attention scores over the candidate features
|
|
h_t, logit, logit_obj, attended_visual = self.vln_bert(mode,
|
|
state_feats, attention_mask=attention_mask,
|
|
lang_mask=lang_mask, vis_mask=vis_mask, obj_mask=obj_mask,
|
|
img_feats=cand_obj_feats_embed)
|
|
|
|
state_output = torch.cat((h_t, attended_visual), dim=-1)
|
|
state_proj = self.state_proj(state_output)
|
|
state_proj = self.state_LayerNorm(state_proj)
|
|
|
|
return state_proj, logit, logit_obj
|
|
|
|
else:
|
|
ModuleNotFoundError
|
|
|
|
class SoftDotAttention(nn.Module):
|
|
'''Soft Dot Attention.
|
|
|
|
Ref: http://www.aclweb.org/anthology/D15-1166
|
|
Adapted from PyTorch OPEN NMT.
|
|
'''
|
|
|
|
def __init__(self, query_dim, ctx_dim):
|
|
'''Initialize layer.'''
|
|
super(SoftDotAttention, self).__init__()
|
|
self.linear_in = nn.Linear(query_dim, ctx_dim, bias=False)
|
|
self.sm = nn.Softmax()
|
|
self.linear_out = nn.Linear(query_dim + ctx_dim, query_dim, bias=False)
|
|
self.tanh = nn.Tanh()
|
|
|
|
def forward(self, h, context, mask=None,
|
|
output_tilde=True, output_prob=True, input_project=True):
|
|
'''Propagate h through the network.
|
|
|
|
h: batch x dim
|
|
context: batch x seq_len x dim
|
|
mask: batch x seq_len indices to be masked
|
|
'''
|
|
if input_project:
|
|
target = self.linear_in(h).unsqueeze(2) # batch x dim x 1
|
|
else:
|
|
target = h.unsqueeze(2) # batch x dim x 1
|
|
|
|
# Get attention
|
|
attn = torch.bmm(context, target).squeeze(2) # batch x seq_len
|
|
logit = attn
|
|
|
|
if mask is not None:
|
|
# -Inf masking prior to the softmax
|
|
attn.masked_fill_(mask, -float('inf'))
|
|
attn = self.sm(attn) # There will be a bug here, but it's actually a problem in torch source code.
|
|
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x seq_len
|
|
|
|
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
|
|
if not output_prob:
|
|
attn = logit
|
|
if output_tilde:
|
|
h_tilde = torch.cat((weighted_context, h), 1)
|
|
h_tilde = self.tanh(self.linear_out(h_tilde))
|
|
return h_tilde, attn
|
|
else:
|
|
return weighted_context, attn
|
|
|
|
class BertLayerNorm(nn.Module):
|
|
def __init__(self, hidden_size, eps=1e-12):
|
|
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
|
"""
|
|
super(BertLayerNorm, self).__init__()
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, x):
|
|
u = x.mean(-1, keepdim=True)
|
|
s = (x - u).pow(2).mean(-1, keepdim=True)
|
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
|
return self.weight * x + self.bias
|
|
|
|
class AttnDecoderLSTM(nn.Module):
|
|
''' An unrolled LSTM with attention over instructions for decoding navigation actions. '''
|
|
|
|
def __init__(self, hidden_size, dropout_ratio, feature_size=2048+4):
|
|
super(AttnDecoderLSTM, self).__init__()
|
|
self.drop = nn.Dropout(p=dropout_ratio)
|
|
self.drop_env = nn.Dropout(p=args.featdropout)
|
|
self.candidate_att_layer = SoftDotAttention(768, feature_size) # 768 is the output feature dimension from BERT
|
|
|
|
def forward(self, h_t, cand_feat,
|
|
already_dropfeat=False):
|
|
|
|
if not already_dropfeat:
|
|
cand_feat[..., :-args.angle_feat_size] = self.drop_env(cand_feat[..., :-args.angle_feat_size])
|
|
|
|
_, logit = self.candidate_att_layer(h_t, cand_feat, output_prob=False)
|
|
|
|
return logit
|
|
|
|
|
|
class Critic(nn.Module):
|
|
def __init__(self):
|
|
super(Critic, self).__init__()
|
|
self.state2value = nn.Sequential(
|
|
nn.Linear(768, args.rnn_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(args.dropout),
|
|
nn.Linear(args.rnn_dim, 1),
|
|
)
|
|
|
|
def forward(self, state):
|
|
return self.state2value(state).squeeze()
|
|
|
|
class SpeakerEncoder(nn.Module):
|
|
def __init__(self, feature_size, hidden_size, dropout_ratio, bidirectional):
|
|
super().__init__()
|
|
self.num_directions = 2 if bidirectional else 1
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = 1
|
|
self.feature_size = feature_size
|
|
|
|
if bidirectional:
|
|
print("BIDIR in speaker encoder!!")
|
|
|
|
self.lstm = nn.LSTM(feature_size, self.hidden_size // self.num_directions, self.num_layers,
|
|
batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional)
|
|
self.drop = nn.Dropout(p=dropout_ratio)
|
|
self.drop3 = nn.Dropout(p=args.featdropout)
|
|
self.attention_layer = SoftDotAttention(self.hidden_size, feature_size)
|
|
|
|
self.post_lstm = nn.LSTM(self.hidden_size, self.hidden_size // self.num_directions, self.num_layers,
|
|
batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional)
|
|
|
|
def forward(self, action_embeds, feature, lengths, already_dropfeat=False):
|
|
"""
|
|
:param action_embeds: (batch_size, length, 2052). The feature of the view
|
|
:param feature: (batch_size, length, 36, 2052). The action taken (with the image feature)
|
|
:param lengths: Not used in it
|
|
:return: context with shape (batch_size, length, hidden_size)
|
|
"""
|
|
x = action_embeds
|
|
if not already_dropfeat:
|
|
x[..., :-args.angle_feat_size] = self.drop3(x[..., :-args.angle_feat_size]) # Do not dropout the spatial features
|
|
|
|
# LSTM on the action embed
|
|
ctx, _ = self.lstm(x)
|
|
ctx = self.drop(ctx)
|
|
|
|
# Att and Handle with the shape
|
|
batch_size, max_length, _ = ctx.size()
|
|
if not already_dropfeat:
|
|
feature[..., :-args.angle_feat_size] = self.drop3(feature[..., :-args.angle_feat_size]) # Dropout the image feature
|
|
x, _ = self.attention_layer( # Attend to the feature map
|
|
ctx.contiguous().view(-1, self.hidden_size), # (batch, length, hidden) --> (batch x length, hidden)
|
|
feature.view(batch_size * max_length, -1, self.feature_size), # (batch, length, # of images, feature_size) --> (batch x length, # of images, feature_size)
|
|
)
|
|
x = x.view(batch_size, max_length, -1)
|
|
x = self.drop(x)
|
|
|
|
# Post LSTM layer
|
|
x, _ = self.post_lstm(x)
|
|
x = self.drop(x)
|
|
|
|
return x
|
|
|
|
class SpeakerDecoder(nn.Module):
|
|
def __init__(self, vocab_size, embedding_size, padding_idx, hidden_size, dropout_ratio):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx)
|
|
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
|
|
self.drop = nn.Dropout(dropout_ratio)
|
|
self.attention_layer = SoftDotAttention(hidden_size, hidden_size)
|
|
self.projection = nn.Linear(hidden_size, vocab_size)
|
|
self.baseline_projection = nn.Sequential(
|
|
nn.Linear(hidden_size, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_ratio),
|
|
nn.Linear(128, 1)
|
|
)
|
|
|
|
def forward(self, words, ctx, ctx_mask, h0, c0):
|
|
embeds = self.embedding(words)
|
|
embeds = self.drop(embeds)
|
|
x, (h1, c1) = self.lstm(embeds, (h0, c0))
|
|
|
|
x = self.drop(x)
|
|
|
|
# Get the size
|
|
batchXlength = words.size(0) * words.size(1)
|
|
multiplier = batchXlength // ctx.size(0) # By using this, it also supports the beam-search
|
|
|
|
# Att and Handle with the shape
|
|
# Reshaping x <the output> --> (b(word)*l(word), r)
|
|
# Expand the ctx from (b, a, r) --> (b(word)*l(word), a, r)
|
|
# Expand the ctx_mask (b, a) --> (b(word)*l(word), a)
|
|
x, _ = self.attention_layer(
|
|
x.contiguous().view(batchXlength, self.hidden_size),
|
|
ctx.unsqueeze(1).expand(-1, multiplier, -1, -1).contiguous(). view(batchXlength, -1, self.hidden_size),
|
|
mask=ctx_mask.unsqueeze(1).expand(-1, multiplier, -1).contiguous().view(batchXlength, -1)
|
|
)
|
|
x = x.view(words.size(0), words.size(1), self.hidden_size)
|
|
|
|
# Output the prediction logit
|
|
x = self.drop(x)
|
|
logit = self.projection(x)
|
|
|
|
return logit, h1, c1
|