41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import (Dataset, DataLoader, RandomSampler, SequentialSampler, TensorDataset)
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from tqdm import tqdm, trange
|
|
import _pickle as cPickle
|
|
|
|
import sys
|
|
from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer)
|
|
from transformers import AdamW
|
|
from transformers import get_linear_schedule_with_warmup as WarmupLinearSchedule
|
|
|
|
# from vlnbert.modeling_bert import LanguageBert
|
|
from vlnbert.modeling_visbert import VLNBert
|
|
|
|
model_name_or_path = 'r2r_src/vlnbert/Prevalent/pretrained_model/pytorch_model.bin'
|
|
|
|
def get_tokenizer():
|
|
tokenizer_class = BertTokenizer
|
|
tokenizer = tokenizer_class.from_pretrained('bert-base-uncased')
|
|
return tokenizer
|
|
|
|
def get_vlnbert_models(config=None):
|
|
config_class = BertConfig
|
|
model_class = VLNBert
|
|
vis_config = config_class.from_pretrained('bert-base-uncased')
|
|
|
|
# all configurations (need to pack into args)
|
|
vis_config.img_feature_dim = 2176
|
|
vis_config.img_feature_type = ""
|
|
vis_config.update_lang_bert = False
|
|
vis_config.update_add_layer = False
|
|
vis_config.vl_layers = 4
|
|
vis_config.la_layers = 9
|
|
visual_model = VLNBert(vis_config)
|
|
|
|
visual_model = model_class.from_pretrained(model_name_or_path, config=vis_config)
|
|
|
|
return visual_model
|