adversarial_VLNDUET/pretrain_src/model/pretrain_cmt.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

366 lines
18 KiB
Python

from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertPreTrainedModel
from .vilmodel import BertLayerNorm, BertOnlyMLMHead, GlocalTextPathCMT
from .ops import pad_tensors_wgrad, gen_seq_masks
class RegionClassification(nn.Module):
" for MRC(-kl)"
def __init__(self, hidden_size, label_dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
BertLayerNorm(hidden_size, eps=1e-12),
nn.Linear(hidden_size, label_dim))
def forward(self, input_):
output = self.net(input_)
return output
class ClsPrediction(nn.Module):
def __init__(self, hidden_size, input_size=None):
super().__init__()
if input_size is None:
input_size = hidden_size
self.net = nn.Sequential(nn.Linear(input_size, hidden_size),
nn.ReLU(),
BertLayerNorm(hidden_size, eps=1e-12),
nn.Linear(hidden_size, 1))
def forward(self, x):
return self.net(x)
class GlocalTextPathCMTPreTraining(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.bert = GlocalTextPathCMT(config)
if 'mlm' in config.pretrain_tasks:
self.mlm_head = BertOnlyMLMHead(self.config)
if 'mrc' in config.pretrain_tasks:
self.image_classifier = RegionClassification(self.config.hidden_size, self.config.image_prob_size)
if self.config.obj_prob_size > 0 and self.config.obj_prob_size != self.config.image_prob_size:
self.obj_classifier = RegionClassification(self.config.hidden_size, self.config.obj_prob_size)
else:
self.obj_classifier = None
if 'sap' in config.pretrain_tasks:
self.global_sap_head = ClsPrediction(self.config.hidden_size)
self.local_sap_head = ClsPrediction(self.config.hidden_size)
if config.glocal_fuse:
self.sap_fuse_linear = ClsPrediction(self.config.hidden_size, input_size=self.config.hidden_size*2)
else:
self.sap_fuse_linear = None
if 'og' in config.pretrain_tasks:
self.og_head = ClsPrediction(self.config.hidden_size)
self.init_weights()
self.tie_weights()
def tie_weights(self):
if 'mlm' in self.config.pretrain_tasks:
self._tie_or_clone_weights(self.mlm_head.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, batch, task, compute_loss=True):
batch = defaultdict(lambda: None, batch)
if task.startswith('mlm'):
return self.forward_mlm(
batch['txt_ids'], batch['txt_lens'], batch['traj_view_img_fts'],
batch['traj_obj_img_fts'], batch['traj_loc_fts'], batch['traj_nav_types'],
batch['traj_step_lens'], batch['traj_vp_view_lens'], batch['traj_vp_obj_lens'],
batch['traj_vpids'], batch['traj_cand_vpids'],
batch['gmap_lens'], batch['gmap_step_ids'], batch['gmap_pos_fts'],
batch['gmap_pair_dists'], batch['gmap_vpids'], batch['vp_pos_fts'],
batch['txt_labels'], compute_loss
)
elif task.startswith('mrc'):
return self.forward_mrc(
batch['txt_ids'], batch['txt_lens'], batch['traj_view_img_fts'],
batch['traj_obj_img_fts'], batch['traj_loc_fts'], batch['traj_nav_types'],
batch['traj_step_lens'], batch['traj_vp_view_lens'], batch['traj_vp_obj_lens'],
batch['traj_vpids'], batch['traj_cand_vpids'],
batch['gmap_lens'], batch['gmap_step_ids'], batch['gmap_pos_fts'],
batch['gmap_pair_dists'], batch['gmap_vpids'], batch['vp_pos_fts'],
batch['vp_view_mrc_masks'], batch['vp_view_probs'],
batch['vp_obj_mrc_masks'], batch['vp_obj_probs'], compute_loss
)
elif task.startswith('sap'):
return self.forward_sap(
batch['txt_ids'], batch['txt_lens'], batch['traj_view_img_fts'],
batch['traj_obj_img_fts'], batch['traj_loc_fts'], batch['traj_nav_types'],
batch['traj_step_lens'], batch['traj_vp_view_lens'], batch['traj_vp_obj_lens'],
batch['traj_vpids'], batch['traj_cand_vpids'],
batch['gmap_lens'], batch['gmap_step_ids'], batch['gmap_pos_fts'],
batch['gmap_pair_dists'], batch['gmap_vpids'], batch['vp_pos_fts'],
batch['gmap_visited_masks'],
batch['global_act_labels'], batch['local_act_labels'], compute_loss
)
elif task.startswith('og'):
return self.forward_og(
batch['txt_ids'], batch['txt_lens'], batch['traj_view_img_fts'],
batch['traj_obj_img_fts'], batch['traj_loc_fts'], batch['traj_nav_types'],
batch['traj_step_lens'], batch['traj_vp_view_lens'], batch['traj_vp_obj_lens'],
batch['traj_vpids'], batch['traj_cand_vpids'],
batch['gmap_lens'], batch['gmap_step_ids'], batch['gmap_pos_fts'],
batch['gmap_pair_dists'], batch['gmap_vpids'], batch['vp_pos_fts'],
batch['obj_labels'], compute_loss
)
elif task.startswith('valid_sap_og'):
return self.forward_sap_og(
batch['txt_ids'], batch['txt_lens'], batch['traj_view_img_fts'],
batch['traj_obj_img_fts'], batch['traj_loc_fts'], batch['traj_nav_types'],
batch['traj_step_lens'], batch['traj_vp_view_lens'], batch['traj_vp_obj_lens'],
batch['traj_vpids'], batch['traj_cand_vpids'],
batch['gmap_lens'], batch['gmap_step_ids'], batch['gmap_pos_fts'],
batch['gmap_pair_dists'], batch['gmap_vpids'], batch['vp_pos_fts'],
batch['gmap_visited_masks'], batch['global_act_labels'], batch['local_act_labels'],
batch['obj_labels']
)
else:
raise ValueError('invalid task')
def forward_mlm(
self, txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
txt_labels, compute_loss
):
txt_embeds = self.bert.forward_mlm(
txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
)
# only compute masked tokens for better efficiency
masked_output = self._compute_masked_hidden(txt_embeds, txt_labels != -1)
prediction_scores = self.mlm_head(masked_output)
if compute_loss:
mask_loss = F.cross_entropy(
prediction_scores, txt_labels[txt_labels != -1], reduction='none'
)
return mask_loss
else:
return prediction_scores
def _compute_masked_hidden(self, hidden, mask):
'''get only the masked region (don't compute unnecessary hiddens)'''
mask = mask.unsqueeze(-1).expand_as(hidden)
hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1))
return hidden_masked
def forward_mrc(
self, txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
vp_view_mrc_masks, vp_view_probs, vp_obj_mrc_masks, vp_obj_probs, compute_loss=True
):
_, vp_embeds = self.bert(
txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
return_gmap_embeds=False
)
vp_view_lens = [x[-1] for x in torch.split(traj_vp_view_lens, traj_step_lens)]
vp_view_embeds = pad_tensors_wgrad(
[x[1:view_len+1] for x, view_len in zip(vp_embeds, vp_view_lens)]
) # [stop] at 0
# vp_view_mrc_masks = vp_view_mrc_masks[:, :vp_view_embeds.size(1)]
# only compute masked regions for better efficient=cy
view_masked_output = self._compute_masked_hidden(vp_view_embeds, vp_view_mrc_masks)
view_prediction_soft_labels = self.image_classifier(view_masked_output)
view_mrc_targets = self._compute_masked_hidden(vp_view_probs, vp_view_mrc_masks)
if traj_obj_img_fts is not None:
vp_obj_lens = [x[-1] for x in torch.split(traj_vp_obj_lens, traj_step_lens)]
vp_obj_embeds = pad_tensors_wgrad(
[x[view_len+1:view_len+obj_len+1] for x, view_len, obj_len in zip(vp_embeds, vp_view_lens, vp_obj_lens)]
)
# vp_obj_mrc_masks = vp_obj_mrc_masks[:, :vp_obj_embeds.size(1)]
obj_masked_output = self._compute_masked_hidden(vp_obj_embeds, vp_obj_mrc_masks)
if self.obj_classifier is None:
obj_prediction_soft_labels = self.image_classifier(obj_masked_output)
else:
obj_prediction_soft_labels = self.obj_classifier(obj_masked_output)
obj_mrc_targets = self._compute_masked_hidden(vp_obj_probs, vp_obj_mrc_masks)
else:
obj_prediction_soft_labels, obj_mrc_targets = None, None
if compute_loss:
view_prediction_soft_labels = F.log_softmax(view_prediction_soft_labels, dim=-1)
view_mrc_loss = F.kl_div(view_prediction_soft_labels, view_mrc_targets, reduction='none').sum(dim=1)
if obj_prediction_soft_labels is None:
mrc_loss = view_mrc_loss
else:
obj_prediction_soft_labels = F.log_softmax(obj_prediction_soft_labels, dim=-1)
obj_mrc_loss = F.kl_div(obj_prediction_soft_labels, obj_mrc_targets, reduction='none').sum(dim=1)
mrc_loss = torch.cat([view_mrc_loss, obj_mrc_loss], 0)
return mrc_loss
else:
return view_prediction_soft_labels, view_mrc_targets, obj_prediction_soft_labels, obj_mrc_targets
def forward_sap(
self, txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
gmap_visited_masks, global_act_labels, local_act_labels, compute_loss
):
batch_size = txt_ids.size(0)
gmap_embeds, vp_embeds = self.bert(
txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
)
if self.sap_fuse_linear is None:
fuse_weights = 0.5
else:
fuse_weights = torch.sigmoid(self.sap_fuse_linear(
torch.cat([gmap_embeds[:, 0], vp_embeds[:, 0]], 1)
))
global_logits = self.global_sap_head(gmap_embeds).squeeze(2) * fuse_weights
global_logits.masked_fill_(gmap_visited_masks, -float('inf'))
global_logits.masked_fill_(gen_seq_masks(gmap_lens).logical_not(), -float('inf'))
local_logits = self.local_sap_head(vp_embeds).squeeze(2) * (1 - fuse_weights)
vp_nav_masks = pad_tensors_wgrad(
[x[-1]!=1 for x in torch.split(traj_nav_types, traj_step_lens)]
)[:, :local_logits.size(1)-1]
vp_nav_masks = torch.cat(
[torch.zeros(len(vp_nav_masks), 1).bool().to(vp_nav_masks.device), vp_nav_masks], 1
) # add [stop]
local_logits.masked_fill_(vp_nav_masks, -float('inf'))
# fusion
fused_logits = torch.clone(global_logits)
fused_logits[:, 0] += local_logits[:, 0] # stop
for i in range(batch_size):
visited_nodes = set([vp for vp, mask in zip(gmap_vpids[i], gmap_visited_masks[i]) if mask])
tmp = {}
bw_logits = 0
for j, cand_vpid in enumerate(traj_cand_vpids[i][-1]):
if cand_vpid in visited_nodes:
bw_logits += local_logits[i, j+1]
else:
tmp[cand_vpid] = local_logits[i, j+1]
for j, vp in enumerate(gmap_vpids[i]):
if j > 0 and vp not in visited_nodes:
if vp in tmp:
fused_logits[i, j] += tmp[vp]
else:
fused_logits[i, j] += bw_logits
if compute_loss:
global_losses = F.cross_entropy(global_logits, global_act_labels, reduction='none')
local_losses = F.cross_entropy(local_logits, local_act_labels, reduction='none')
fused_losses = F.cross_entropy(fused_logits, global_act_labels, reduction='none')
losses = global_losses + local_losses + fused_losses
return losses
else:
return global_logits, local_logits, fused_logits, global_act_labels, local_act_labels
def forward_og(
self, txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
obj_labels, compute_loss
):
gmap_embeds, vp_embeds = self.bert.forward(
txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
return_gmap_embeds=False
)
vp_view_lens = [x[-1] for x in torch.split(traj_vp_view_lens, traj_step_lens, 0)]
vp_obj_lens = [x[-1] for x in torch.split(traj_vp_obj_lens, traj_step_lens, 0)]
obj_embeds = pad_tensors_wgrad([
x[1+view_len: 1+view_len+obj_len] for x, view_len, obj_len in zip(vp_embeds, vp_view_lens, vp_obj_lens)
])
obj_masks = gen_seq_masks(torch.stack(vp_obj_lens, 0))
obj_logits = self.og_head(obj_embeds).squeeze(2)
obj_logits.masked_fill_(obj_masks.logical_not(), -float('inf'))
if compute_loss:
losses = F.cross_entropy(obj_logits, obj_labels, reduction='none')
return losses
else:
return obj_logits
def forward_sap_og(
self, txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
gmap_visited_masks, global_act_labels, local_act_labels, obj_labels
):
batch_size = txt_ids.size(0)
gmap_embeds, vp_embeds = self.bert(
txt_ids, txt_lens, traj_view_img_fts, traj_obj_img_fts, traj_loc_fts, traj_nav_types,
traj_step_lens, traj_vp_view_lens, traj_vp_obj_lens, traj_vpids, traj_cand_vpids,
gmap_lens, gmap_step_ids, gmap_pos_fts, gmap_pair_dists, gmap_vpids, vp_pos_fts,
)
if self.sap_fuse_linear is None:
fuse_weights = 0.5
else:
fuse_weights = torch.sigmoid(self.sap_fuse_linear(
torch.cat([gmap_embeds[:, 0], vp_embeds[:, 0]], 1)
))
global_logits = self.global_sap_head(gmap_embeds).squeeze(2) * fuse_weights
global_logits.masked_fill_(gmap_visited_masks, -float('inf'))
global_logits.masked_fill_(gen_seq_masks(gmap_lens).logical_not(), -float('inf'))
local_logits = self.local_sap_head(vp_embeds).squeeze(2) * (1 - fuse_weights)
vp_nav_masks = pad_tensors_wgrad(
[x[-1]!=1 for x in torch.split(traj_nav_types, traj_step_lens)]
)[:, :local_logits.size(1)-1]
vp_nav_masks = torch.cat(
[torch.zeros(len(vp_nav_masks), 1).bool().to(vp_nav_masks.device), vp_nav_masks], 1
) # add [stop]
local_logits.masked_fill_(vp_nav_masks, -float('inf'))
# fusion
fused_logits = torch.clone(global_logits)
fused_logits[:, 0] += local_logits[:, 0] # stop
for i in range(batch_size):
visited_nodes = set([vp for vp, mask in zip(gmap_vpids[i], gmap_visited_masks[i]) if mask])
tmp = {}
bw_logits = 0
for j, cand_vpid in enumerate(traj_cand_vpids[i][-1]):
if cand_vpid in visited_nodes:
bw_logits += local_logits[i, j+1]
else:
tmp[cand_vpid] = local_logits[i, j+1]
for j, vp in enumerate(gmap_vpids[i]):
if j > 0 and vp not in visited_nodes:
if vp in tmp:
fused_logits[i, j] += tmp[vp]
else:
fused_logits[i, j] += bw_logits
vp_view_lens = [x[-1] for x in torch.split(traj_vp_view_lens, traj_step_lens, 0)]
vp_obj_lens = [x[-1] for x in torch.split(traj_vp_obj_lens, traj_step_lens, 0)]
obj_embeds = pad_tensors_wgrad([
x[1+view_len: 1+view_len+obj_len] for x, view_len, obj_len in zip(vp_embeds, vp_view_lens, vp_obj_lens)
])
obj_masks = gen_seq_masks(torch.stack(vp_obj_lens, 0))
obj_logits = self.og_head(obj_embeds).squeeze(2)
obj_logits.masked_fill_(obj_masks.logical_not(), -float('inf'))
return global_logits, local_logits, fused_logits, obj_logits