adversarial_VLNBERT/r2r_src/model.py

274 lines
12 KiB
Python

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from param import args
from vlnbert.vlnbert_model import get_vlnbert_models
class VLNBERT(nn.Module):
def __init__(self, directions=4, feature_size=2048+128):
super(VLNBERT, self).__init__()
print('\nInitalizing the VLN-BERT model ...')
self.vln_bert = get_vlnbert_models(config=None) # initialize the VLN-BERT
self.vln_bert.config.directions = directions
hidden_size = self.vln_bert.config.hidden_size
layer_norm_eps = self.vln_bert.config.layer_norm_eps
self.action_state_project = nn.Sequential(
nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh())
self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
self.obj_pos_encode = nn.Linear(5, args.angle_feat_size, bias=True)
self.obj_projection = nn.Linear(feature_size+args.angle_feat_size, hidden_size, bias=True)
self.obj_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
self.drop_env = nn.Dropout(p=args.featdropout)
self.img_projection = nn.Linear(feature_size, hidden_size, bias=True)
self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
self.state_proj = nn.Linear(hidden_size*2, hidden_size, bias=True)
self.state_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, mode, sentence, token_type_ids=None, attention_mask=None,
lang_mask=None, vis_mask=None, obj_mask=None,
position_ids=None, action_feats=None, pano_feats=None, cand_feats=None,
obj_feats=None, obj_pos=None, already_dropfeat=False):
if mode == 'language':
init_state, encoded_sentence = self.vln_bert(mode, sentence, position_ids=position_ids,
token_type_ids=token_type_ids, attention_mask=attention_mask, lang_mask=lang_mask)
return init_state, encoded_sentence
elif mode == 'visual':
state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1)
state_with_action = self.action_state_project(state_action_embed)
state_with_action = self.action_LayerNorm(state_with_action)
state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1)
if not already_dropfeat:
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
obj_feats[..., :-args.angle_feat_size] = self.drop_env(obj_feats[..., :-args.angle_feat_size])
cand_feats_embed = self.img_projection(cand_feats) # [2176 * 768] projection
cand_feats_embed = self.cand_LayerNorm(cand_feats_embed)
obj_feats_embed = self.obj_pos_encode(obj_pos)
obj_feats_concat = torch.cat((obj_feats[..., :-args.angle_feat_size], obj_feats_embed, obj_feats[..., -args.angle_feat_size:]), dim=-1)
obj_feats_embed = self.obj_projection(obj_feats_concat)
obj_feats_embed = self.obj_LayerNorm(obj_feats_embed)
cand_obj_feats_embed = torch.cat((cand_feats_embed, obj_feats_embed), dim=1)
# logit is the attention scores over the candidate features
h_t, logit, logit_obj, attended_visual = self.vln_bert(mode,
state_feats, attention_mask=attention_mask,
lang_mask=lang_mask, vis_mask=vis_mask, obj_mask=obj_mask,
img_feats=cand_obj_feats_embed)
state_output = torch.cat((h_t, attended_visual), dim=-1)
state_proj = self.state_proj(state_output)
state_proj = self.state_LayerNorm(state_proj)
return state_proj, logit, logit_obj
else:
ModuleNotFoundError
class SoftDotAttention(nn.Module):
'''Soft Dot Attention.
Ref: http://www.aclweb.org/anthology/D15-1166
Adapted from PyTorch OPEN NMT.
'''
def __init__(self, query_dim, ctx_dim):
'''Initialize layer.'''
super(SoftDotAttention, self).__init__()
self.linear_in = nn.Linear(query_dim, ctx_dim, bias=False)
self.sm = nn.Softmax()
self.linear_out = nn.Linear(query_dim + ctx_dim, query_dim, bias=False)
self.tanh = nn.Tanh()
def forward(self, h, context, mask=None,
output_tilde=True, output_prob=True, input_project=True):
'''Propagate h through the network.
h: batch x dim
context: batch x seq_len x dim
mask: batch x seq_len indices to be masked
'''
if input_project:
target = self.linear_in(h).unsqueeze(2) # batch x dim x 1
else:
target = h.unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, target).squeeze(2) # batch x seq_len
logit = attn
if mask is not None:
# -Inf masking prior to the softmax
attn.masked_fill_(mask, -float('inf'))
attn = self.sm(attn) # There will be a bug here, but it's actually a problem in torch source code.
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x seq_len
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
if not output_prob:
attn = logit
if output_tilde:
h_tilde = torch.cat((weighted_context, h), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
else:
return weighted_context, attn
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class AttnDecoderLSTM(nn.Module):
''' An unrolled LSTM with attention over instructions for decoding navigation actions. '''
def __init__(self, hidden_size, dropout_ratio, feature_size=2048+4):
super(AttnDecoderLSTM, self).__init__()
self.drop = nn.Dropout(p=dropout_ratio)
self.drop_env = nn.Dropout(p=args.featdropout)
self.candidate_att_layer = SoftDotAttention(768, feature_size) # 768 is the output feature dimension from BERT
def forward(self, h_t, cand_feat,
already_dropfeat=False):
if not already_dropfeat:
cand_feat[..., :-args.angle_feat_size] = self.drop_env(cand_feat[..., :-args.angle_feat_size])
_, logit = self.candidate_att_layer(h_t, cand_feat, output_prob=False)
return logit
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.state2value = nn.Sequential(
nn.Linear(768, args.rnn_dim),
nn.ReLU(),
nn.Dropout(args.dropout),
nn.Linear(args.rnn_dim, 1),
)
def forward(self, state):
return self.state2value(state).squeeze()
class SpeakerEncoder(nn.Module):
def __init__(self, feature_size, hidden_size, dropout_ratio, bidirectional):
super().__init__()
self.num_directions = 2 if bidirectional else 1
self.hidden_size = hidden_size
self.num_layers = 1
self.feature_size = feature_size
if bidirectional:
print("BIDIR in speaker encoder!!")
self.lstm = nn.LSTM(feature_size, self.hidden_size // self.num_directions, self.num_layers,
batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional)
self.drop = nn.Dropout(p=dropout_ratio)
self.drop3 = nn.Dropout(p=args.featdropout)
self.attention_layer = SoftDotAttention(self.hidden_size, feature_size)
self.post_lstm = nn.LSTM(self.hidden_size, self.hidden_size // self.num_directions, self.num_layers,
batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional)
def forward(self, action_embeds, feature, lengths, already_dropfeat=False):
"""
:param action_embeds: (batch_size, length, 2052). The feature of the view
:param feature: (batch_size, length, 36, 2052). The action taken (with the image feature)
:param lengths: Not used in it
:return: context with shape (batch_size, length, hidden_size)
"""
x = action_embeds
if not already_dropfeat:
x[..., :-args.angle_feat_size] = self.drop3(x[..., :-args.angle_feat_size]) # Do not dropout the spatial features
# LSTM on the action embed
ctx, _ = self.lstm(x)
ctx = self.drop(ctx)
# Att and Handle with the shape
batch_size, max_length, _ = ctx.size()
if not already_dropfeat:
feature[..., :-args.angle_feat_size] = self.drop3(feature[..., :-args.angle_feat_size]) # Dropout the image feature
x, _ = self.attention_layer( # Attend to the feature map
ctx.contiguous().view(-1, self.hidden_size), # (batch, length, hidden) --> (batch x length, hidden)
feature.view(batch_size * max_length, -1, self.feature_size), # (batch, length, # of images, feature_size) --> (batch x length, # of images, feature_size)
)
x = x.view(batch_size, max_length, -1)
x = self.drop(x)
# Post LSTM layer
x, _ = self.post_lstm(x)
x = self.drop(x)
return x
class SpeakerDecoder(nn.Module):
def __init__(self, vocab_size, embedding_size, padding_idx, hidden_size, dropout_ratio):
super().__init__()
self.hidden_size = hidden_size
self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx)
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
self.drop = nn.Dropout(dropout_ratio)
self.attention_layer = SoftDotAttention(hidden_size, hidden_size)
self.projection = nn.Linear(hidden_size, vocab_size)
self.baseline_projection = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Dropout(dropout_ratio),
nn.Linear(128, 1)
)
def forward(self, words, ctx, ctx_mask, h0, c0):
embeds = self.embedding(words)
embeds = self.drop(embeds)
x, (h1, c1) = self.lstm(embeds, (h0, c0))
x = self.drop(x)
# Get the size
batchXlength = words.size(0) * words.size(1)
multiplier = batchXlength // ctx.size(0) # By using this, it also supports the beam-search
# Att and Handle with the shape
# Reshaping x <the output> --> (b(word)*l(word), r)
# Expand the ctx from (b, a, r) --> (b(word)*l(word), a, r)
# Expand the ctx_mask (b, a) --> (b(word)*l(word), a)
x, _ = self.attention_layer(
x.contiguous().view(batchXlength, self.hidden_size),
ctx.unsqueeze(1).expand(-1, multiplier, -1, -1).contiguous(). view(batchXlength, -1, self.hidden_size),
mask=ctx_mask.unsqueeze(1).expand(-1, multiplier, -1).contiguous().view(batchXlength, -1)
)
x = x.view(words.size(0), words.size(1), self.hidden_size)
# Output the prediction logit
x = self.drop(x)
logit = self.projection(x)
return logit, h1, c1