fix mis-deleted attentions
This commit is contained in:
parent
23e4b9be90
commit
1602aefcb5
@ -51,7 +51,7 @@ class VLNBERT(nn.Module):
|
||||
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
|
||||
|
||||
# logit is the attention scores over the candidate features
|
||||
h_t, logit = self.vln_bert(mode, state_feats,
|
||||
h_t, logit, attended_language, attended_visual = self.vln_bert(mode, state_feats,
|
||||
attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats)
|
||||
|
||||
# update agent's state, unify history, language and vision by elementwise product
|
||||
|
||||
@ -439,4 +439,11 @@ class VLNBert(BertPreTrainedModel):
|
||||
language_state_scores = language_attention_scores.mean(dim=1)
|
||||
visual_action_scores = visual_attention_scores.mean(dim=1)
|
||||
|
||||
return pooled_output, visual_action_scores
|
||||
# weighted_feat
|
||||
language_attention_probs = nn.Softmax(dim=-1)(language_state_scores.clone()).unsqueeze(-1)
|
||||
visual_attention_probs = nn.Softmax(dim=-1)(visual_action_scores.clone()).unsqueeze(-1)
|
||||
|
||||
attended_language = (language_attention_probs * text_embeds[:, 1:, :]).sum(1)
|
||||
attended_visual = (visual_attention_probs * img_embedding_output).sum(1)
|
||||
|
||||
return pooled_output, visual_action_scores, attended_language, attended_visual
|
||||
|
||||
Loading…
Reference in New Issue
Block a user