import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from param import args from vlnbert.vlnbert_model import get_vlnbert_models class VLNBERT(nn.Module): def __init__(self, directions=4, feature_size=2048+128): super(VLNBERT, self).__init__() print('\nInitalizing the VLN-BERT model ...') self.vln_bert = get_vlnbert_models(config=None) # initialize the VLN-BERT self.vln_bert.config.directions = directions hidden_size = self.vln_bert.config.hidden_size layer_norm_eps = self.vln_bert.config.layer_norm_eps self.action_state_project = nn.Sequential( nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh()) self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) self.obj_pos_encode = nn.Linear(5, args.angle_feat_size, bias=True) self.obj_projection = nn.Linear(feature_size+args.angle_feat_size, hidden_size, bias=True) self.obj_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) self.drop_env = nn.Dropout(p=args.featdropout) self.img_projection = nn.Linear(feature_size, hidden_size, bias=True) self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) self.state_proj = nn.Linear(hidden_size*2, hidden_size, bias=True) self.state_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps) def forward(self, mode, sentence, token_type_ids=None, attention_mask=None, lang_mask=None, vis_mask=None, obj_mask=None, position_ids=None, action_feats=None, pano_feats=None, cand_feats=None, obj_feats=None, obj_pos=None, already_dropfeat=False): if mode == 'language': init_state, encoded_sentence = self.vln_bert(mode, sentence, position_ids=position_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, lang_mask=lang_mask) return init_state, encoded_sentence elif mode == 'visual': state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1) state_with_action = self.action_state_project(state_action_embed) state_with_action = self.action_LayerNorm(state_with_action) state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1) if not already_dropfeat: cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size]) obj_feats[..., :-args.angle_feat_size] = self.drop_env(obj_feats[..., :-args.angle_feat_size]) cand_feats_embed = self.img_projection(cand_feats) # [2176 * 768] projection cand_feats_embed = self.cand_LayerNorm(cand_feats_embed) obj_feats_embed = self.obj_pos_encode(obj_pos) obj_feats_concat = torch.cat((obj_feats[..., :-args.angle_feat_size], obj_feats_embed, obj_feats[..., -args.angle_feat_size:]), dim=-1) obj_feats_embed = self.obj_projection(obj_feats_concat) obj_feats_embed = self.obj_LayerNorm(obj_feats_embed) cand_obj_feats_embed = torch.cat((cand_feats_embed, obj_feats_embed), dim=1) # logit is the attention scores over the candidate features h_t, logit, logit_obj, attended_visual = self.vln_bert(mode, state_feats, attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, obj_mask=obj_mask, img_feats=cand_obj_feats_embed) state_output = torch.cat((h_t, attended_visual), dim=-1) state_proj = self.state_proj(state_output) state_proj = self.state_LayerNorm(state_proj) return state_proj, logit, logit_obj else: ModuleNotFoundError class SoftDotAttention(nn.Module): '''Soft Dot Attention. Ref: http://www.aclweb.org/anthology/D15-1166 Adapted from PyTorch OPEN NMT. ''' def __init__(self, query_dim, ctx_dim): '''Initialize layer.''' super(SoftDotAttention, self).__init__() self.linear_in = nn.Linear(query_dim, ctx_dim, bias=False) self.sm = nn.Softmax() self.linear_out = nn.Linear(query_dim + ctx_dim, query_dim, bias=False) self.tanh = nn.Tanh() def forward(self, h, context, mask=None, output_tilde=True, output_prob=True, input_project=True): '''Propagate h through the network. h: batch x dim context: batch x seq_len x dim mask: batch x seq_len indices to be masked ''' if input_project: target = self.linear_in(h).unsqueeze(2) # batch x dim x 1 else: target = h.unsqueeze(2) # batch x dim x 1 # Get attention attn = torch.bmm(context, target).squeeze(2) # batch x seq_len logit = attn if mask is not None: # -Inf masking prior to the softmax attn.masked_fill_(mask, -float('inf')) attn = self.sm(attn) # There will be a bug here, but it's actually a problem in torch source code. attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x seq_len weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim if not output_prob: attn = logit if output_tilde: h_tilde = torch.cat((weighted_context, h), 1) h_tilde = self.tanh(self.linear_out(h_tilde)) return h_tilde, attn else: return weighted_context, attn class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class AttnDecoderLSTM(nn.Module): ''' An unrolled LSTM with attention over instructions for decoding navigation actions. ''' def __init__(self, hidden_size, dropout_ratio, feature_size=2048+4): super(AttnDecoderLSTM, self).__init__() self.drop = nn.Dropout(p=dropout_ratio) self.drop_env = nn.Dropout(p=args.featdropout) self.candidate_att_layer = SoftDotAttention(768, feature_size) # 768 is the output feature dimension from BERT def forward(self, h_t, cand_feat, already_dropfeat=False): if not already_dropfeat: cand_feat[..., :-args.angle_feat_size] = self.drop_env(cand_feat[..., :-args.angle_feat_size]) _, logit = self.candidate_att_layer(h_t, cand_feat, output_prob=False) return logit class Critic(nn.Module): def __init__(self): super(Critic, self).__init__() self.state2value = nn.Sequential( nn.Linear(768, args.rnn_dim), nn.ReLU(), nn.Dropout(args.dropout), nn.Linear(args.rnn_dim, 1), ) def forward(self, state): return self.state2value(state).squeeze() class SpeakerEncoder(nn.Module): def __init__(self, feature_size, hidden_size, dropout_ratio, bidirectional): super().__init__() self.num_directions = 2 if bidirectional else 1 self.hidden_size = hidden_size self.num_layers = 1 self.feature_size = feature_size if bidirectional: print("BIDIR in speaker encoder!!") self.lstm = nn.LSTM(feature_size, self.hidden_size // self.num_directions, self.num_layers, batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional) self.drop = nn.Dropout(p=dropout_ratio) self.drop3 = nn.Dropout(p=args.featdropout) self.attention_layer = SoftDotAttention(self.hidden_size, feature_size) self.post_lstm = nn.LSTM(self.hidden_size, self.hidden_size // self.num_directions, self.num_layers, batch_first=True, dropout=dropout_ratio, bidirectional=bidirectional) def forward(self, action_embeds, feature, lengths, already_dropfeat=False): """ :param action_embeds: (batch_size, length, 2052). The feature of the view :param feature: (batch_size, length, 36, 2052). The action taken (with the image feature) :param lengths: Not used in it :return: context with shape (batch_size, length, hidden_size) """ x = action_embeds if not already_dropfeat: x[..., :-args.angle_feat_size] = self.drop3(x[..., :-args.angle_feat_size]) # Do not dropout the spatial features # LSTM on the action embed ctx, _ = self.lstm(x) ctx = self.drop(ctx) # Att and Handle with the shape batch_size, max_length, _ = ctx.size() if not already_dropfeat: feature[..., :-args.angle_feat_size] = self.drop3(feature[..., :-args.angle_feat_size]) # Dropout the image feature x, _ = self.attention_layer( # Attend to the feature map ctx.contiguous().view(-1, self.hidden_size), # (batch, length, hidden) --> (batch x length, hidden) feature.view(batch_size * max_length, -1, self.feature_size), # (batch, length, # of images, feature_size) --> (batch x length, # of images, feature_size) ) x = x.view(batch_size, max_length, -1) x = self.drop(x) # Post LSTM layer x, _ = self.post_lstm(x) x = self.drop(x) return x class SpeakerDecoder(nn.Module): def __init__(self, vocab_size, embedding_size, padding_idx, hidden_size, dropout_ratio): super().__init__() self.hidden_size = hidden_size self.embedding = torch.nn.Embedding(vocab_size, embedding_size, padding_idx) self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True) self.drop = nn.Dropout(dropout_ratio) self.attention_layer = SoftDotAttention(hidden_size, hidden_size) self.projection = nn.Linear(hidden_size, vocab_size) self.baseline_projection = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(dropout_ratio), nn.Linear(128, 1) ) def forward(self, words, ctx, ctx_mask, h0, c0): embeds = self.embedding(words) embeds = self.drop(embeds) x, (h1, c1) = self.lstm(embeds, (h0, c0)) x = self.drop(x) # Get the size batchXlength = words.size(0) * words.size(1) multiplier = batchXlength // ctx.size(0) # By using this, it also supports the beam-search # Att and Handle with the shape # Reshaping x --> (b(word)*l(word), r) # Expand the ctx from (b, a, r) --> (b(word)*l(word), a, r) # Expand the ctx_mask (b, a) --> (b(word)*l(word), a) x, _ = self.attention_layer( x.contiguous().view(batchXlength, self.hidden_size), ctx.unsqueeze(1).expand(-1, multiplier, -1, -1).contiguous(). view(batchXlength, -1, self.hidden_size), mask=ctx_mask.unsqueeze(1).expand(-1, multiplier, -1).contiguous().view(batchXlength, -1) ) x = x.view(words.size(0), words.size(1), self.hidden_size) # Output the prediction logit x = self.drop(x) logit = self.projection(x) return logit, h1, c1