fix mis-deleted attentions

This commit is contained in:
Yicong Hong 2021-01-14 22:08:12 +11:00
parent 23e4b9be90
commit 1602aefcb5
2 changed files with 9 additions and 2 deletions

View File

@ -51,7 +51,7 @@ class VLNBERT(nn.Module):
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size]) cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
# logit is the attention scores over the candidate features # logit is the attention scores over the candidate features
h_t, logit = self.vln_bert(mode, state_feats, h_t, logit, attended_language, attended_visual = self.vln_bert(mode, state_feats,
attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats) attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats)
# update agent's state, unify history, language and vision by elementwise product # update agent's state, unify history, language and vision by elementwise product

View File

@ -439,4 +439,11 @@ class VLNBert(BertPreTrainedModel):
language_state_scores = language_attention_scores.mean(dim=1) language_state_scores = language_attention_scores.mean(dim=1)
visual_action_scores = visual_attention_scores.mean(dim=1) visual_action_scores = visual_attention_scores.mean(dim=1)
return pooled_output, visual_action_scores # weighted_feat
language_attention_probs = nn.Softmax(dim=-1)(language_state_scores.clone()).unsqueeze(-1)
visual_attention_probs = nn.Softmax(dim=-1)(visual_action_scores.clone()).unsqueeze(-1)
attended_language = (language_attention_probs * text_embeds[:, 1:, :]).sum(1)
attended_visual = (visual_attention_probs * img_embedding_output).sum(1)
return pooled_output, visual_action_scores, attended_language, attended_visual