adversarial_VLNDUET/pretrain_src/model/vilmodel.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

748 lines
31 KiB
Python

import json
import logging
import math
import os
import sys
from io import open
from typing import Callable, List, Tuple
import numpy as np
import copy
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor, device, dtype
from transformers import BertPreTrainedModel
from .ops import create_transformer_encoder
from .ops import extend_neg_masks, gen_seq_masks, pad_tensors_wgrad
logger = logging.getLogger(__name__)
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except (ImportError, AttributeError) as e:
# logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
BertLayerNorm = torch.nn.LayerNorm
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, head_mask=None):
"""
hidden_states: (N, L_{hidden}, D)
attention_mask: (N, H, L_{hidden}, L_{hidden})
"""
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# recurrent vlnbert use attention scores
outputs = (context_layer, attention_scores) if self.output_attentions else (context_layer,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states, attention_mask,
None if head_mask is None else head_mask[i],
)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOutAttention(nn.Module):
def __init__(self, config, ctx_dim=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
if ctx_dim is None:
ctx_dim = config.hidden_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(ctx_dim, self.all_head_size)
self.value = nn.Linear(ctx_dim, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, context, attention_mask=None):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(context)
mixed_value_layer = self.value(context)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, attention_scores
class BertXAttention(nn.Module):
def __init__(self, config, ctx_dim=None):
super().__init__()
self.att = BertOutAttention(config, ctx_dim=ctx_dim)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
output, attention_scores = self.att(input_tensor, ctx_tensor, ctx_att_mask)
attention_output = self.output(output, input_tensor)
return attention_output, attention_scores
class GraphLXRTXLayer(nn.Module):
def __init__(self, config):
super().__init__()
# Lang self-att and FFN layer
if config.use_lang2visn_attn:
self.lang_self_att = BertAttention(config)
self.lang_inter = BertIntermediate(config)
self.lang_output = BertOutput(config)
# Visn self-att and FFN layer
self.visn_self_att = BertAttention(config)
self.visn_inter = BertIntermediate(config)
self.visn_output = BertOutput(config)
# The cross attention layer
self.visual_attention = BertXAttention(config)
def forward(
self, lang_feats, lang_attention_mask, visn_feats, visn_attention_mask,
graph_sprels=None
):
visn_att_output = self.visual_attention(
visn_feats, lang_feats, ctx_att_mask=lang_attention_mask
)[0]
if graph_sprels is not None:
visn_attention_mask = visn_attention_mask + graph_sprels
visn_att_output = self.visn_self_att(visn_att_output, visn_attention_mask)[0]
visn_inter_output = self.visn_inter(visn_att_output)
visn_output = self.visn_output(visn_inter_output, visn_att_output)
return visn_output
def forward_lang2visn(
self, lang_feats, lang_attention_mask, visn_feats, visn_attention_mask,
):
lang_att_output = self.visual_attention(
lang_feats, visn_feats, ctx_att_mask=visn_attention_mask
)[0]
lang_att_output = self.lang_self_att(
lang_att_output, lang_attention_mask
)[0]
lang_inter_output = self.lang_inter(lang_att_output)
lang_output = self.lang_output(lang_inter_output, lang_att_output)
return lang_output
class LanguageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_l_layers = config.num_l_layers
self.update_lang_bert = config.update_lang_bert
self.layer = nn.ModuleList(
[BertLayer(config) for _ in range(self.num_l_layers)]
)
if not self.update_lang_bert:
for name, param in self.layer.named_parameters():
param.requires_grad = False
def forward(self, txt_embeds, txt_masks):
extended_txt_masks = extend_neg_masks(txt_masks)
for layer_module in self.layer:
temp_output = layer_module(txt_embeds, extended_txt_masks)
txt_embeds = temp_output[0]
if not self.update_lang_bert:
txt_embeds = txt_embeds.detach()
return txt_embeds
class CrossmodalEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.num_x_layers = config.num_x_layers
self.x_layers = nn.ModuleList(
[GraphLXRTXLayer(config) for _ in range(self.num_x_layers)]
)
def forward(self, txt_embeds, txt_masks, img_embeds, img_masks, graph_sprels=None):
extended_txt_masks = extend_neg_masks(txt_masks)
extended_img_masks = extend_neg_masks(img_masks) # (N, 1(H), 1(L_q), L_v)
for layer_module in self.x_layers:
img_embeds = layer_module(
txt_embeds, extended_txt_masks,
img_embeds, extended_img_masks,
graph_sprels=graph_sprels
)
return img_embeds
class ImageEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.img_linear = nn.Linear(config.image_feat_size, config.hidden_size)
self.img_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.loc_linear = nn.Linear(config.angle_feat_size + 3, config.hidden_size)
self.loc_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
if config.obj_feat_size > 0 and config.obj_feat_size != config.image_feat_size:
self.obj_linear = nn.Linear(config.obj_feat_size, config.hidden_size)
self.obj_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
else:
self.obj_linear = self.obj_layer_norm = None
# 0: non-navigable, 1: navigable, 2: object
self.nav_type_embedding = nn.Embedding(3, config.hidden_size)
# tf naming convention for layer norm
self.layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if config.num_pano_layers > 0:
self.pano_encoder = create_transformer_encoder(
config, config.num_pano_layers, norm=True
)
else:
self.pano_encoder = None
def forward(
self, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, type_embed_layer
):
device = traj_view_img_fts.device
has_obj = traj_obj_img_fts is not None
traj_view_img_embeds = self.img_layer_norm(self.img_linear(traj_view_img_fts))
if has_obj:
if self.obj_linear is None:
traj_obj_img_embeds = self.img_layer_norm(self.img_linear(traj_obj_img_fts))
else:
traj_obj_img_embeds = self.obj_layer_norm(self.obj_linear(traj_obj_img_fts))
traj_img_embeds = []
for view_embed, obj_embed, view_len, obj_len in zip(
traj_view_img_embeds, traj_obj_img_embeds, traj_vp_view_lens, traj_vp_obj_lens
):
if obj_len > 0:
traj_img_embeds.append(torch.cat([view_embed[:view_len], obj_embed[:obj_len]], 0))
else:
traj_img_embeds.append(view_embed[:view_len])
traj_img_embeds = pad_tensors_wgrad(traj_img_embeds)
traj_vp_lens = traj_vp_view_lens + traj_vp_obj_lens
else:
traj_img_embeds = traj_view_img_embeds
traj_vp_lens = traj_vp_view_lens
traj_embeds = traj_img_embeds + \
self.loc_layer_norm(self.loc_linear(traj_loc_fts)) + \
self.nav_type_embedding(traj_nav_types) + \
type_embed_layer(torch.ones(1, 1).long().to(device))
traj_embeds = self.layer_norm(traj_embeds)
traj_embeds = self.dropout(traj_embeds)
traj_masks = gen_seq_masks(traj_vp_lens)
if self.pano_encoder is not None:
traj_embeds = self.pano_encoder(
traj_embeds, src_key_padding_mask=traj_masks.logical_not()
)
split_traj_embeds = torch.split(traj_embeds, traj_step_lens, 0)
split_traj_vp_lens = torch.split(traj_vp_lens, traj_step_lens, 0)
return split_traj_embeds, split_traj_vp_lens
class LocalVPEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.vp_pos_embeddings = nn.Sequential(
nn.Linear(config.angle_feat_size*2 + 6, config.hidden_size),
BertLayerNorm(config.hidden_size, eps=1e-12)
)
self.encoder = CrossmodalEncoder(config)
def vp_input_embedding(self, split_traj_embeds, split_traj_vp_lens, vp_pos_fts):
vp_img_embeds = pad_tensors_wgrad([x[-1] for x in split_traj_embeds])
vp_lens = torch.stack([x[-1]+1 for x in split_traj_vp_lens], 0)
vp_masks = gen_seq_masks(vp_lens)
max_vp_len = max(vp_lens)
batch_size, _, hidden_size = vp_img_embeds.size()
device = vp_img_embeds.device
# add [stop] token at beginning
vp_img_embeds = torch.cat(
[torch.zeros(batch_size, 1, hidden_size).to(device), vp_img_embeds], 1
)[:, :max_vp_len]
vp_embeds = vp_img_embeds + self.vp_pos_embeddings(vp_pos_fts)
return vp_embeds, vp_masks
def forward(
self, txt_embeds, txt_masks, split_traj_embeds, split_traj_vp_lens, vp_pos_fts
):
vp_embeds, vp_masks = self.vp_input_embedding(
split_traj_embeds, split_traj_vp_lens, vp_pos_fts
)
vp_embeds = self.encoder(txt_embeds, txt_masks, vp_embeds, vp_masks)
return vp_embeds
class GlobalMapEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.gmap_pos_embeddings = nn.Sequential(
nn.Linear(config.angle_feat_size + 3, config.hidden_size),
BertLayerNorm(config.hidden_size, eps=1e-12)
)
self.gmap_step_embeddings = nn.Embedding(config.max_action_steps, config.hidden_size)
self.encoder = CrossmodalEncoder(config)
if config.graph_sprels:
self.sprel_linear = nn.Linear(1, 1)
else:
self.sprel_linear = None
def _aggregate_gmap_features(
self, split_traj_embeds, split_traj_vp_lens, traj_vpids, traj_cand_vpids, gmap_vpids
):
batch_size = len(split_traj_embeds)
device = split_traj_embeds[0].device
batch_gmap_img_fts = []
for i in range(batch_size):
visited_vp_fts, unvisited_vp_fts = {}, {}
vp_masks = gen_seq_masks(split_traj_vp_lens[i])
max_vp_len = max(split_traj_vp_lens[i])
i_traj_embeds = split_traj_embeds[i][:, :max_vp_len] * vp_masks.unsqueeze(2)
for t in range(len(split_traj_embeds[i])):
visited_vp_fts[traj_vpids[i][t]] = torch.sum(i_traj_embeds[t], 0) / split_traj_vp_lens[i][t]
for j, vp in enumerate(traj_cand_vpids[i][t]):
if vp not in visited_vp_fts:
unvisited_vp_fts.setdefault(vp, [])
unvisited_vp_fts[vp].append(i_traj_embeds[t][j])
gmap_img_fts = []
for vp in gmap_vpids[i][1:]:
if vp in visited_vp_fts:
gmap_img_fts.append(visited_vp_fts[vp])
else:
gmap_img_fts.append(torch.mean(torch.stack(unvisited_vp_fts[vp], 0), 0))
gmap_img_fts = torch.stack(gmap_img_fts, 0)
batch_gmap_img_fts.append(gmap_img_fts)
batch_gmap_img_fts = pad_tensors_wgrad(batch_gmap_img_fts)
# add a [stop] token at beginning
batch_gmap_img_fts = torch.cat(
[torch.zeros(batch_size, 1, batch_gmap_img_fts.size(2)).to(device), batch_gmap_img_fts],
dim=1
)
return batch_gmap_img_fts
def gmap_input_embedding(
self, split_traj_embeds, split_traj_vp_lens, traj_vpids, traj_cand_vpids, gmap_vpids,
gmap_step_ids, gmap_pos_fts, gmap_lens
):
gmap_img_fts = self._aggregate_gmap_features(
split_traj_embeds, split_traj_vp_lens, traj_vpids, traj_cand_vpids, gmap_vpids
)
gmap_embeds = gmap_img_fts + \
self.gmap_step_embeddings(gmap_step_ids) + \
self.gmap_pos_embeddings(gmap_pos_fts)
gmap_masks = gen_seq_masks(gmap_lens)
return gmap_embeds, gmap_masks
def forward(
self, txt_embeds, txt_masks,
split_traj_embeds, split_traj_vp_lens, traj_vpids, traj_cand_vpids, gmap_vpids,
gmap_step_ids, gmap_pos_fts, gmap_lens, graph_sprels=None
):
gmap_embeds, gmap_masks = self.gmap_input_embedding(
split_traj_embeds, split_traj_vp_lens, traj_vpids, traj_cand_vpids, gmap_vpids,
gmap_step_ids, gmap_pos_fts, gmap_lens
)
if self.sprel_linear is not None:
graph_sprels = self.sprel_linear(graph_sprels.unsqueeze(3)).squeeze(3).unsqueeze(1)
else:
graph_sprels = None
gmap_embeds = self.encoder(
txt_embeds, txt_masks, gmap_embeds, gmap_masks,
graph_sprels=graph_sprels
)
return gmap_embeds
class GlocalTextPathCMT(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = BertEmbeddings(config)
self.lang_encoder = LanguageEncoder(config)
self.img_embeddings = ImageEmbeddings(config)
self.local_encoder = LocalVPEncoder(config)
self.global_encoder = GlobalMapEncoder(config)
self.init_weights()
def forward(
self, txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
return_gmap_embeds=True
):
# text embedding
txt_token_type_ids = torch.zeros_like(txt_ids)
txt_embeds = self.embeddings(txt_ids, token_type_ids=txt_token_type_ids)
txt_masks = gen_seq_masks(txt_lens)
txt_embeds = self.lang_encoder(txt_embeds, txt_masks)
# trajectory embedding
split_traj_embeds, split_traj_vp_lens = self.img_embeddings(
traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens,
self.embeddings.token_type_embeddings
)
# gmap embeds
if return_gmap_embeds:
gmap_embeds = self.global_encoder(
txt_embeds, txt_masks,
split_traj_embeds, split_traj_vp_lens, traj_vpids, traj_cand_vpids, gmap_vpids,
gmap_step_ids, gmap_pos_fts, gmap_lens, graph_sprels=gmap_pair_dists,
)
else:
gmap_embeds = None
# vp embeds
vp_embeds = self.local_encoder(
txt_embeds, txt_masks,
split_traj_embeds, split_traj_vp_lens, vp_pos_fts
)
return gmap_embeds, vp_embeds
def forward_mlm(
self, txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
):
# text embedding
txt_token_type_ids = torch.zeros_like(txt_ids)
txt_embeds = self.embeddings(txt_ids, token_type_ids=txt_token_type_ids)
txt_masks = gen_seq_masks(txt_lens)
txt_embeds = self.lang_encoder(txt_embeds, txt_masks)
extended_txt_masks = extend_neg_masks(txt_masks)
# trajectory embedding
split_traj_embeds, split_traj_vp_lens = self.img_embeddings(
traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens,
self.embeddings.token_type_embeddings
)
# gmap embeds
gmap_input_embeds, gmap_masks = self.global_encoder.gmap_input_embedding(
split_traj_embeds, split_traj_vp_lens, traj_vpids, traj_cand_vpids, gmap_vpids,
gmap_step_ids, gmap_pos_fts, gmap_lens
)
gmap_txt_embeds = txt_embeds
extended_gmap_masks = extend_neg_masks(gmap_masks)
for layer_module in self.global_encoder.encoder.x_layers:
gmap_txt_embeds = layer_module.forward_lang2visn(
gmap_txt_embeds, extended_txt_masks,
gmap_input_embeds, extended_gmap_masks,
)
# vp embeds
vp_input_embeds, vp_masks = self.local_encoder.vp_input_embedding(
split_traj_embeds, split_traj_vp_lens, vp_pos_fts
)
vp_txt_embeds = txt_embeds
extended_vp_masks = extend_neg_masks(vp_masks)
for layer_module in self.local_encoder.encoder.x_layers:
vp_txt_embeds = layer_module.forward_lang2visn(
vp_txt_embeds, extended_txt_masks,
vp_input_embeds, extended_vp_masks,
)
txt_embeds = gmap_txt_embeds + vp_txt_embeds
return txt_embeds