290 lines
12 KiB
Python
290 lines
12 KiB
Python
# Copyright (c) 2020 Microsoft Corporation. Licensed under the MIT license.
|
|
# Modified in Recurrent VLN-BERT, 2020, Yicong.Hong@anu.edu.au
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
import logging
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import CrossEntropyLoss, MSELoss
|
|
|
|
#from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
|
|
# BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
|
# BertSelfOutput, BertIntermediate, BertOutput,
|
|
# BertPooler, BertLayerNorm, BertPreTrainedModel,
|
|
# BertPredictionHeadTransform)
|
|
|
|
from pytorch_transformers.modeling_bertimport (BertEmbeddings,
|
|
BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
|
BertSelfOutput, BertIntermediate, BertOutput,
|
|
BertPooler, BertLayerNorm, BertPreTrainedModel,
|
|
BertPredictionHeadTransform)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CaptionBertSelfAttention(BertSelfAttention):
|
|
"""
|
|
Modified from BertSelfAttention to add support for output_hidden_states.
|
|
"""
|
|
def __init__(self, config):
|
|
super(CaptionBertSelfAttention, self).__init__(config)
|
|
self.config = config
|
|
|
|
def forward(self, mode, hidden_states, attention_mask, head_mask=None,
|
|
history_state=None):
|
|
if history_state is not None:
|
|
x_states = torch.cat([history_state, hidden_states], dim=1)
|
|
mixed_query_layer = self.query(hidden_states)
|
|
mixed_key_layer = self.key(x_states)
|
|
mixed_value_layer = self.value(x_states)
|
|
else:
|
|
mixed_query_layer = self.query(hidden_states)
|
|
mixed_key_layer = self.key(hidden_states)
|
|
mixed_value_layer = self.value(hidden_states)
|
|
|
|
if mode == 'visual':
|
|
mixed_query_layer = mixed_query_layer[:, [0]+list(range(-self.config.directions, 0)), :]
|
|
|
|
''' language feature only provide Keys and Values '''
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
outputs = (context_layer, attention_scores)
|
|
|
|
return outputs
|
|
|
|
|
|
class CaptionBertAttention(BertAttention):
|
|
"""
|
|
Modified from BertAttention to add support for output_hidden_states.
|
|
"""
|
|
def __init__(self, config):
|
|
super(CaptionBertAttention, self).__init__(config)
|
|
self.self = CaptionBertSelfAttention(config)
|
|
self.output = BertSelfOutput(config)
|
|
self.config = config
|
|
|
|
def forward(self, mode, input_tensor, attention_mask, head_mask=None,
|
|
history_state=None):
|
|
''' transformer processing '''
|
|
self_outputs = self.self(mode, input_tensor, attention_mask, head_mask, history_state)
|
|
|
|
''' feed-forward network with residule '''
|
|
if mode == 'visual':
|
|
attention_output = self.output(self_outputs[0], input_tensor[:, [0]+list(range(-self.config.directions, 0)), :])
|
|
if mode == 'language':
|
|
attention_output = self.output(self_outputs[0], input_tensor)
|
|
|
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
|
|
|
return outputs
|
|
|
|
|
|
class CaptionBertLayer(BertLayer):
|
|
"""
|
|
Modified from BertLayer to add support for output_hidden_states.
|
|
"""
|
|
def __init__(self, config):
|
|
super(CaptionBertLayer, self).__init__(config)
|
|
self.attention = CaptionBertAttention(config)
|
|
self.intermediate = BertIntermediate(config)
|
|
self.output = BertOutput(config)
|
|
|
|
def forward(self, mode, hidden_states, attention_mask, head_mask=None,
|
|
history_state=None):
|
|
|
|
attention_outputs = self.attention(mode, hidden_states, attention_mask,
|
|
head_mask, history_state)
|
|
|
|
''' feed-forward network with residule '''
|
|
attention_output = attention_outputs[0]
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
outputs = (layer_output,) + attention_outputs[1:]
|
|
|
|
return outputs
|
|
|
|
|
|
class CaptionBertEncoder(BertEncoder):
|
|
"""
|
|
Modified from BertEncoder to add support for output_hidden_states.
|
|
"""
|
|
def __init__(self, config):
|
|
super(CaptionBertEncoder, self).__init__(config)
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
# 12 Bert layers
|
|
self.layer = nn.ModuleList([CaptionBertLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.config = config
|
|
|
|
def forward(self, mode, hidden_states, attention_mask, head_mask=None,
|
|
encoder_history_states=None):
|
|
|
|
if mode == 'visual':
|
|
for i, layer_module in enumerate(self.layer):
|
|
history_state = None if encoder_history_states is None else encoder_history_states[i]
|
|
|
|
layer_outputs = layer_module(mode,
|
|
hidden_states, attention_mask, head_mask[i],
|
|
history_state)
|
|
|
|
concat_layer_outputs = torch.cat((layer_outputs[0][:,0:1,:], hidden_states[:,1:-self.config.directions,:], layer_outputs[0][:,1:self.config.directions+1,:]), 1)
|
|
hidden_states = concat_layer_outputs
|
|
|
|
if i == self.config.num_hidden_layers - 1:
|
|
state_attention_score = layer_outputs[1][:, :, 0, :]
|
|
lang_attention_score = layer_outputs[1][:, :, -self.config.directions:, 1:-self.config.directions]
|
|
vis_attention_score = layer_outputs[1][:, :, :, :]
|
|
|
|
outputs = (hidden_states, state_attention_score, lang_attention_score, vis_attention_score)
|
|
|
|
elif mode == 'language':
|
|
for i, layer_module in enumerate(self.layer):
|
|
history_state = None if encoder_history_states is None else encoder_history_states[i] # default None
|
|
|
|
layer_outputs = layer_module(mode,
|
|
hidden_states, attention_mask, head_mask[i],
|
|
history_state)
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if i == self.config.num_hidden_layers - 1:
|
|
slang_attention_score = layer_outputs[1]
|
|
|
|
outputs = (hidden_states, slang_attention_score)
|
|
|
|
return outputs
|
|
|
|
|
|
class BertImgModel(BertPreTrainedModel):
|
|
""" Expand from BertModel to handle image region features as input
|
|
"""
|
|
def __init__(self, config):
|
|
super(BertImgModel, self).__init__(config)
|
|
self.embeddings = BertEmbeddings(config)
|
|
self.encoder = CaptionBertEncoder(config)
|
|
self.pooler = BertPooler(config)
|
|
|
|
self.img_dim = config.img_feature_dim
|
|
logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
|
|
|
|
# self.apply(self.init_weights)
|
|
self.init_weights()
|
|
|
|
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
|
position_ids=None, img_feats=None):
|
|
|
|
if attention_mask.dim() == 2:
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
elif attention_mask.dim() == 3:
|
|
extended_attention_mask = attention_mask.unsqueeze(1)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
head_mask = [None] * self.config.num_hidden_layers
|
|
|
|
if mode == 'visual':
|
|
language_features = input_ids
|
|
concat_embedding_output = torch.cat((language_features, img_feats), 1)
|
|
elif mode == 'language':
|
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids,
|
|
token_type_ids=token_type_ids)
|
|
concat_embedding_output = embedding_output
|
|
|
|
''' pass to the Transformer layers '''
|
|
encoder_outputs = self.encoder(mode, concat_embedding_output,
|
|
extended_attention_mask, head_mask=head_mask)
|
|
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output) # We "pool" the model by simply taking the hidden state corresponding to the first token
|
|
|
|
# add hidden_states and attentions if they are here
|
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
|
|
|
return outputs
|
|
|
|
|
|
class VLNBert(BertPreTrainedModel):
|
|
"""
|
|
Modified from BertForMultipleChoice to support oscar training.
|
|
"""
|
|
def __init__(self, config):
|
|
super(VLNBert, self).__init__(config)
|
|
self.config = config
|
|
self.bert = BertImgModel(config)
|
|
|
|
self.vis_lang_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.state_proj = nn.Linear(config.hidden_size*2, config.hidden_size, bias=True)
|
|
self.state_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
# self.apply(self.init_weights)
|
|
self.init_weights()
|
|
|
|
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
|
position_ids=None, img_feats=None):
|
|
|
|
outputs = self.bert(mode, input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
|
attention_mask=attention_mask, img_feats=img_feats)
|
|
|
|
sequence_output = outputs[0]
|
|
sequence_output = self.dropout(sequence_output)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
if mode == 'language':
|
|
return sequence_output
|
|
|
|
elif mode == 'visual':
|
|
# attention scores with respect to agent's state
|
|
language_attentions = outputs[2][:, :, 1:-self.config.directions]
|
|
visual_attentions = outputs[2][:, :, -self.config.directions:]
|
|
|
|
language_attention_scores = language_attentions.mean(dim=1) # mean over the 12 heads
|
|
visual_attention_scores = visual_attentions.mean(dim=1)
|
|
|
|
# weighted_feat
|
|
language_attention_probs = nn.Softmax(dim=-1)(language_attention_scores.clone()).unsqueeze(-1)
|
|
visual_attention_probs = nn.Softmax(dim=-1)(visual_attention_scores.clone()).unsqueeze(-1)
|
|
|
|
language_seq = sequence_output[:, 1:-self.config.directions, :]
|
|
visual_seq = sequence_output[:, -self.config.directions:, :]
|
|
|
|
# residual weighting, final attention to weight the raw inputs
|
|
attended_language = (language_attention_probs * input_ids[:, 1:, :]).sum(1)
|
|
attended_visual = (visual_attention_probs * img_feats).sum(1)
|
|
|
|
# update agent's state, unify history, language and vision by elementwise product
|
|
vis_lang_feat = self.vis_lang_LayerNorm(attended_language * attended_visual)
|
|
state_output = torch.cat((pooled_output, vis_lang_feat), dim=-1)
|
|
state_proj = self.state_proj(state_output)
|
|
state_proj = self.state_LayerNorm(state_proj)
|
|
|
|
return state_proj, visual_attention_scores
|