adversarial_VLNDUET/map_nav_src/models/model.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

55 lines
1.6 KiB
Python

import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertPreTrainedModel
from .vlnbert_init import get_vlnbert_models
class VLNBert(nn.Module):
def __init__(self, args):
super().__init__()
print('\nInitalizing the VLN-BERT model ...')
self.args = args
self.vln_bert = get_vlnbert_models(args, config=None) # initialize the VLN-BERT
self.drop_env = nn.Dropout(p=args.feat_dropout)
def forward(self, mode, batch):
batch = collections.defaultdict(lambda: None, batch)
if mode == 'language':
txt_embeds = self.vln_bert(mode, batch)
return txt_embeds
elif mode == 'panorama':
batch['view_img_fts'] = self.drop_env(batch['view_img_fts'])
if 'obj_img_fts' in batch:
batch['obj_img_fts'] = self.drop_env(batch['obj_img_fts'])
pano_embeds, pano_masks = self.vln_bert(mode, batch)
return pano_embeds, pano_masks
elif mode == 'navigation':
outs = self.vln_bert(mode, batch)
return outs
else:
raise NotImplementedError('wrong mode: %s'%mode)
class Critic(nn.Module):
def __init__(self, args):
super(Critic, self).__init__()
self.state2value = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(args.dropout),
nn.Linear(512, 1),
)
def forward(self, state):
return self.state2value(state).squeeze()