fix mis-deleted attentions
This commit is contained in:
parent
23e4b9be90
commit
1602aefcb5
@ -51,7 +51,7 @@ class VLNBERT(nn.Module):
|
|||||||
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
|
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
|
||||||
|
|
||||||
# logit is the attention scores over the candidate features
|
# logit is the attention scores over the candidate features
|
||||||
h_t, logit = self.vln_bert(mode, state_feats,
|
h_t, logit, attended_language, attended_visual = self.vln_bert(mode, state_feats,
|
||||||
attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats)
|
attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats)
|
||||||
|
|
||||||
# update agent's state, unify history, language and vision by elementwise product
|
# update agent's state, unify history, language and vision by elementwise product
|
||||||
|
|||||||
@ -439,4 +439,11 @@ class VLNBert(BertPreTrainedModel):
|
|||||||
language_state_scores = language_attention_scores.mean(dim=1)
|
language_state_scores = language_attention_scores.mean(dim=1)
|
||||||
visual_action_scores = visual_attention_scores.mean(dim=1)
|
visual_action_scores = visual_attention_scores.mean(dim=1)
|
||||||
|
|
||||||
return pooled_output, visual_action_scores
|
# weighted_feat
|
||||||
|
language_attention_probs = nn.Softmax(dim=-1)(language_state_scores.clone()).unsqueeze(-1)
|
||||||
|
visual_attention_probs = nn.Softmax(dim=-1)(visual_action_scores.clone()).unsqueeze(-1)
|
||||||
|
|
||||||
|
attended_language = (language_attention_probs * text_embeds[:, 1:, :]).sum(1)
|
||||||
|
attended_visual = (visual_attention_probs * img_embedding_output).sum(1)
|
||||||
|
|
||||||
|
return pooled_output, visual_action_scores, attended_language, attended_visual
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user