diff --git a/r2r_src/model_PREVALENT.py b/r2r_src/model_PREVALENT.py index 1d24998..2100a71 100644 --- a/r2r_src/model_PREVALENT.py +++ b/r2r_src/model_PREVALENT.py @@ -51,7 +51,7 @@ class VLNBERT(nn.Module): cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size]) # logit is the attention scores over the candidate features - h_t, logit = self.vln_bert(mode, state_feats, + h_t, logit, attended_language, attended_visual = self.vln_bert(mode, state_feats, attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats) # update agent's state, unify history, language and vision by elementwise product diff --git a/r2r_src/vlnbert/vlnbert_PREVALENT.py b/r2r_src/vlnbert/vlnbert_PREVALENT.py index 84ee44f..4e3ee30 100644 --- a/r2r_src/vlnbert/vlnbert_PREVALENT.py +++ b/r2r_src/vlnbert/vlnbert_PREVALENT.py @@ -439,4 +439,11 @@ class VLNBert(BertPreTrainedModel): language_state_scores = language_attention_scores.mean(dim=1) visual_action_scores = visual_attention_scores.mean(dim=1) - return pooled_output, visual_action_scores + # weighted_feat + language_attention_probs = nn.Softmax(dim=-1)(language_state_scores.clone()).unsqueeze(-1) + visual_attention_probs = nn.Softmax(dim=-1)(visual_action_scores.clone()).unsqueeze(-1) + + attended_language = (language_attention_probs * text_embeds[:, 1:, :]).sum(1) + attended_visual = (visual_attention_probs * img_embedding_output).sum(1) + + return pooled_output, visual_action_scores, attended_language, attended_visual