first version
This commit is contained in:
commit
3496085292
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
.ftpignore
|
||||
.ftpconfig
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
connectivity/*
|
||||
!connectivity/.gitkeep
|
||||
|
||||
data/R2R_test.json
|
||||
data/R2R_train.json
|
||||
data/R2R_val_seen.json
|
||||
data/R2R_val_unseen.json
|
||||
data/prevalent/*
|
||||
!data/prevalent/.gitkeep
|
||||
|
||||
img_features/*
|
||||
!img_features/.gitkeep
|
||||
|
||||
snap/*
|
||||
!snap/.gitkeep
|
||||
|
||||
logs/*
|
||||
!logs/.gitkeep
|
||||
95
README.md
Normal file
95
README.md
Normal file
@ -0,0 +1,95 @@
|
||||
# Entity-Graph-VLN
|
||||
|
||||
Code of the NeurIPS 2020 paper:
|
||||
**Language and Visual Entity Relationship Graph for Agent Navigation**<br>
|
||||
[**Yicong Hong**](http://www.yiconghong.me/), [Cristian Rodriguez-Opazo](https://crodriguezo.github.io/), [Yuankai Qi](https://sites.google.com/site/yuankiqi/home), [Qi Wu](http://www.qi-wu.me/), [Stephen Gould](http://users.cecs.anu.edu.au/~sgould/)<br>
|
||||
|
||||
[[Paper](https://papers.nips.cc/paper/2020/hash/56dc0997d871e9177069bb472574eb29-Abstract.html)] [[Supplemental](https://papers.nips.cc/paper/2020/file/56dc0997d871e9177069bb472574eb29-Supplemental.pdf)] [[GitHub](https://github.com/YicongHong/Entity-Graph-VLN)]
|
||||
|
||||
<p align="center">
|
||||
<img src="teaser/f1.png" width="100%">
|
||||
</p>
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Installation
|
||||
|
||||
Install the [Matterport3D Simulator](https://github.com/peteanderson80/Matterport3DSimulator).
|
||||
|
||||
Please find the versions of packages in our environment [here](https://github.com/YicongHong/Entity-Graph-VLN/blob/master/entity_graph_vln.yml). In particular, we use:
|
||||
- Python 3.6.9
|
||||
- NumPy 1.18.1
|
||||
- OpenCV 3.4.2
|
||||
- PyTorch 1.3.0
|
||||
- Torchvision 0.4.1
|
||||
|
||||
### Data Preparation
|
||||
|
||||
Please follow the instructions below to prepare the data in directories:
|
||||
|
||||
- `connectivity`
|
||||
- Download the [connectivity maps [23.8MB]](https://github.com/peteanderson80/Matterport3DSimulator/tree/master/connectivity).
|
||||
- `data`
|
||||
- Download the [R2R data [5.8MB]](https://github.com/peteanderson80/Matterport3DSimulator/tree/master/tasks/R2R/data).
|
||||
- Download the vocabulary and the [augmented data from EnvDrop [79.5MB]](https://github.com/airsplay/R2R-EnvDrop/tree/master/tasks/R2R/data).
|
||||
- `img_features`
|
||||
- Download the [Scene features [4.2GB]](https://www.dropbox.com/s/85tpa6tc3enl5ud/ResNet-152-places365.zip?dl=1) (ResNet-152-Places365).
|
||||
- Download the pre-processed [Object features and vocabulary [1.3GB]](https://zenodo.org/record/4310441/files/objects.zip?download=1) ([Caffe Faster-RCNN](https://github.com/peteanderson80/bottom-up-attention)).
|
||||
|
||||
### Trained Network Weights
|
||||
|
||||
- `snap`
|
||||
- Download the trained [network weights [146.0MB]](https://zenodo.org/record/4310441/files/snap.zip?download=1)
|
||||
|
||||
## R2R Navigation
|
||||
|
||||
Please read Peter Anderson's VLN paper for the [R2R Navigation task](https://arxiv.org/abs/1711.07280).
|
||||
|
||||
Our code is based on the code structure of the [EnvDrop](https://github.com/airsplay/R2R-EnvDrop).
|
||||
|
||||
### Reproduce Testing Results
|
||||
|
||||
To replicate the performance reported in our paper, load the trained network weights and run validation:
|
||||
```bash
|
||||
bash run/agent.bash
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
#### Navigator
|
||||
|
||||
To train the network from scratch, first train a Navigator on the R2R training split:
|
||||
|
||||
Modify `run/agent.bash`, remove the argument for `--load` and set `--train listener`. Then,
|
||||
```bash
|
||||
bash run/agent.bash
|
||||
```
|
||||
The trained Navigator will be saved under `snap/`.
|
||||
|
||||
#### Speaker
|
||||
|
||||
You also need to train a [Speaker](https://github.com/airsplay/R2R-EnvDrop) for augmented training:
|
||||
```bash
|
||||
bash run/speak.bash
|
||||
```
|
||||
The trained Speaker will be saved under `snap/`.
|
||||
|
||||
#### Augmented Navigator
|
||||
|
||||
Finally, keep training the Navigator with the mixture of original data and [augmented data](http://www.cs.unc.edu/~airsplay/aug_paths.json):
|
||||
```bash
|
||||
bash run/bt_envdrop.bash
|
||||
```
|
||||
We apply a one-step learning rate decay to 1e-5 when training saturates.
|
||||
|
||||
## Citation
|
||||
If you use or discuss our Entity Relationship Graph, please cite our paper:
|
||||
```
|
||||
@article{hong2020language,
|
||||
title={Language and Visual Entity Relationship Graph for Agent Navigation},
|
||||
author={Hong, Yicong and Rodriguez, Cristian and Qi, Yuankai and Wu, Qi and Gould, Stephen},
|
||||
journal={Advances in Neural Information Processing Systems},
|
||||
volume={33},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
1
connectivity
Symbolic link
1
connectivity
Symbolic link
@ -0,0 +1 @@
|
||||
/students/u5399302/MatterportData/connectivity
|
||||
1
data/R2R_val_train_seen.json
Normal file
1
data/R2R_val_train_seen.json
Normal file
File diff suppressed because one or more lines are too long
30522
data/vocab.txt
Normal file
30522
data/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
1
img_features
Symbolic link
1
img_features
Symbolic link
@ -0,0 +1 @@
|
||||
/students/u5399302/MatterportData/img_features
|
||||
584
r2r_src/agent.py
Normal file
584
r2r_src/agent.py
Normal file
@ -0,0 +1,584 @@
|
||||
# R2R-EnvDrop, 2019, haotan@cs.unc.edu
|
||||
# Modified in Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from torch import optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
from env import R2RBatch
|
||||
import utils
|
||||
from utils import padding_idx, print_progress
|
||||
import model_OSCAR, model_PREVALENT
|
||||
import param
|
||||
from param import args
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class BaseAgent(object):
|
||||
''' Base class for an R2R agent to generate and save trajectories. '''
|
||||
|
||||
def __init__(self, env, results_path):
|
||||
self.env = env
|
||||
self.results_path = results_path
|
||||
random.seed(1)
|
||||
self.results = {}
|
||||
self.losses = [] # For learning agents
|
||||
|
||||
def write_results(self):
|
||||
output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()]
|
||||
with open(self.results_path, 'w') as f:
|
||||
json.dump(output, f)
|
||||
|
||||
def get_results(self):
|
||||
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
|
||||
return output
|
||||
|
||||
def rollout(self, **args):
|
||||
''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] '''
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_agent(name):
|
||||
return globals()[name+"Agent"]
|
||||
|
||||
def test(self, iters=None, **kwargs):
|
||||
self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch
|
||||
self.losses = []
|
||||
self.results = {}
|
||||
# We rely on env showing the entire batch before repeating anything
|
||||
looped = False
|
||||
self.loss = 0
|
||||
if iters is not None:
|
||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||
for i in range(iters):
|
||||
for traj in self.rollout(**kwargs):
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = traj['path']
|
||||
else: # Do a full round
|
||||
while True:
|
||||
for traj in self.rollout(**kwargs):
|
||||
if traj['instr_id'] in self.results:
|
||||
looped = True
|
||||
else:
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = traj['path']
|
||||
if looped:
|
||||
break
|
||||
|
||||
|
||||
class Seq2SeqAgent(BaseAgent):
|
||||
''' An agent based on an LSTM seq2seq model with attention. '''
|
||||
|
||||
# For now, the agent can't pick which forward move to make - just the one in the middle
|
||||
env_actions = {
|
||||
'left': (0,-1, 0), # left
|
||||
'right': (0, 1, 0), # right
|
||||
'up': (0, 0, 1), # up
|
||||
'down': (0, 0,-1), # down
|
||||
'forward': (1, 0, 0), # forward
|
||||
'<end>': (0, 0, 0), # <end>
|
||||
'<start>': (0, 0, 0), # <start>
|
||||
'<ignore>': (0, 0, 0) # <ignore>
|
||||
}
|
||||
|
||||
def __init__(self, env, results_path, tok, episode_len=20):
|
||||
super(Seq2SeqAgent, self).__init__(env, results_path)
|
||||
self.tok = tok
|
||||
self.episode_len = episode_len
|
||||
self.feature_size = self.env.feature_size
|
||||
|
||||
# Models
|
||||
if args.vlnbert == 'oscar':
|
||||
self.vln_bert = model_OSCAR.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
|
||||
self.critic = model_OSCAR.Critic().cuda()
|
||||
elif args.vlnbert == 'prevalent':
|
||||
self.vln_bert = model_PREVALENT.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
|
||||
self.critic = model_PREVALENT.Critic().cuda()
|
||||
self.models = (self.vln_bert, self.critic)
|
||||
|
||||
# Optimizers
|
||||
self.vln_bert_optimizer = args.optimizer(self.vln_bert.parameters(), lr=args.lr)
|
||||
self.critic_optimizer = args.optimizer(self.critic.parameters(), lr=args.lr)
|
||||
self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer)
|
||||
|
||||
# Evaluations
|
||||
self.losses = []
|
||||
self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
|
||||
self.ndtw_criterion = utils.ndtw_initialize()
|
||||
|
||||
# Logs
|
||||
sys.stdout.flush()
|
||||
self.logs = defaultdict(list)
|
||||
|
||||
def _sort_batch(self, obs):
|
||||
seq_tensor = np.array([ob['instr_encoding'] for ob in obs])
|
||||
seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1)
|
||||
seq_lengths[seq_lengths == 0] = seq_tensor.shape[1]
|
||||
|
||||
seq_tensor = torch.from_numpy(seq_tensor)
|
||||
seq_lengths = torch.from_numpy(seq_lengths)
|
||||
|
||||
# Sort sequences by lengths
|
||||
seq_lengths, perm_idx = seq_lengths.sort(0, True) # True -> descending
|
||||
sorted_tensor = seq_tensor[perm_idx]
|
||||
mask = (sorted_tensor != padding_idx)
|
||||
|
||||
token_type_ids = torch.zeros_like(mask)
|
||||
|
||||
return Variable(sorted_tensor, requires_grad=False).long().cuda(), \
|
||||
mask.long().cuda(), token_type_ids.long().cuda(), \
|
||||
list(seq_lengths), list(perm_idx)
|
||||
|
||||
def _feature_variable(self, obs):
|
||||
''' Extract precomputed features into variable. '''
|
||||
features = np.empty((len(obs), args.views, self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
for i, ob in enumerate(obs):
|
||||
features[i, :, :] = ob['feature'] # Image feat
|
||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||
|
||||
def _candidate_variable(self, obs):
|
||||
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
|
||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||
# which is zero in my implementation
|
||||
for i, ob in enumerate(obs):
|
||||
for j, cc in enumerate(ob['candidate']):
|
||||
candidate_feat[i, j, :] = cc['feature']
|
||||
|
||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||
|
||||
def get_input_feat(self, obs):
|
||||
input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32)
|
||||
for i, ob in enumerate(obs):
|
||||
input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation'])
|
||||
input_a_t = torch.from_numpy(input_a_t).cuda()
|
||||
# f_t = self._feature_variable(obs) # Pano image features from obs
|
||||
candidate_feat, candidate_leng = self._candidate_variable(obs)
|
||||
|
||||
return input_a_t, candidate_feat, candidate_leng
|
||||
|
||||
def _teacher_action(self, obs, ended):
|
||||
"""
|
||||
Extract teacher actions into variable.
|
||||
:param obs: The observation.
|
||||
:param ended: Whether the action seq is ended
|
||||
:return:
|
||||
"""
|
||||
a = np.zeros(len(obs), dtype=np.int64)
|
||||
for i, ob in enumerate(obs):
|
||||
if ended[i]: # Just ignore this index
|
||||
a[i] = args.ignoreid
|
||||
else:
|
||||
for k, candidate in enumerate(ob['candidate']):
|
||||
if candidate['viewpointId'] == ob['teacher']: # Next view point
|
||||
a[i] = k
|
||||
break
|
||||
else: # Stop here
|
||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||
a[i] = len(ob['candidate'])
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
||||
"""
|
||||
Interface between Panoramic view and Egocentric view
|
||||
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
|
||||
"""
|
||||
def take_action(i, idx, name):
|
||||
if type(name) is int: # Go to the next view
|
||||
self.env.env.sims[idx].makeAction(name, 0, 0)
|
||||
else: # Adjust
|
||||
self.env.env.sims[idx].makeAction(*self.env_actions[name])
|
||||
|
||||
if perm_idx is None:
|
||||
perm_idx = range(len(perm_obs))
|
||||
|
||||
for i, idx in enumerate(perm_idx):
|
||||
action = a_t[i]
|
||||
if action != -1: # -1 is the <stop> action
|
||||
select_candidate = perm_obs[i]['candidate'][action]
|
||||
src_point = perm_obs[i]['viewIndex']
|
||||
trg_point = select_candidate['pointId']
|
||||
src_level = (src_point ) // 12 # The point idx started from 0
|
||||
trg_level = (trg_point ) // 12
|
||||
while src_level < trg_level: # Tune up
|
||||
take_action(i, idx, 'up')
|
||||
src_level += 1
|
||||
while src_level > trg_level: # Tune down
|
||||
take_action(i, idx, 'down')
|
||||
src_level -= 1
|
||||
while self.env.env.sims[idx].getState().viewIndex != trg_point: # Turn right until the target
|
||||
take_action(i, idx, 'right')
|
||||
assert select_candidate['viewpointId'] == \
|
||||
self.env.env.sims[idx].getState().navigableLocations[select_candidate['idx']].viewpointId
|
||||
take_action(i, idx, select_candidate['idx'])
|
||||
|
||||
state = self.env.env.sims[idx].getState()
|
||||
if traj is not None:
|
||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||
|
||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||
"""
|
||||
:param train_ml: The weight to train with maximum likelihood
|
||||
:param train_rl: whether use RL in training
|
||||
:param reset: Reset the environment
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.feedback == 'teacher' or self.feedback == 'argmax':
|
||||
train_rl = False
|
||||
|
||||
if reset: # Reset env
|
||||
obs = np.array(self.env.reset())
|
||||
else:
|
||||
obs = np.array(self.env._get_obs())
|
||||
|
||||
batch_size = len(obs)
|
||||
|
||||
# Language input
|
||||
sentence, language_attention_mask, token_type_ids, \
|
||||
seq_lengths, perm_idx = self._sort_batch(obs)
|
||||
perm_obs = obs[perm_idx]
|
||||
|
||||
''' Language BERT '''
|
||||
language_inputs = {'mode': 'language',
|
||||
'sentence': sentence,
|
||||
'attention_mask': language_attention_mask,
|
||||
'lang_mask': language_attention_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
if args.vlnbert == 'oscar':
|
||||
language_features = self.vln_bert(**language_inputs)
|
||||
elif args.vlnbert == 'prevalent':
|
||||
h_t, language_features = self.vln_bert(**language_inputs)
|
||||
|
||||
# Record starting point
|
||||
traj = [{
|
||||
'instr_id': ob['instr_id'],
|
||||
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
||||
} for ob in perm_obs]
|
||||
|
||||
# Init the reward shaping
|
||||
last_dist = np.zeros(batch_size, np.float32)
|
||||
last_ndtw = np.zeros(batch_size, np.float32)
|
||||
for i, ob in enumerate(perm_obs): # The init distance from the view point to the target
|
||||
last_dist[i] = ob['distance']
|
||||
path_act = [vp[0] for vp in traj[i]['path']]
|
||||
last_ndtw[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
|
||||
|
||||
# Initialization the tracking state
|
||||
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
|
||||
|
||||
# Init the logs
|
||||
rewards = []
|
||||
hidden_states = []
|
||||
policy_log_probs = []
|
||||
masks = []
|
||||
entropys = []
|
||||
ml_loss = 0.
|
||||
|
||||
for t in range(self.episode_len):
|
||||
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, serves
|
||||
# as the agent's state passing through time steps
|
||||
if (t >= 1) or (args.vlnbert=='prevalent'):
|
||||
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
|
||||
|
||||
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
|
||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
||||
|
||||
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
||||
''' Visual BERT '''
|
||||
visual_inputs = {'mode': 'visual',
|
||||
'sentence': language_features,
|
||||
'attention_mask': visual_attention_mask,
|
||||
'lang_mask': language_attention_mask,
|
||||
'vis_mask': visual_temp_mask,
|
||||
'token_type_ids': token_type_ids,
|
||||
'action_feats': input_a_t,
|
||||
# 'pano_feats': f_t,
|
||||
'cand_feats': candidate_feat}
|
||||
h_t, logit = self.vln_bert(**visual_inputs)
|
||||
hidden_states.append(h_t)
|
||||
|
||||
# Mask outputs where agent can't move forward
|
||||
# Here the logit is [b, max_candidate]
|
||||
candidate_mask = utils.length2mask(candidate_leng)
|
||||
logit.masked_fill_(candidate_mask, -float('inf'))
|
||||
|
||||
# Supervised training
|
||||
target = self._teacher_action(perm_obs, ended)
|
||||
ml_loss += self.criterion(logit, target)
|
||||
|
||||
# Determine next model inputs
|
||||
if self.feedback == 'teacher':
|
||||
a_t = target # teacher forcing
|
||||
elif self.feedback == 'argmax':
|
||||
_, a_t = logit.max(1) # student forcing - argmax
|
||||
a_t = a_t.detach()
|
||||
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
|
||||
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
|
||||
elif self.feedback == 'sample':
|
||||
probs = F.softmax(logit, 1) # sampling an action from model
|
||||
c = torch.distributions.Categorical(probs)
|
||||
self.logs['entropy'].append(c.entropy().sum().item()) # For log
|
||||
entropys.append(c.entropy()) # For optimization
|
||||
a_t = c.sample().detach()
|
||||
policy_log_probs.append(c.log_prob(a_t))
|
||||
else:
|
||||
print(self.feedback)
|
||||
sys.exit('Invalid feedback option')
|
||||
# Prepare environment action
|
||||
# NOTE: Env action is in the perm_obs space
|
||||
cpu_a_t = a_t.cpu().numpy()
|
||||
for i, next_id in enumerate(cpu_a_t):
|
||||
if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||
|
||||
# Make action and get the new state
|
||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
||||
obs = np.array(self.env._get_obs())
|
||||
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
||||
|
||||
if train_rl:
|
||||
# Calculate the mask and reward
|
||||
dist = np.zeros(batch_size, np.float32)
|
||||
ndtw_score = np.zeros(batch_size, np.float32)
|
||||
reward = np.zeros(batch_size, np.float32)
|
||||
mask = np.ones(batch_size, np.float32)
|
||||
for i, ob in enumerate(perm_obs):
|
||||
dist[i] = ob['distance']
|
||||
path_act = [vp[0] for vp in traj[i]['path']]
|
||||
ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
|
||||
|
||||
if ended[i]:
|
||||
reward[i] = 0.0
|
||||
mask[i] = 0.0
|
||||
else:
|
||||
action_idx = cpu_a_t[i]
|
||||
# Target reward
|
||||
if action_idx == -1: # If the action now is end
|
||||
if dist[i] < 3.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
else: # The action is not end
|
||||
# Path fidelity rewards (distance & nDTW)
|
||||
reward[i] = - (dist[i] - last_dist[i])
|
||||
ndtw_reward = ndtw_score[i] - last_ndtw[i]
|
||||
if reward[i] > 0.0: # Quantification
|
||||
reward[i] = 1.0 + ndtw_reward
|
||||
elif reward[i] < 0.0:
|
||||
reward[i] = -1.0 + ndtw_reward
|
||||
else:
|
||||
raise NameError("The action doesn't change the move")
|
||||
# Miss the target penalty
|
||||
if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0):
|
||||
reward[i] -= (1.0 - last_dist[i]) * 2.0
|
||||
rewards.append(reward)
|
||||
masks.append(mask)
|
||||
last_dist[:] = dist
|
||||
last_ndtw[:] = ndtw_score
|
||||
|
||||
# Update the finished actions
|
||||
# -1 means ended or ignored (already ended)
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||
|
||||
# Early exit if all ended
|
||||
if ended.all():
|
||||
break
|
||||
|
||||
if train_rl:
|
||||
# Last action in A2C
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
|
||||
|
||||
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
|
||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
||||
|
||||
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
||||
''' Visual BERT '''
|
||||
visual_inputs = {'mode': 'visual',
|
||||
'sentence': language_features,
|
||||
'attention_mask': visual_attention_mask,
|
||||
'lang_mask': language_attention_mask,
|
||||
'vis_mask': visual_temp_mask,
|
||||
'token_type_ids': token_type_ids,
|
||||
'action_feats': input_a_t,
|
||||
# 'pano_feats': f_t,
|
||||
'cand_feats': candidate_feat}
|
||||
last_h_, _ = self.vln_bert(**visual_inputs)
|
||||
|
||||
rl_loss = 0.
|
||||
|
||||
# NOW, A2C!!!
|
||||
# Calculate the final discounted reward
|
||||
last_value__ = self.critic(last_h_).detach() # The value esti of the last state, remove the grad for safety
|
||||
discount_reward = np.zeros(batch_size, np.float32) # The inital reward is zero
|
||||
for i in range(batch_size):
|
||||
if not ended[i]: # If the action is not ended, use the value function as the last reward
|
||||
discount_reward[i] = last_value__[i]
|
||||
|
||||
length = len(rewards)
|
||||
total = 0
|
||||
for t in range(length-1, -1, -1):
|
||||
discount_reward = discount_reward * args.gamma + rewards[t] # If it ended, the reward will be 0
|
||||
mask_ = Variable(torch.from_numpy(masks[t]), requires_grad=False).cuda()
|
||||
clip_reward = discount_reward.copy()
|
||||
r_ = Variable(torch.from_numpy(clip_reward), requires_grad=False).cuda()
|
||||
v_ = self.critic(hidden_states[t])
|
||||
a_ = (r_ - v_).detach()
|
||||
|
||||
rl_loss += (-policy_log_probs[t] * a_ * mask_).sum()
|
||||
rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss
|
||||
if self.feedback == 'sample':
|
||||
rl_loss += (- 0.01 * entropys[t] * mask_).sum()
|
||||
self.logs['critic_loss'].append((((r_ - v_) ** 2) * mask_).sum().item())
|
||||
|
||||
total = total + np.sum(masks[t])
|
||||
self.logs['total'].append(total)
|
||||
|
||||
# Normalize the loss function
|
||||
if args.normalize_loss == 'total':
|
||||
rl_loss /= total
|
||||
elif args.normalize_loss == 'batch':
|
||||
rl_loss /= batch_size
|
||||
else:
|
||||
assert args.normalize_loss == 'none'
|
||||
|
||||
self.loss += rl_loss
|
||||
self.logs['RL_loss'].append(rl_loss.item())
|
||||
|
||||
if train_ml is not None:
|
||||
self.loss += ml_loss * train_ml / batch_size
|
||||
self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item())
|
||||
|
||||
if type(self.loss) is int: # For safety, it will be activated if no losses are added
|
||||
self.losses.append(0.)
|
||||
else:
|
||||
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
||||
|
||||
return traj
|
||||
|
||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||
''' Evaluate once on each instruction in the current environment '''
|
||||
self.feedback = feedback
|
||||
if use_dropout:
|
||||
self.vln_bert.train()
|
||||
self.critic.train()
|
||||
else:
|
||||
self.vln_bert.eval()
|
||||
self.critic.eval()
|
||||
super(Seq2SeqAgent, self).test(iters)
|
||||
|
||||
def zero_grad(self):
|
||||
self.loss = 0.
|
||||
self.losses = []
|
||||
for model, optimizer in zip(self.models, self.optimizers):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
def accumulate_gradient(self, feedback='teacher', **kwargs):
|
||||
if feedback == 'teacher':
|
||||
self.feedback = 'teacher'
|
||||
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
|
||||
elif feedback == 'sample':
|
||||
self.feedback = 'teacher'
|
||||
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
|
||||
self.feedback = 'sample'
|
||||
self.rollout(train_ml=None, train_rl=True, **kwargs)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def optim_step(self):
|
||||
self.loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.)
|
||||
|
||||
self.vln_bert_optimizer.step()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
def train(self, n_iters, feedback='teacher', **kwargs):
|
||||
''' Train for a given number of iterations '''
|
||||
self.feedback = feedback
|
||||
|
||||
self.vln_bert.train()
|
||||
self.critic.train()
|
||||
|
||||
self.losses = []
|
||||
for iter in range(1, n_iters + 1):
|
||||
|
||||
self.vln_bert_optimizer.zero_grad()
|
||||
self.critic_optimizer.zero_grad()
|
||||
|
||||
self.loss = 0
|
||||
|
||||
if feedback == 'teacher':
|
||||
self.feedback = 'teacher'
|
||||
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
|
||||
elif feedback == 'sample': # agents in IL and RL separately
|
||||
if args.ml_weight != 0:
|
||||
self.feedback = 'teacher'
|
||||
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
|
||||
self.feedback = 'sample'
|
||||
self.rollout(train_ml=None, train_rl=True, **kwargs)
|
||||
else:
|
||||
assert False
|
||||
|
||||
self.loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.)
|
||||
|
||||
self.vln_bert_optimizer.step()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
if args.aug is None:
|
||||
print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50)
|
||||
|
||||
def save(self, epoch, path):
|
||||
''' Snapshot models '''
|
||||
the_dir, _ = os.path.split(path)
|
||||
os.makedirs(the_dir, exist_ok=True)
|
||||
states = {}
|
||||
def create_state(name, model, optimizer):
|
||||
states[name] = {
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}
|
||||
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
|
||||
("critic", self.critic, self.critic_optimizer)]
|
||||
for param in all_tuple:
|
||||
create_state(*param)
|
||||
torch.save(states, path)
|
||||
|
||||
def load(self, path):
|
||||
''' Loads parameters (but not training state) '''
|
||||
states = torch.load(path)
|
||||
|
||||
def recover_state(name, model, optimizer):
|
||||
state = model.state_dict()
|
||||
model_keys = set(state.keys())
|
||||
load_keys = set(states[name]['state_dict'].keys())
|
||||
if model_keys != load_keys:
|
||||
print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
|
||||
state.update(states[name]['state_dict'])
|
||||
model.load_state_dict(state)
|
||||
if args.loadOptim:
|
||||
optimizer.load_state_dict(states[name]['optimizer'])
|
||||
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
|
||||
("critic", self.critic, self.critic_optimizer)]
|
||||
for param in all_tuple:
|
||||
recover_state(*param)
|
||||
return states['vln_bert']['epoch'] - 1
|
||||
359
r2r_src/env.py
Normal file
359
r2r_src/env.py
Normal file
@ -0,0 +1,359 @@
|
||||
''' Batched Room-to-Room navigation environment '''
|
||||
|
||||
import sys
|
||||
sys.path.append('buildpy36')
|
||||
sys.path.append('Matterport_Simulator/build/')
|
||||
import MatterSim
|
||||
import csv
|
||||
import numpy as np
|
||||
import math
|
||||
import base64
|
||||
import utils
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import networkx as nx
|
||||
from param import args
|
||||
|
||||
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
|
||||
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
|
||||
|
||||
class EnvBatch():
|
||||
''' A simple wrapper for a batch of MatterSim environments,
|
||||
using discretized viewpoints and pretrained features '''
|
||||
|
||||
def __init__(self, feature_store=None, batch_size=100):
|
||||
"""
|
||||
1. Load pretrained image feature
|
||||
2. Init the Simulator.
|
||||
:param feature_store: The name of file stored the feature.
|
||||
:param batch_size: Used to create the simulator list.
|
||||
"""
|
||||
if feature_store:
|
||||
if type(feature_store) is dict: # A silly way to avoid multiple reading
|
||||
self.features = feature_store
|
||||
self.image_w = 640
|
||||
self.image_h = 480
|
||||
self.vfov = 60
|
||||
self.feature_size = next(iter(self.features.values())).shape[-1]
|
||||
print('The feature size is %d' % self.feature_size)
|
||||
else:
|
||||
print(' Image features not provided - in testing mode')
|
||||
self.features = None
|
||||
self.image_w = 640
|
||||
self.image_h = 480
|
||||
self.vfov = 60
|
||||
self.sims = []
|
||||
for i in range(batch_size):
|
||||
sim = MatterSim.Simulator()
|
||||
sim.setRenderingEnabled(False)
|
||||
sim.setDiscretizedViewingAngles(True) # Set increment/decrement to 30 degree. (otherwise by radians)
|
||||
sim.setCameraResolution(self.image_w, self.image_h)
|
||||
sim.setCameraVFOV(math.radians(self.vfov))
|
||||
sim.init()
|
||||
self.sims.append(sim)
|
||||
|
||||
def _make_id(self, scanId, viewpointId):
|
||||
return scanId + '_' + viewpointId
|
||||
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings):
|
||||
for i, (scanId, viewpointId, heading) in enumerate(zip(scanIds, viewpointIds, headings)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0)
|
||||
|
||||
def getStates(self):
|
||||
"""
|
||||
Get list of states augmented with precomputed image features. rgb field will be empty.
|
||||
Agent's current view [0-35] (set only when viewing angles are discretized)
|
||||
[0-11] looking down, [12-23] looking at horizon, [24-35] looking up
|
||||
:return: [ ((30, 2048), sim_state) ] * batch_size
|
||||
"""
|
||||
feature_states = []
|
||||
for i, sim in enumerate(self.sims):
|
||||
state = sim.getState()
|
||||
|
||||
long_id = self._make_id(state.scanId, state.location.viewpointId)
|
||||
if self.features:
|
||||
feature = self.features[long_id]
|
||||
feature_states.append((feature, state))
|
||||
else:
|
||||
feature_states.append((None, state))
|
||||
return feature_states
|
||||
|
||||
def makeActions(self, actions):
|
||||
''' Take an action using the full state dependent action interface (with batched input).
|
||||
Every action element should be an (index, heading, elevation) tuple. '''
|
||||
for i, (index, heading, elevation) in enumerate(actions):
|
||||
self.sims[i].makeAction(index, heading, elevation)
|
||||
|
||||
|
||||
class R2RBatch():
|
||||
''' Implements the Room to Room navigation task, using discretized viewpoints and pretrained features '''
|
||||
|
||||
def __init__(self, feature_store, batch_size=100, seed=10, splits=['train'], tokenizer=None,
|
||||
name=None):
|
||||
self.env = EnvBatch(feature_store=feature_store, batch_size=batch_size)
|
||||
if feature_store:
|
||||
self.feature_size = self.env.feature_size
|
||||
else:
|
||||
self.feature_size = 2048
|
||||
self.data = []
|
||||
if tokenizer:
|
||||
self.tok = tokenizer
|
||||
scans = []
|
||||
for split in splits:
|
||||
for i_item, item in enumerate(load_datasets([split])):
|
||||
if args.test_only and i_item == 64:
|
||||
break
|
||||
if "/" in split:
|
||||
try:
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = item['path_id']
|
||||
new_item['instructions'] = item['instructions'][0]
|
||||
new_item['instr_encoding'] = item['instr_enc']
|
||||
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||
self.data.append(new_item)
|
||||
scans.append(item['scan'])
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
# Split multiple instructions into separate entries
|
||||
for j, instr in enumerate(item['instructions']):
|
||||
try:
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||
new_item['instructions'] = instr
|
||||
|
||||
''' BERT tokenizer '''
|
||||
instr_tokens = tokenizer.tokenize(instr)
|
||||
padded_instr_tokens, num_words = pad_instr_tokens(instr_tokens, args.maxInput)
|
||||
new_item['instr_encoding'] = tokenizer.convert_tokens_to_ids(padded_instr_tokens)
|
||||
|
||||
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||
self.data.append(new_item)
|
||||
scans.append(item['scan'])
|
||||
except:
|
||||
continue
|
||||
|
||||
if name is None:
|
||||
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
||||
else:
|
||||
self.name = name
|
||||
|
||||
self.scans = set(scans)
|
||||
self.splits = splits
|
||||
self.seed = seed
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
self.ix = 0
|
||||
self.batch_size = batch_size
|
||||
self._load_nav_graphs()
|
||||
|
||||
self.angle_feature = utils.get_all_point_angle_feature()
|
||||
self.sim = utils.new_simulator()
|
||||
self.buffered_state_dict = {}
|
||||
|
||||
# It means that the fake data is equals to data in the supervised setup
|
||||
self.fake_data = self.data
|
||||
print('R2RBatch loaded with %d instructions, using splits: %s' % (len(self.data), ",".join(splits)))
|
||||
|
||||
def size(self):
|
||||
return len(self.data)
|
||||
|
||||
def _load_nav_graphs(self):
|
||||
"""
|
||||
load graph from self.scan,
|
||||
Store the graph {scan_id: graph} in self.graphs
|
||||
Store the shortest path {scan_id: {view_id_x: {view_id_y: [path]} } } in self.paths
|
||||
Store the distances in self.distances. (Structure see above)
|
||||
Load connectivity graph for each scan, useful for reasoning about shortest paths
|
||||
:return: None
|
||||
"""
|
||||
print('Loading navigation graphs for %d scans' % len(self.scans))
|
||||
self.graphs = load_nav_graphs(self.scans)
|
||||
self.paths = {}
|
||||
for scan, G in self.graphs.items(): # compute all shortest paths
|
||||
self.paths[scan] = dict(nx.all_pairs_dijkstra_path(G))
|
||||
self.distances = {}
|
||||
for scan, G in self.graphs.items(): # compute all shortest paths
|
||||
self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))
|
||||
|
||||
def _next_minibatch(self, tile_one=False, batch_size=None, **kwargs):
|
||||
"""
|
||||
Store the minibach in 'self.batch'
|
||||
:param tile_one: Tile the one into batch_size
|
||||
:return: None
|
||||
"""
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
if tile_one:
|
||||
batch = [self.data[self.ix]] * batch_size
|
||||
self.ix += 1
|
||||
if self.ix >= len(self.data):
|
||||
random.shuffle(self.data)
|
||||
self.ix -= len(self.data)
|
||||
else:
|
||||
batch = self.data[self.ix: self.ix+batch_size]
|
||||
if len(batch) < batch_size:
|
||||
random.shuffle(self.data)
|
||||
self.ix = batch_size - len(batch)
|
||||
batch += self.data[:self.ix]
|
||||
else:
|
||||
self.ix += batch_size
|
||||
self.batch = batch
|
||||
|
||||
def reset_epoch(self, shuffle=False):
|
||||
''' Reset the data index to beginning of epoch. Primarily for testing.
|
||||
You must still call reset() for a new episode. '''
|
||||
if shuffle:
|
||||
random.shuffle(self.data)
|
||||
self.ix = 0
|
||||
|
||||
def _shortest_path_action(self, state, goalViewpointId):
|
||||
''' Determine next action on the shortest path to goal, for supervised training. '''
|
||||
if state.location.viewpointId == goalViewpointId:
|
||||
return goalViewpointId # Just stop here
|
||||
path = self.paths[state.scanId][state.location.viewpointId][goalViewpointId]
|
||||
nextViewpointId = path[1]
|
||||
return nextViewpointId
|
||||
|
||||
def make_candidate(self, feature, scanId, viewpointId, viewId):
|
||||
def _loc_distance(loc):
|
||||
return np.sqrt(loc.rel_heading ** 2 + loc.rel_elevation ** 2)
|
||||
base_heading = (viewId % 12) * math.radians(30)
|
||||
adj_dict = {}
|
||||
long_id = "%s_%s" % (scanId, viewpointId)
|
||||
if long_id not in self.buffered_state_dict:
|
||||
for ix in range(36):
|
||||
if ix == 0:
|
||||
self.sim.newEpisode(scanId, viewpointId, 0, math.radians(-30))
|
||||
elif ix % 12 == 0:
|
||||
self.sim.makeAction(0, 1.0, 1.0)
|
||||
else:
|
||||
self.sim.makeAction(0, 1.0, 0)
|
||||
|
||||
state = self.sim.getState()
|
||||
assert state.viewIndex == ix
|
||||
|
||||
# Heading and elevation for the viewpoint center
|
||||
heading = state.heading - base_heading
|
||||
elevation = state.elevation
|
||||
|
||||
visual_feat = feature[ix]
|
||||
|
||||
# get adjacent locations
|
||||
for j, loc in enumerate(state.navigableLocations[1:]):
|
||||
# if a loc is visible from multiple view, use the closest
|
||||
# view (in angular distance) as its representation
|
||||
distance = _loc_distance(loc)
|
||||
|
||||
# Heading and elevation for for the loc
|
||||
loc_heading = heading + loc.rel_heading
|
||||
loc_elevation = elevation + loc.rel_elevation
|
||||
angle_feat = utils.angle_feature(loc_heading, loc_elevation)
|
||||
if (loc.viewpointId not in adj_dict or
|
||||
distance < adj_dict[loc.viewpointId]['distance']):
|
||||
adj_dict[loc.viewpointId] = {
|
||||
'heading': loc_heading,
|
||||
'elevation': loc_elevation,
|
||||
"normalized_heading": state.heading + loc.rel_heading,
|
||||
'scanId':scanId,
|
||||
'viewpointId': loc.viewpointId, # Next viewpoint id
|
||||
'pointId': ix,
|
||||
'distance': distance,
|
||||
'idx': j + 1,
|
||||
'feature': np.concatenate((visual_feat, angle_feat), -1)
|
||||
}
|
||||
candidate = list(adj_dict.values())
|
||||
self.buffered_state_dict[long_id] = [
|
||||
{key: c[key]
|
||||
for key in
|
||||
['normalized_heading', 'elevation', 'scanId', 'viewpointId',
|
||||
'pointId', 'idx']}
|
||||
for c in candidate
|
||||
]
|
||||
return candidate
|
||||
else:
|
||||
candidate = self.buffered_state_dict[long_id]
|
||||
candidate_new = []
|
||||
for c in candidate:
|
||||
c_new = c.copy()
|
||||
ix = c_new['pointId']
|
||||
normalized_heading = c_new['normalized_heading']
|
||||
visual_feat = feature[ix]
|
||||
loc_heading = normalized_heading - base_heading
|
||||
c_new['heading'] = loc_heading
|
||||
angle_feat = utils.angle_feature(c_new['heading'], c_new['elevation'])
|
||||
c_new['feature'] = np.concatenate((visual_feat, angle_feat), -1)
|
||||
c_new.pop('normalized_heading')
|
||||
candidate_new.append(c_new)
|
||||
return candidate_new
|
||||
|
||||
def _get_obs(self):
|
||||
obs = []
|
||||
for i, (feature, state) in enumerate(self.env.getStates()):
|
||||
item = self.batch[i]
|
||||
base_view_id = state.viewIndex
|
||||
|
||||
if feature is None:
|
||||
feature = np.zeros((36, 2048))
|
||||
|
||||
# Full features
|
||||
candidate = self.make_candidate(feature, state.scanId, state.location.viewpointId, state.viewIndex)
|
||||
# [visual_feature, angle_feature] for views
|
||||
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
||||
|
||||
obs.append({
|
||||
'instr_id' : item['instr_id'],
|
||||
'scan' : state.scanId,
|
||||
'viewpoint' : state.location.viewpointId,
|
||||
'viewIndex' : state.viewIndex,
|
||||
'heading' : state.heading,
|
||||
'elevation' : state.elevation,
|
||||
'feature' : feature,
|
||||
'candidate': candidate,
|
||||
'navigableLocations' : state.navigableLocations,
|
||||
'instructions' : item['instructions'],
|
||||
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['path_id']
|
||||
})
|
||||
if 'instr_encoding' in item:
|
||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||
# A2C reward. The negative distance between the state and the final state
|
||||
obs[-1]['distance'] = self.distances[state.scanId][state.location.viewpointId][item['path'][-1]]
|
||||
return obs
|
||||
|
||||
def reset(self, batch=None, inject=False, **kwargs):
|
||||
''' Load a new minibatch / episodes. '''
|
||||
if batch is None: # Allow the user to explicitly define the batch
|
||||
self._next_minibatch(**kwargs)
|
||||
else:
|
||||
if inject: # Inject the batch into the next minibatch
|
||||
self._next_minibatch(**kwargs)
|
||||
self.batch[:len(batch)] = batch
|
||||
else: # Else set the batch to the current batch
|
||||
self.batch = batch
|
||||
scanIds = [item['scan'] for item in self.batch]
|
||||
viewpointIds = [item['path'][0] for item in self.batch]
|
||||
headings = [item['heading'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, viewpointIds, headings)
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, actions):
|
||||
''' Take action (same interface as makeActions) '''
|
||||
self.env.makeActions(actions)
|
||||
return self._get_obs()
|
||||
|
||||
def get_statistics(self):
|
||||
stats = {}
|
||||
length = 0
|
||||
path = 0
|
||||
for datum in self.data:
|
||||
length += len(self.tok.split_sentence(datum['instructions']))
|
||||
path += self.distances[datum['scan']][datum['path'][0]][datum['path'][-1]]
|
||||
stats['length'] = length / len(self.data)
|
||||
stats['path'] = path / len(self.data)
|
||||
return stats
|
||||
112
r2r_src/eval.py
Normal file
112
r2r_src/eval.py
Normal file
@ -0,0 +1,112 @@
|
||||
''' Evaluation of agent trajectories '''
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pprint
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
from env import R2RBatch
|
||||
from utils import load_datasets, load_nav_graphs, ndtw_graphload, DTW
|
||||
from agent import BaseAgent
|
||||
|
||||
|
||||
class Evaluation(object):
|
||||
''' Results submission format: [{'instr_id': string, 'trajectory':[(viewpoint_id, heading_rads, elevation_rads),] } ] '''
|
||||
|
||||
def __init__(self, splits, scans, tok):
|
||||
self.error_margin = 3.0
|
||||
self.splits = splits
|
||||
self.tok = tok
|
||||
self.gt = {}
|
||||
self.instr_ids = []
|
||||
self.scans = []
|
||||
for split in splits:
|
||||
for item in load_datasets([split]):
|
||||
if scans is not None and item['scan'] not in scans:
|
||||
continue
|
||||
self.gt[str(item['path_id'])] = item
|
||||
self.scans.append(item['scan'])
|
||||
self.instr_ids += ['%s_%d' % (item['path_id'], i) for i in range(len(item['instructions']))]
|
||||
self.scans = set(self.scans)
|
||||
self.instr_ids = set(self.instr_ids)
|
||||
self.graphs = load_nav_graphs(self.scans)
|
||||
self.distances = {}
|
||||
for scan,G in self.graphs.items(): # compute all shortest paths
|
||||
self.distances[scan] = dict(nx.all_pairs_dijkstra_path_length(G))
|
||||
|
||||
def _get_nearest(self, scan, goal_id, path):
|
||||
near_id = path[0][0]
|
||||
near_d = self.distances[scan][near_id][goal_id]
|
||||
for item in path:
|
||||
d = self.distances[scan][item[0]][goal_id]
|
||||
if d < near_d:
|
||||
near_id = item[0]
|
||||
near_d = d
|
||||
return near_id
|
||||
|
||||
def _score_item(self, instr_id, path):
|
||||
''' Calculate error based on the final position in trajectory, and also
|
||||
the closest position (oracle stopping rule).
|
||||
The path contains [view_id, angle, vofv] '''
|
||||
gt = self.gt[instr_id.split('_')[-2]]
|
||||
start = gt['path'][0]
|
||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||
goal = gt['path'][-1]
|
||||
final_position = path[-1][0] # the first of [view_id, angle, vofv]
|
||||
nearest_position = self._get_nearest(gt['scan'], goal, path)
|
||||
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
||||
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
||||
self.scores['trajectory_steps'].append(len(path)-1)
|
||||
distance = 0 # length of the path in meters
|
||||
prev = path[0]
|
||||
for curr in path[1:]:
|
||||
distance += self.distances[gt['scan']][prev[0]][curr[0]]
|
||||
prev = curr
|
||||
self.scores['trajectory_lengths'].append(distance)
|
||||
self.scores['shortest_lengths'].append(
|
||||
self.distances[gt['scan']][start][goal]
|
||||
)
|
||||
|
||||
def score(self, output_file):
|
||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||
self.scores = defaultdict(list)
|
||||
instr_ids = set(self.instr_ids)
|
||||
if type(output_file) is str:
|
||||
with open(output_file) as f:
|
||||
results = json.load(f)
|
||||
else:
|
||||
results = output_file
|
||||
|
||||
print('result length', len(results))
|
||||
for item in results:
|
||||
# Check against expected ids
|
||||
if item['instr_id'] in instr_ids:
|
||||
instr_ids.remove(item['instr_id'])
|
||||
self._score_item(item['instr_id'], item['trajectory'])
|
||||
|
||||
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
||||
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
||||
% (len(instr_ids), len(self.instr_ids), ",".join(self.splits), output_file)
|
||||
assert len(self.scores['nav_errors']) == len(self.instr_ids)
|
||||
score_summary = {
|
||||
'nav_error': np.average(self.scores['nav_errors']),
|
||||
'oracle_error': np.average(self.scores['oracle_errors']),
|
||||
'steps': np.average(self.scores['trajectory_steps']),
|
||||
'lengths': np.average(self.scores['trajectory_lengths'])
|
||||
}
|
||||
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
|
||||
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
|
||||
oracle_successes = len([i for i in self.scores['oracle_errors'] if i < self.error_margin])
|
||||
score_summary['oracle_rate'] = float(oracle_successes)/float(len(self.scores['oracle_errors']))
|
||||
|
||||
spl = [float(error < self.error_margin) * l / max(l, p, 0.01)
|
||||
for error, p, l in
|
||||
zip(self.scores['nav_errors'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||
]
|
||||
score_summary['spl'] = np.average(spl)
|
||||
|
||||
return score_summary, self.scores
|
||||
87
r2r_src/model_OSCAR.py
Normal file
87
r2r_src/model_OSCAR.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from param import args
|
||||
|
||||
from vlnbert.vlnbert_init import get_vlnbert_models
|
||||
|
||||
|
||||
class VLNBERT(nn.Module):
|
||||
def __init__(self, feature_size=2048+128):
|
||||
super(VLNBERT, self).__init__()
|
||||
print('\nInitalizing the VLN-BERT model ...')
|
||||
self.vln_bert = get_vlnbert_models(args, config=None) # initialize the VLN-BERT
|
||||
self.vln_bert.config.directions = 4 # a preset random number
|
||||
|
||||
hidden_size = self.vln_bert.config.hidden_size
|
||||
layer_norm_eps = self.vln_bert.config.layer_norm_eps
|
||||
|
||||
self.action_state_project = nn.Sequential(
|
||||
nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh())
|
||||
self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.drop_env = nn.Dropout(p=args.featdropout)
|
||||
self.img_projection = nn.Linear(feature_size, hidden_size, bias=True)
|
||||
self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(self, mode, sentence, token_type_ids=None,
|
||||
attention_mask=None, lang_mask=None, vis_mask=None,
|
||||
position_ids=None, action_feats=None, pano_feats=None, cand_feats=None):
|
||||
|
||||
if mode == 'language':
|
||||
encoded_sentence = self.vln_bert(mode, sentence, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
return encoded_sentence
|
||||
|
||||
elif mode == 'visual':
|
||||
|
||||
state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1)
|
||||
state_with_action = self.action_state_project(state_action_embed)
|
||||
state_with_action = self.action_LayerNorm(state_with_action)
|
||||
state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1)
|
||||
|
||||
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
|
||||
|
||||
cand_feats_embed = self.img_projection(cand_feats) # [2176 * 768] projection
|
||||
cand_feats_embed = self.cand_LayerNorm(cand_feats_embed)
|
||||
|
||||
# logit is the attention scores over the candidate features
|
||||
h_t, logit = self.vln_bert(mode, state_feats,
|
||||
attention_mask=attention_mask, img_feats=cand_feats_embed)
|
||||
|
||||
return h_t, logit
|
||||
|
||||
else:
|
||||
ModuleNotFoundError
|
||||
|
||||
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self):
|
||||
super(Critic, self).__init__()
|
||||
self.state2value = nn.Sequential(
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(args.dropout),
|
||||
nn.Linear(512, 1),
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
return self.state2value(state).squeeze()
|
||||
96
r2r_src/model_PREVALENT.py
Normal file
96
r2r_src/model_PREVALENT.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
from param import args
|
||||
|
||||
from vlnbert.vlnbert_init import get_vlnbert_models
|
||||
|
||||
class VLNBERT(nn.Module):
|
||||
def __init__(self, feature_size=2048+128):
|
||||
super(VLNBERT, self).__init__()
|
||||
print('\nInitalizing the VLN-BERT model ...')
|
||||
|
||||
self.vln_bert = get_vlnbert_models(args, config=None) # initialize the VLN-BERT
|
||||
self.vln_bert.config.directions = 4 # a preset random number
|
||||
|
||||
hidden_size = self.vln_bert.config.hidden_size
|
||||
layer_norm_eps = self.vln_bert.config.layer_norm_eps
|
||||
|
||||
self.action_state_project = nn.Sequential(
|
||||
nn.Linear(hidden_size+args.angle_feat_size, hidden_size), nn.Tanh())
|
||||
self.action_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.drop_env = nn.Dropout(p=args.featdropout)
|
||||
self.img_projection = nn.Linear(feature_size, hidden_size, bias=True)
|
||||
self.cand_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.vis_lang_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.state_proj = nn.Linear(hidden_size*2, hidden_size, bias=True)
|
||||
self.state_LayerNorm = BertLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(self, mode, sentence, token_type_ids=None,
|
||||
attention_mask=None, lang_mask=None, vis_mask=None,
|
||||
position_ids=None, action_feats=None, pano_feats=None, cand_feats=None):
|
||||
|
||||
if mode == 'language':
|
||||
init_state, encoded_sentence = self.vln_bert(mode, sentence, attention_mask=attention_mask, lang_mask=lang_mask,)
|
||||
|
||||
return init_state, encoded_sentence
|
||||
|
||||
elif mode == 'visual':
|
||||
|
||||
state_action_embed = torch.cat((sentence[:,0,:], action_feats), 1)
|
||||
state_with_action = self.action_state_project(state_action_embed)
|
||||
state_with_action = self.action_LayerNorm(state_with_action)
|
||||
state_feats = torch.cat((state_with_action.unsqueeze(1), sentence[:,1:,:]), dim=1)
|
||||
|
||||
cand_feats[..., :-args.angle_feat_size] = self.drop_env(cand_feats[..., :-args.angle_feat_size])
|
||||
|
||||
# logit is the attention scores over the candidate features
|
||||
h_t, logit = self.vln_bert(mode, state_feats,
|
||||
attention_mask=attention_mask, lang_mask=lang_mask, vis_mask=vis_mask, img_feats=cand_feats)
|
||||
|
||||
# update agent's state, unify history, language and vision by elementwise product
|
||||
vis_lang_feat = self.vis_lang_LayerNorm(attended_language * attended_visual)
|
||||
state_output = torch.cat((h_t, vis_lang_feat), dim=-1)
|
||||
state_proj = self.state_proj(state_output)
|
||||
state_proj = self.state_LayerNorm(state_proj)
|
||||
|
||||
return state_proj, logit
|
||||
|
||||
else:
|
||||
ModuleNotFoundError
|
||||
|
||||
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self):
|
||||
super(Critic, self).__init__()
|
||||
self.state2value = nn.Sequential(
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(args.dropout),
|
||||
nn.Linear(512, 1),
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
return self.state2value(state).squeeze()
|
||||
88
r2r_src/param.py
Normal file
88
r2r_src/param.py
Normal file
@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
|
||||
class Param:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(description="")
|
||||
|
||||
# General
|
||||
self.parser.add_argument('--test_only', type=int, default=0, help='fast mode for testing')
|
||||
|
||||
self.parser.add_argument('--iters', type=int, default=300000, help='training iterations')
|
||||
self.parser.add_argument('--name', type=str, default='default', help='experiment id')
|
||||
self.parser.add_argument('--vlnbert', type=str, default='oscar', help='oscar or prevalent')
|
||||
self.parser.add_argument('--train', type=str, default='listener')
|
||||
self.parser.add_argument('--description', type=str, default='no description\n')
|
||||
|
||||
# Data preparation
|
||||
self.parser.add_argument('--maxInput', type=int, default=80, help="max input instruction")
|
||||
self.parser.add_argument('--maxAction', type=int, default=15, help='Max Action sequence')
|
||||
self.parser.add_argument('--batchSize', type=int, default=8)
|
||||
self.parser.add_argument('--ignoreid', type=int, default=-100)
|
||||
self.parser.add_argument('--feature_size', type=int, default=2048)
|
||||
self.parser.add_argument("--loadOptim",action="store_const", default=False, const=True)
|
||||
|
||||
# Load the model from
|
||||
self.parser.add_argument("--load", default=None, help='path of the trained model')
|
||||
|
||||
# Augmented Paths from
|
||||
self.parser.add_argument("--aug", default=None)
|
||||
|
||||
# Listener Model Config
|
||||
self.parser.add_argument("--zeroInit", dest='zero_init', action='store_const', default=False, const=True)
|
||||
self.parser.add_argument("--mlWeight", dest='ml_weight', type=float, default=0.20)
|
||||
self.parser.add_argument("--teacherWeight", dest='teacher_weight', type=float, default=1.)
|
||||
self.parser.add_argument("--features", type=str, default='places365')
|
||||
|
||||
# Dropout Param
|
||||
self.parser.add_argument('--dropout', type=float, default=0.5)
|
||||
self.parser.add_argument('--featdropout', type=float, default=0.3)
|
||||
|
||||
# Submision configuration
|
||||
self.parser.add_argument("--submit", type=int, default=0)
|
||||
|
||||
# Training Configurations
|
||||
self.parser.add_argument('--optim', type=str, default='rms') # rms, adam
|
||||
self.parser.add_argument('--lr', type=float, default=0.00001, help="the learning rate")
|
||||
self.parser.add_argument('--decay', dest='weight_decay', type=float, default=0.)
|
||||
self.parser.add_argument('--feedback', type=str, default='sample',
|
||||
help='How to choose next position, one of ``teacher``, ``sample`` and ``argmax``')
|
||||
self.parser.add_argument('--teacher', type=str, default='final',
|
||||
help="How to get supervision. one of ``next`` and ``final`` ")
|
||||
self.parser.add_argument('--epsilon', type=float, default=0.1)
|
||||
|
||||
# Model hyper params:
|
||||
self.parser.add_argument("--angleFeatSize", dest="angle_feat_size", type=int, default=4)
|
||||
|
||||
# A2C
|
||||
self.parser.add_argument("--gamma", default=0.9, type=float)
|
||||
self.parser.add_argument("--normalize", dest="normalize_loss", default="total", type=str, help='batch or total')
|
||||
|
||||
self.args = self.parser.parse_args()
|
||||
|
||||
if self.args.optim == 'rms':
|
||||
print("Optimizer: Using RMSProp")
|
||||
self.args.optimizer = torch.optim.RMSprop
|
||||
elif self.args.optim == 'adam':
|
||||
print("Optimizer: Using Adam")
|
||||
self.args.optimizer = torch.optim.Adam
|
||||
elif self.args.optim == 'adamW':
|
||||
print("Optimizer: Using AdamW")
|
||||
self.args.optimizer = torch.optim.AdamW
|
||||
elif self.args.optim == 'sgd':
|
||||
print("Optimizer: sgd")
|
||||
self.args.optimizer = torch.optim.SGD
|
||||
else:
|
||||
assert False
|
||||
|
||||
param = Param()
|
||||
args = param.args
|
||||
|
||||
args.description = args.name
|
||||
args.IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
|
||||
args.log_dir = 'snap/%s' % args.name
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.makedirs(args.log_dir)
|
||||
DEBUG_FILE = open(os.path.join('snap', args.name, "debug.log"), 'w')
|
||||
262
r2r_src/train.py
Normal file
262
r2r_src/train.py
Normal file
@ -0,0 +1,262 @@
|
||||
import torch
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress
|
||||
import utils
|
||||
from env import R2RBatch
|
||||
from agent import Seq2SeqAgent
|
||||
from eval import Evaluation
|
||||
from param import args
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from vlnbert.vlnbert_init import get_tokenizer
|
||||
|
||||
log_dir = 'snap/%s' % args.name
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
|
||||
PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv'
|
||||
|
||||
if args.features == 'imagenet':
|
||||
features = IMAGENET_FEATURES
|
||||
elif args.features == 'places365':
|
||||
features = PLACE365_FEATURES
|
||||
|
||||
feedback_method = args.feedback # teacher or sample
|
||||
|
||||
print(args); print('')
|
||||
|
||||
|
||||
''' train the listener '''
|
||||
def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)
|
||||
|
||||
record_file = open('./logs/' + args.name + '.txt', 'a')
|
||||
record_file.write(str(args) + '\n\n')
|
||||
record_file.close()
|
||||
|
||||
start_iter = 0
|
||||
if args.load is not None:
|
||||
if args.aug is None:
|
||||
start_iter = listner.load(os.path.join(args.load))
|
||||
print("\nLOAD the model from {}, iteration ".format(args.load, start_iter))
|
||||
else:
|
||||
load_iter = listner.load(os.path.join(args.load))
|
||||
print("\nLOAD the model from {}, iteration ".format(args.load, load_iter))
|
||||
|
||||
start = time.time()
|
||||
print('\nListener training starts, start iteration: %s' % str(start_iter))
|
||||
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}}
|
||||
|
||||
for idx in range(start_iter, start_iter+n_iters, log_every):
|
||||
listner.logs = defaultdict(list)
|
||||
interval = min(log_every, n_iters-idx)
|
||||
iter = idx + interval
|
||||
|
||||
# Train for log_every interval
|
||||
if aug_env is None:
|
||||
listner.env = train_env
|
||||
listner.train(interval, feedback=feedback_method) # Train interval iters
|
||||
else:
|
||||
jdx_length = len(range(interval // 2))
|
||||
for jdx in range(interval // 2):
|
||||
# Train with GT data
|
||||
listner.env = train_env
|
||||
args.ml_weight = 0.2
|
||||
listner.train(1, feedback=feedback_method)
|
||||
|
||||
# Train with Augmented data
|
||||
listner.env = aug_env
|
||||
args.ml_weight = 0.2
|
||||
listner.train(1, feedback=feedback_method)
|
||||
|
||||
print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50)
|
||||
|
||||
# Log the training stats to tensorboard
|
||||
total = max(sum(listner.logs['total']), 1)
|
||||
length = max(len(listner.logs['critic_loss']), 1)
|
||||
critic_loss = sum(listner.logs['critic_loss']) / total
|
||||
RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
|
||||
IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
|
||||
entropy = sum(listner.logs['entropy']) / total
|
||||
writer.add_scalar("loss/critic", critic_loss, idx)
|
||||
writer.add_scalar("policy_entropy", entropy, idx)
|
||||
writer.add_scalar("loss/RL_loss", RL_loss, idx)
|
||||
writer.add_scalar("loss/IL_loss", IL_loss, idx)
|
||||
writer.add_scalar("total_actions", total, idx)
|
||||
writer.add_scalar("max_length", length, idx)
|
||||
# print("total_actions", total, ", max_length", length)
|
||||
|
||||
# Run validation
|
||||
loss_str = "iter {}".format(iter)
|
||||
for env_name, (env, evaluator) in val_envs.items():
|
||||
listner.env = env
|
||||
|
||||
# Get validation distance from goal under test evaluation conditions
|
||||
listner.test(use_dropout=False, feedback='argmax', iters=None)
|
||||
result = listner.get_results()
|
||||
score_summary, _ = evaluator.score(result)
|
||||
loss_str += ", %s " % env_name
|
||||
for metric, val in score_summary.items():
|
||||
if metric in ['spl']:
|
||||
writer.add_scalar("spl/%s" % env_name, val, idx)
|
||||
if env_name in best_val:
|
||||
if val > best_val[env_name]['spl']:
|
||||
best_val[env_name]['spl'] = val
|
||||
best_val[env_name]['update'] = True
|
||||
elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
|
||||
best_val[env_name]['spl'] = val
|
||||
best_val[env_name]['update'] = True
|
||||
loss_str += ', %s: %.4f' % (metric, val)
|
||||
|
||||
record_file = open('./logs/' + args.name + '.txt', 'a')
|
||||
record_file.write(loss_str + '\n')
|
||||
record_file.close()
|
||||
|
||||
for env_name in best_val:
|
||||
if best_val[env_name]['update']:
|
||||
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
|
||||
best_val[env_name]['update'] = False
|
||||
listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name)))
|
||||
else:
|
||||
listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict"))
|
||||
|
||||
print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters),
|
||||
iter, float(iter)/n_iters*100, loss_str)))
|
||||
|
||||
if iter % 1000 == 0:
|
||||
print("BEST RESULT TILL NOW")
|
||||
for env_name in best_val:
|
||||
print(env_name, best_val[env_name]['state'])
|
||||
|
||||
record_file = open('./logs/' + args.name + '.txt', 'a')
|
||||
record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n')
|
||||
record_file.close()
|
||||
|
||||
listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx)))
|
||||
|
||||
|
||||
def valid(train_env, tok, val_envs={}):
|
||||
agent = Seq2SeqAgent(train_env, "", tok, args.maxAction)
|
||||
|
||||
print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load))
|
||||
|
||||
for env_name, (env, evaluator) in val_envs.items():
|
||||
agent.logs = defaultdict(list)
|
||||
agent.env = env
|
||||
|
||||
iters = None
|
||||
agent.test(use_dropout=False, feedback='argmax', iters=iters)
|
||||
result = agent.get_results()
|
||||
|
||||
if env_name != '':
|
||||
score_summary, _ = evaluator.score(result)
|
||||
loss_str = "Env name: %s" % env_name
|
||||
for metric,val in score_summary.items():
|
||||
loss_str += ', %s: %.4f' % (metric, val)
|
||||
print(loss_str)
|
||||
|
||||
if args.submit:
|
||||
json.dump(
|
||||
result,
|
||||
open(os.path.join(log_dir, "submit_%s.json" % env_name), 'w'),
|
||||
sort_keys=True, indent=4, separators=(',', ': ')
|
||||
)
|
||||
|
||||
def setup():
|
||||
torch.manual_seed(1)
|
||||
torch.cuda.manual_seed(1)
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
def train_val(test_only=False):
|
||||
''' Train on the training set, and validate on seen and unseen splits. '''
|
||||
setup()
|
||||
tok = get_tokenizer(args)
|
||||
|
||||
feat_dict = read_img_features(features, test_only=test_only)
|
||||
|
||||
if test_only:
|
||||
featurized_scans = None
|
||||
val_env_names = ['val_train_seen']
|
||||
else:
|
||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
|
||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||
from collections import OrderedDict
|
||||
|
||||
if args.submit:
|
||||
val_env_names.append('test')
|
||||
else:
|
||||
pass
|
||||
|
||||
val_envs = OrderedDict(
|
||||
((split,
|
||||
(R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok),
|
||||
Evaluation([split], featurized_scans, tok))
|
||||
)
|
||||
for split in val_env_names
|
||||
)
|
||||
)
|
||||
|
||||
if args.train == 'listener':
|
||||
train(train_env, tok, args.iters, val_envs=val_envs)
|
||||
elif args.train == 'validlistener':
|
||||
valid(train_env, tok, val_envs=val_envs)
|
||||
else:
|
||||
assert False
|
||||
|
||||
def train_val_augment(test_only=False):
|
||||
"""
|
||||
Train the listener with the augmented data
|
||||
"""
|
||||
setup()
|
||||
|
||||
# Create a batch training environment that will also preprocess text
|
||||
tok_bert = get_tokenizer(args)
|
||||
|
||||
# Load the env img features
|
||||
feat_dict = read_img_features(features, test_only=test_only)
|
||||
|
||||
if test_only:
|
||||
featurized_scans = None
|
||||
val_env_names = ['val_train_seen']
|
||||
else:
|
||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
|
||||
# Load the augmentation data
|
||||
aug_path = args.aug
|
||||
# Create the training environment
|
||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert)
|
||||
aug_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug')
|
||||
|
||||
# Setup the validation data
|
||||
val_envs = {split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert),
|
||||
Evaluation([split], featurized_scans, tok_bert))
|
||||
for split in val_env_names}
|
||||
|
||||
# Start training
|
||||
train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.train in ['listener', 'validlistener']:
|
||||
train_val(test_only=args.test_only)
|
||||
elif args.train == 'auglistener':
|
||||
train_val_augment(test_only=args.test_only)
|
||||
else:
|
||||
assert False
|
||||
674
r2r_src/utils.py
Normal file
674
r2r_src/utils.py
Normal file
@ -0,0 +1,674 @@
|
||||
''' Utils for io, language, connectivity graphs etc '''
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
sys.path.append('Matterport_Simulator/build/')
|
||||
import MatterSim
|
||||
import string
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
from collections import Counter, defaultdict
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from param import args
|
||||
from numpy.linalg import norm
|
||||
|
||||
|
||||
# padding, unknown word, end of sentence
|
||||
base_vocab = ['<PAD>', '<UNK>', '<EOS>']
|
||||
padding_idx = base_vocab.index('<PAD>')
|
||||
|
||||
def load_nav_graphs(scans):
|
||||
''' Load connectivity graph for each scan '''
|
||||
|
||||
def distance(pose1, pose2):
|
||||
''' Euclidean distance between two graph poses '''
|
||||
return ((pose1['pose'][3]-pose2['pose'][3])**2\
|
||||
+ (pose1['pose'][7]-pose2['pose'][7])**2\
|
||||
+ (pose1['pose'][11]-pose2['pose'][11])**2)**0.5
|
||||
|
||||
graphs = {}
|
||||
for scan in scans:
|
||||
with open('connectivity/%s_connectivity.json' % scan) as f:
|
||||
G = nx.Graph()
|
||||
positions = {}
|
||||
data = json.load(f)
|
||||
for i,item in enumerate(data):
|
||||
if item['included']:
|
||||
for j,conn in enumerate(item['unobstructed']):
|
||||
if conn and data[j]['included']:
|
||||
positions[item['image_id']] = np.array([item['pose'][3],
|
||||
item['pose'][7], item['pose'][11]]);
|
||||
assert data[j]['unobstructed'][i], 'Graph should be undirected'
|
||||
G.add_edge(item['image_id'],data[j]['image_id'],weight=distance(item,data[j]))
|
||||
nx.set_node_attributes(G, values=positions, name='position')
|
||||
graphs[scan] = G
|
||||
return graphs
|
||||
|
||||
|
||||
def load_datasets(splits):
|
||||
"""
|
||||
|
||||
:param splits: A list of split.
|
||||
if the split is "something@5000", it will use a random 5000 data from the data
|
||||
:return:
|
||||
"""
|
||||
import random
|
||||
data = []
|
||||
old_state = random.getstate()
|
||||
for split in splits:
|
||||
# It only needs some part of the dataset?
|
||||
components = split.split("@")
|
||||
number = -1
|
||||
if len(components) > 1:
|
||||
split, number = components[0], int(components[1])
|
||||
|
||||
# Load Json
|
||||
# if split in ['train', 'val_seen', 'val_unseen', 'test',
|
||||
# 'val_unseen_half1', 'val_unseen_half2', 'val_seen_half1', 'val_seen_half2']: # Add two halves for sanity check
|
||||
if "/" not in split:
|
||||
with open('data/R2R_%s.json' % split) as f:
|
||||
new_data = json.load(f)
|
||||
else:
|
||||
print('\nLoading prevalent data for pretraining...')
|
||||
with open(split) as f:
|
||||
new_data = json.load(f)
|
||||
|
||||
# Partition
|
||||
if number > 0:
|
||||
random.seed(0) # Make the data deterministic, additive
|
||||
random.shuffle(new_data)
|
||||
new_data = new_data[:number]
|
||||
|
||||
# Join
|
||||
data += new_data
|
||||
random.setstate(old_state) # Recover the state of the random generator
|
||||
return data
|
||||
|
||||
|
||||
def pad_instr_tokens(instr_tokens, maxlength=20):
|
||||
|
||||
if len(instr_tokens) <= 2: #assert len(raw_instr_tokens) > 2
|
||||
return None
|
||||
|
||||
if len(instr_tokens) > maxlength - 2: # -2 for [CLS] and [SEP]
|
||||
instr_tokens = instr_tokens[:(maxlength-2)]
|
||||
|
||||
instr_tokens = ['[CLS]'] + instr_tokens + ['[SEP]']
|
||||
num_words = len(instr_tokens) # - 1 # include [SEP]
|
||||
instr_tokens += ['[PAD]'] * (maxlength-len(instr_tokens))
|
||||
|
||||
assert len(instr_tokens) == maxlength
|
||||
|
||||
return instr_tokens, num_words
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
''' Class to tokenize and encode a sentence. '''
|
||||
SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)') # Split on any non-alphanumeric character
|
||||
|
||||
def __init__(self, vocab=None, encoding_length=20):
|
||||
self.encoding_length = encoding_length
|
||||
self.vocab = vocab
|
||||
self.word_to_index = {}
|
||||
self.index_to_word = {}
|
||||
if vocab:
|
||||
for i,word in enumerate(vocab):
|
||||
self.word_to_index[word] = i
|
||||
new_w2i = defaultdict(lambda: self.word_to_index['<UNK>'])
|
||||
new_w2i.update(self.word_to_index)
|
||||
self.word_to_index = new_w2i
|
||||
for key, value in self.word_to_index.items():
|
||||
self.index_to_word[value] = key
|
||||
old = self.vocab_size()
|
||||
self.add_word('<BOS>')
|
||||
assert self.vocab_size() == old+1
|
||||
print("OLD_VOCAB_SIZE", old)
|
||||
print("VOCAB_SIZE", self.vocab_size())
|
||||
print("VOACB", len(vocab))
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
This is used for debug
|
||||
"""
|
||||
self.word_to_index = dict(self.word_to_index) # To avoid using mis-typing tokens
|
||||
|
||||
def add_word(self, word):
|
||||
assert word not in self.word_to_index
|
||||
self.word_to_index[word] = self.vocab_size() # vocab_size() is the
|
||||
self.index_to_word[self.vocab_size()] = word
|
||||
|
||||
@staticmethod
|
||||
def split_sentence(sentence):
|
||||
''' Break sentence into a list of words and punctuation '''
|
||||
toks = []
|
||||
for word in [s.strip().lower() for s in Tokenizer.SENTENCE_SPLIT_REGEX.split(sentence.strip()) if len(s.strip()) > 0]:
|
||||
# Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..'
|
||||
if all(c in string.punctuation for c in word) and not all(c in '.' for c in word):
|
||||
toks += list(word)
|
||||
else:
|
||||
toks.append(word)
|
||||
return toks
|
||||
|
||||
def vocab_size(self):
|
||||
return len(self.index_to_word)
|
||||
|
||||
def encode_sentence(self, sentence, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = self.encoding_length
|
||||
if len(self.word_to_index) == 0:
|
||||
sys.exit('Tokenizer has no vocab')
|
||||
|
||||
encoding = [self.word_to_index['<BOS>']]
|
||||
for word in self.split_sentence(sentence):
|
||||
encoding.append(self.word_to_index[word]) # Default Dict
|
||||
encoding.append(self.word_to_index['<EOS>'])
|
||||
|
||||
if len(encoding) <= 2:
|
||||
return None
|
||||
#assert len(encoding) > 2
|
||||
|
||||
if len(encoding) < max_length:
|
||||
encoding += [self.word_to_index['<PAD>']] * (max_length-len(encoding)) # Padding
|
||||
elif len(encoding) > max_length:
|
||||
encoding[max_length - 1] = self.word_to_index['<EOS>'] # Cut the length with EOS
|
||||
|
||||
return np.array(encoding[:max_length])
|
||||
|
||||
def decode_sentence(self, encoding, length=None):
|
||||
sentence = []
|
||||
if length is not None:
|
||||
encoding = encoding[:length]
|
||||
for ix in encoding:
|
||||
if ix == self.word_to_index['<PAD>']:
|
||||
break
|
||||
else:
|
||||
sentence.append(self.index_to_word[ix])
|
||||
return " ".join(sentence)
|
||||
|
||||
def shrink(self, inst):
|
||||
"""
|
||||
:param inst: The id inst
|
||||
:return: Remove the potential <BOS> and <EOS>
|
||||
If no <EOS> return empty list
|
||||
"""
|
||||
if len(inst) == 0:
|
||||
return inst
|
||||
end = np.argmax(np.array(inst) == self.word_to_index['<EOS>']) # If no <EOS>, return empty string
|
||||
if len(inst) > 1 and inst[0] == self.word_to_index['<BOS>']:
|
||||
start = 1
|
||||
else:
|
||||
start = 0
|
||||
# print(inst, start, end)
|
||||
return inst[start: end]
|
||||
|
||||
|
||||
def build_vocab(splits=['train'], min_count=5, start_vocab=base_vocab):
|
||||
''' Build a vocab, starting with base vocab containing a few useful tokens. '''
|
||||
count = Counter()
|
||||
t = Tokenizer()
|
||||
data = load_datasets(splits)
|
||||
for item in data:
|
||||
for instr in item['instructions']:
|
||||
count.update(t.split_sentence(instr))
|
||||
vocab = list(start_vocab)
|
||||
for word,num in count.most_common():
|
||||
if num >= min_count:
|
||||
vocab.append(word)
|
||||
else:
|
||||
break
|
||||
return vocab
|
||||
|
||||
|
||||
def write_vocab(vocab, path):
|
||||
print('Writing vocab of size %d to %s' % (len(vocab),path))
|
||||
with open(path, 'w') as f:
|
||||
for word in vocab:
|
||||
f.write("%s\n" % word)
|
||||
|
||||
|
||||
def read_vocab(path):
|
||||
with open(path) as f:
|
||||
vocab = [word.strip() for word in f.readlines()]
|
||||
return vocab
|
||||
|
||||
|
||||
def asMinutes(s):
|
||||
m = math.floor(s / 60)
|
||||
s -= m * 60
|
||||
return '%dm %ds' % (m, s)
|
||||
|
||||
|
||||
def timeSince(since, percent):
|
||||
now = time.time()
|
||||
s = now - since
|
||||
es = s / (percent)
|
||||
rs = es - s
|
||||
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
|
||||
|
||||
def read_img_features(feature_store, test_only=False):
|
||||
import csv
|
||||
import base64
|
||||
from tqdm import tqdm
|
||||
|
||||
print("Start loading the image feature ... (~50 seconds)")
|
||||
start = time.time()
|
||||
|
||||
if "detectfeat" in args.features:
|
||||
views = int(args.features[10:])
|
||||
else:
|
||||
views = 36
|
||||
|
||||
args.views = views
|
||||
|
||||
tsv_fieldnames = ['scanId', 'viewpointId', 'image_w', 'image_h', 'vfov', 'features']
|
||||
|
||||
if not test_only:
|
||||
features = {}
|
||||
with open(feature_store, "r") as tsv_in_file: # Open the tsv file.
|
||||
reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=tsv_fieldnames)
|
||||
for item in reader:
|
||||
long_id = item['scanId'] + "_" + item['viewpointId']
|
||||
features[long_id] = np.frombuffer(base64.decodestring(item['features'].encode('ascii')),
|
||||
dtype=np.float32).reshape((views, -1)) # Feature of long_id is (36, 2048)
|
||||
else:
|
||||
features = None
|
||||
|
||||
print("Finish Loading the image feature from %s in %0.4f seconds" % (feature_store, time.time() - start))
|
||||
return features
|
||||
|
||||
def read_candidates(candidates_store):
|
||||
import csv
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
print("Start loading the candidate feature")
|
||||
|
||||
start = time.time()
|
||||
|
||||
TSV_FIELDNAMES = ['scanId', 'viewpointId', 'heading', 'elevation', 'next', 'pointId', 'idx', 'feature']
|
||||
candidates = defaultdict(lambda: list())
|
||||
items = 0
|
||||
with open(candidates_store, "r") as tsv_in_file: # Open the tsv file.
|
||||
reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=TSV_FIELDNAMES)
|
||||
for item in reader:
|
||||
long_id = item['scanId'] + "_" + item['viewpointId']
|
||||
candidates[long_id].append(
|
||||
{'heading': float(item['heading']),
|
||||
'elevation': float(item['elevation']),
|
||||
'scanId': item['scanId'],
|
||||
'viewpointId': item['next'],
|
||||
'pointId': int(item['pointId']),
|
||||
'idx': int(item['idx']) + 1, # Because a bug in the precompute code, here +1 is important
|
||||
'feature': np.frombuffer(
|
||||
base64.decodestring(item['feature'].encode('ascii')),
|
||||
dtype=np.float32)
|
||||
}
|
||||
)
|
||||
items += 1
|
||||
|
||||
for long_id in candidates:
|
||||
assert (len(candidates[long_id])) != 0
|
||||
|
||||
assert sum(len(candidate) for candidate in candidates.values()) == items
|
||||
|
||||
# candidate = candidates[long_id]
|
||||
# print(candidate)
|
||||
print("Finish Loading the candidates from %s in %0.4f seconds" % (candidates_store, time.time() - start))
|
||||
candidates = dict(candidates)
|
||||
return candidates
|
||||
|
||||
def add_exploration(paths):
|
||||
explore = json.load(open("data/exploration.json", 'r'))
|
||||
inst2explore = {path['instr_id']: path['trajectory'] for path in explore}
|
||||
for path in paths:
|
||||
path['trajectory'] = inst2explore[path['instr_id']] + path['trajectory']
|
||||
return paths
|
||||
|
||||
def angle_feature(heading, elevation):
|
||||
|
||||
import math
|
||||
# twopi = math.pi * 2
|
||||
# heading = (heading + twopi) % twopi # From 0 ~ 2pi
|
||||
# It will be the same
|
||||
return np.array([math.sin(heading), math.cos(heading),
|
||||
math.sin(elevation), math.cos(elevation)] * (args.angle_feat_size // 4),
|
||||
dtype=np.float32)
|
||||
|
||||
def new_simulator():
|
||||
import MatterSim
|
||||
# Simulator image parameters
|
||||
WIDTH = 640
|
||||
HEIGHT = 480
|
||||
VFOV = 60
|
||||
|
||||
sim = MatterSim.Simulator()
|
||||
sim.setRenderingEnabled(False)
|
||||
sim.setCameraResolution(WIDTH, HEIGHT)
|
||||
sim.setCameraVFOV(math.radians(VFOV))
|
||||
sim.setDiscretizedViewingAngles(True)
|
||||
sim.init()
|
||||
|
||||
return sim
|
||||
|
||||
def get_point_angle_feature(baseViewId=0):
|
||||
sim = new_simulator()
|
||||
|
||||
feature = np.empty((36, args.angle_feat_size), np.float32)
|
||||
base_heading = (baseViewId % 12) * math.radians(30)
|
||||
for ix in range(36):
|
||||
if ix == 0:
|
||||
sim.newEpisode('ZMojNkEp431', '2f4d90acd4024c269fb0efe49a8ac540', 0, math.radians(-30))
|
||||
elif ix % 12 == 0:
|
||||
sim.makeAction(0, 1.0, 1.0)
|
||||
else:
|
||||
sim.makeAction(0, 1.0, 0)
|
||||
|
||||
state = sim.getState()
|
||||
assert state.viewIndex == ix
|
||||
|
||||
heading = state.heading - base_heading
|
||||
|
||||
feature[ix, :] = angle_feature(heading, state.elevation)
|
||||
return feature
|
||||
|
||||
def get_all_point_angle_feature():
|
||||
return [get_point_angle_feature(baseViewId) for baseViewId in range(36)]
|
||||
|
||||
def add_idx(inst):
|
||||
toks = Tokenizer.split_sentence(inst)
|
||||
return " ".join([str(idx)+tok for idx, tok in enumerate(toks)])
|
||||
|
||||
import signal
|
||||
class GracefulKiller:
|
||||
kill_now = False
|
||||
def __init__(self):
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||
|
||||
def exit_gracefully(self,signum, frame):
|
||||
self.kill_now = True
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.cul = OrderedDict()
|
||||
self.start = {}
|
||||
self.iter = 0
|
||||
|
||||
def reset(self):
|
||||
self.cul = OrderedDict()
|
||||
self.start = {}
|
||||
self.iter = 0
|
||||
|
||||
def tic(self, key):
|
||||
self.start[key] = time.time()
|
||||
|
||||
def toc(self, key):
|
||||
delta = time.time() - self.start[key]
|
||||
if key not in self.cul:
|
||||
self.cul[key] = delta
|
||||
else:
|
||||
self.cul[key] += delta
|
||||
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
|
||||
def show(self):
|
||||
total = sum(self.cul.values())
|
||||
for key in self.cul:
|
||||
print("%s, total time %0.2f, avg time %0.2f, part of %0.2f" %
|
||||
(key, self.cul[key], self.cul[key]*1./self.iter, self.cul[key]*1./total))
|
||||
print(total / self.iter)
|
||||
|
||||
|
||||
stop_word_list = [
|
||||
",", ".", "and", "?", "!"
|
||||
]
|
||||
|
||||
|
||||
def stop_words_location(inst, mask=False):
|
||||
toks = Tokenizer.split_sentence(inst)
|
||||
sws = [i for i, tok in enumerate(toks) if tok in stop_word_list] # The index of the stop words
|
||||
if len(sws) == 0 or sws[-1] != (len(toks)-1): # Add the index of the last token
|
||||
sws.append(len(toks)-1)
|
||||
sws = [x for x, y in zip(sws[:-1], sws[1:]) if x+1 != y] + [sws[-1]] # Filter the adjacent stop word
|
||||
sws_mask = np.ones(len(toks), np.int32) # Create the mask
|
||||
sws_mask[sws] = 0
|
||||
return sws_mask if mask else sws
|
||||
|
||||
def get_segments(inst, mask=False):
|
||||
toks = Tokenizer.split_sentence(inst)
|
||||
sws = [i for i, tok in enumerate(toks) if tok in stop_word_list] # The index of the stop words
|
||||
sws = [-1] + sws + [len(toks)] # Add the <start> and <end> positions
|
||||
segments = [toks[sws[i]+1:sws[i+1]] for i in range(len(sws)-1)] # Slice the segments from the tokens
|
||||
segments = list(filter(lambda x: len(x)>0, segments)) # remove the consecutive stop words
|
||||
return segments
|
||||
|
||||
def clever_pad_sequence(sequences, batch_first=True, padding_value=0):
|
||||
max_size = sequences[0].size()
|
||||
max_len, trailing_dims = max_size[0], max_size[1:]
|
||||
max_len = max(seq.size()[0] for seq in sequences)
|
||||
if batch_first:
|
||||
out_dims = (len(sequences), max_len) + trailing_dims
|
||||
else:
|
||||
out_dims = (max_len, len(sequences)) + trailing_dims
|
||||
if padding_value is not None:
|
||||
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
|
||||
for i, tensor in enumerate(sequences):
|
||||
length = tensor.size(0)
|
||||
# use index notation to prevent duplicate references to the tensor
|
||||
if batch_first:
|
||||
out_tensor[i, :length, ...] = tensor
|
||||
else:
|
||||
out_tensor[:length, i, ...] = tensor
|
||||
|
||||
return out_tensor
|
||||
|
||||
import torch
|
||||
def length2mask(length, size=None):
|
||||
batch_size = len(length)
|
||||
size = int(max(length)) if size is None else size
|
||||
mask = (torch.arange(size, dtype=torch.int64).unsqueeze(0).repeat(batch_size, 1)
|
||||
> (torch.LongTensor(length) - 1).unsqueeze(1)).cuda()
|
||||
return mask
|
||||
|
||||
def average_length(path2inst):
|
||||
length = []
|
||||
|
||||
for name in path2inst:
|
||||
datum = path2inst[name]
|
||||
length.append(len(datum))
|
||||
return sum(length) / len(length)
|
||||
|
||||
def tile_batch(tensor, multiplier):
|
||||
_, *s = tensor.size()
|
||||
tensor = tensor.unsqueeze(1).expand(-1, multiplier, *(-1,) * len(s)).contiguous().view(-1, *s)
|
||||
return tensor
|
||||
|
||||
def viewpoint_drop_mask(viewpoint, seed=None, drop_func=None):
|
||||
local_seed = hash(viewpoint) ^ seed
|
||||
torch.random.manual_seed(local_seed)
|
||||
drop_mask = drop_func(torch.ones(2048).cuda())
|
||||
return drop_mask
|
||||
|
||||
|
||||
class FloydGraph:
|
||||
def __init__(self):
|
||||
self._dis = defaultdict(lambda :defaultdict(lambda: 95959595))
|
||||
self._point = defaultdict(lambda :defaultdict(lambda: ""))
|
||||
self._visited = set()
|
||||
|
||||
def distance(self, x, y):
|
||||
if x == y:
|
||||
return 0
|
||||
else:
|
||||
return self._dis[x][y]
|
||||
|
||||
def add_edge(self, x, y, dis):
|
||||
if dis < self._dis[x][y]:
|
||||
self._dis[x][y] = dis
|
||||
self._dis[y][x] = dis
|
||||
self._point[x][y] = ""
|
||||
self._point[y][x] = ""
|
||||
|
||||
def update(self, k):
|
||||
for x in self._dis:
|
||||
for y in self._dis:
|
||||
if x != y:
|
||||
if self._dis[x][k] + self._dis[k][y] < self._dis[x][y]:
|
||||
self._dis[x][y] = self._dis[x][k] + self._dis[k][y]
|
||||
self._dis[y][x] = self._dis[x][y]
|
||||
self._point[x][y] = k
|
||||
self._point[y][x] = k
|
||||
self._visited.add(k)
|
||||
|
||||
def visited(self, k):
|
||||
return (k in self._visited)
|
||||
|
||||
def path(self, x, y):
|
||||
"""
|
||||
:param x: start
|
||||
:param y: end
|
||||
:return: the path from x to y [v1, v2, ..., v_n, y]
|
||||
"""
|
||||
if x == y:
|
||||
return []
|
||||
if self._point[x][y] == "": # Direct edge
|
||||
return [y]
|
||||
else:
|
||||
k = self._point[x][y]
|
||||
# print(x, y, k)
|
||||
# for x1 in (x, k, y):
|
||||
# for x2 in (x, k, y):
|
||||
# print(x1, x2, "%.4f" % self._dis[x1][x2])
|
||||
return self.path(x, k) + self.path(k, y)
|
||||
|
||||
def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_length=100):
|
||||
"""
|
||||
Call in a loop to create terminal progress bar
|
||||
@params:
|
||||
iteration - Required : current iteration (Int)
|
||||
total - Required : total iterations (Int)
|
||||
prefix - Optional : prefix string (Str)
|
||||
suffix - Optional : suffix string (Str)
|
||||
decimals - Optional : positive number of decimals in percent complete (Int)
|
||||
bar_length - Optional : character length of bar (Int)
|
||||
"""
|
||||
str_format = "{0:." + str(decimals) + "f}"
|
||||
percents = str_format.format(100 * (iteration / float(total)))
|
||||
filled_length = int(round(bar_length * iteration / float(total)))
|
||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||
|
||||
sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
|
||||
|
||||
if iteration == total:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
def ndtw_initialize():
|
||||
ndtw_criterion = {}
|
||||
scan_gts_dir = '/students/u5399302/MatterportData/data/id_paths.json'
|
||||
with open(scan_gts_dir) as f_:
|
||||
scan_gts = json.load(f_)
|
||||
all_scan_ids = []
|
||||
for key in scan_gts:
|
||||
path_scan_id = scan_gts[key][0]
|
||||
# print('path_scan_id', path_scan_id)
|
||||
if path_scan_id not in all_scan_ids:
|
||||
all_scan_ids.append(path_scan_id)
|
||||
ndtw_graph = ndtw_graphload(path_scan_id)
|
||||
ndtw_criterion[path_scan_id] = DTW(ndtw_graph)
|
||||
return ndtw_criterion
|
||||
|
||||
def ndtw_graphload(scan):
|
||||
"""Loads a networkx graph for a given scan.
|
||||
Args:
|
||||
connections_file: A string with the path to the .json file with the
|
||||
connectivity information.
|
||||
Returns:
|
||||
A networkx graph.
|
||||
"""
|
||||
connections_file = 'connectivity/{}_connectivity.json'.format(scan)
|
||||
with open(connections_file) as f:
|
||||
lines = json.load(f)
|
||||
nodes = np.array([x['image_id'] for x in lines])
|
||||
matrix = np.array([x['unobstructed'] for x in lines])
|
||||
mask = np.array([x['included'] for x in lines])
|
||||
|
||||
matrix = matrix[mask][:, mask]
|
||||
nodes = nodes[mask]
|
||||
|
||||
pos2d = {x['image_id']: np.array(x['pose'])[[3, 7]] for x in lines}
|
||||
pos3d = {x['image_id']: np.array(x['pose'])[[3, 7, 11]] for x in lines}
|
||||
|
||||
graph = nx.from_numpy_matrix(matrix)
|
||||
graph = nx.relabel.relabel_nodes(graph, dict(enumerate(nodes)))
|
||||
nx.set_node_attributes(graph, pos2d, 'pos2d')
|
||||
nx.set_node_attributes(graph, pos3d, 'pos3d')
|
||||
|
||||
weight2d = {(u, v): norm(pos2d[u] - pos2d[v]) for u, v in graph.edges}
|
||||
weight3d = {(u, v): norm(pos3d[u] - pos3d[v]) for u, v in graph.edges}
|
||||
nx.set_edge_attributes(graph, weight2d, 'weight2d')
|
||||
nx.set_edge_attributes(graph, weight3d, 'weight3d')
|
||||
|
||||
return graph
|
||||
|
||||
class DTW(object):
|
||||
"""Dynamic Time Warping (DTW) evaluation metrics.
|
||||
Python doctest:
|
||||
>>> graph = nx.grid_graph([3, 4])
|
||||
>>> prediction = [(0, 0), (1, 0), (2, 0), (3, 0)]
|
||||
>>> reference = [(0, 0), (1, 0), (2, 1), (3, 2)]
|
||||
>>> dtw = DTW(graph)
|
||||
>>> assert np.isclose(dtw(prediction, reference, 'dtw'), 3.0)
|
||||
>>> assert np.isclose(dtw(prediction, reference, 'ndtw'), 0.77880078307140488)
|
||||
>>> assert np.isclose(dtw(prediction, reference, 'sdtw'), 0.77880078307140488)
|
||||
>>> assert np.isclose(dtw(prediction[:2], reference, 'sdtw'), 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, graph, weight='weight', threshold=3.0):
|
||||
"""Initializes a DTW object.
|
||||
Args:
|
||||
graph: networkx graph for the environment.
|
||||
weight: networkx edge weight key (str).
|
||||
threshold: distance threshold $d_{th}$ (float).
|
||||
"""
|
||||
self.graph = graph
|
||||
self.weight = weight
|
||||
self.threshold = threshold
|
||||
self.distance = dict(
|
||||
nx.all_pairs_dijkstra_path_length(self.graph, weight=self.weight))
|
||||
|
||||
def __call__(self, prediction, reference, metric='sdtw'):
|
||||
"""Computes DTW metrics.
|
||||
Args:
|
||||
prediction: list of nodes (str), path predicted by agent.
|
||||
reference: list of nodes (str), the ground truth path.
|
||||
metric: one of ['ndtw', 'sdtw', 'dtw'].
|
||||
Returns:
|
||||
the DTW between the prediction and reference path (float).
|
||||
"""
|
||||
assert metric in ['ndtw', 'sdtw', 'dtw']
|
||||
|
||||
dtw_matrix = np.inf * np.ones((len(prediction) + 1, len(reference) + 1))
|
||||
dtw_matrix[0][0] = 0
|
||||
for i in range(1, len(prediction)+1):
|
||||
for j in range(1, len(reference)+1):
|
||||
best_previous_cost = min(
|
||||
dtw_matrix[i-1][j], dtw_matrix[i][j-1], dtw_matrix[i-1][j-1])
|
||||
cost = self.distance[prediction[i-1]][reference[j-1]]
|
||||
dtw_matrix[i][j] = cost + best_previous_cost
|
||||
dtw = dtw_matrix[len(prediction)][len(reference)]
|
||||
|
||||
if metric == 'dtw':
|
||||
return dtw
|
||||
|
||||
ndtw = np.exp(-dtw/(self.threshold * len(reference)))
|
||||
if metric == 'ndtw':
|
||||
return ndtw
|
||||
|
||||
success = self.distance[prediction[-1]][reference[-1]] <= self.threshold
|
||||
return success * ndtw
|
||||
283
r2r_src/vlnbert/vlnbert_OSCAR.py
Normal file
283
r2r_src/vlnbert/vlnbert_OSCAR.py
Normal file
@ -0,0 +1,283 @@
|
||||
# Copyright (c) 2020 Microsoft Corporation. Licensed under the MIT license.
|
||||
# Modified in Recurrent VLN-BERT, 2020, Yicong.Hong@anu.edu.au
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import logging
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
import sys
|
||||
sys.path.append('Oscar/Oscar')
|
||||
from transformers.pytorch_transformers.modeling_bert import (BertEmbeddings,
|
||||
BertSelfAttention, BertAttention, BertEncoder, BertLayer,
|
||||
BertSelfOutput, BertIntermediate, BertOutput,
|
||||
BertPooler, BertLayerNorm, BertPreTrainedModel,
|
||||
BertPredictionHeadTransform)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CaptionBertSelfAttention(BertSelfAttention):
|
||||
"""
|
||||
Modified from BertSelfAttention to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertSelfAttention, self).__init__(config)
|
||||
self.config = config
|
||||
|
||||
def forward(self, mode, hidden_states, attention_mask, head_mask=None,
|
||||
history_state=None):
|
||||
if history_state is not None:
|
||||
x_states = torch.cat([history_state, hidden_states], dim=1)
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(x_states)
|
||||
mixed_value_layer = self.value(x_states)
|
||||
else:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
if mode == 'visual':
|
||||
mixed_query_layer = mixed_query_layer[:, [0]+list(range(-self.config.directions, 0)), :]
|
||||
|
||||
''' language feature only provide Keys and Values '''
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_scores)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CaptionBertAttention(BertAttention):
|
||||
"""
|
||||
Modified from BertAttention to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertAttention, self).__init__(config)
|
||||
self.self = CaptionBertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.config = config
|
||||
|
||||
def forward(self, mode, input_tensor, attention_mask, head_mask=None,
|
||||
history_state=None):
|
||||
''' transformer processing '''
|
||||
self_outputs = self.self(mode, input_tensor, attention_mask, head_mask, history_state)
|
||||
|
||||
''' feed-forward network with residule '''
|
||||
if mode == 'visual':
|
||||
attention_output = self.output(self_outputs[0], input_tensor[:, [0]+list(range(-self.config.directions, 0)), :])
|
||||
if mode == 'language':
|
||||
attention_output = self.output(self_outputs[0], input_tensor)
|
||||
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CaptionBertLayer(BertLayer):
|
||||
"""
|
||||
Modified from BertLayer to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertLayer, self).__init__(config)
|
||||
self.attention = CaptionBertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, mode, hidden_states, attention_mask, head_mask=None,
|
||||
history_state=None):
|
||||
|
||||
attention_outputs = self.attention(mode, hidden_states, attention_mask,
|
||||
head_mask, history_state)
|
||||
|
||||
''' feed-forward network with residule '''
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CaptionBertEncoder(BertEncoder):
|
||||
"""
|
||||
Modified from BertEncoder to add support for output_hidden_states.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CaptionBertEncoder, self).__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
# 12 Bert layers
|
||||
self.layer = nn.ModuleList([CaptionBertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.config = config
|
||||
|
||||
def forward(self, mode, hidden_states, attention_mask, head_mask=None,
|
||||
encoder_history_states=None):
|
||||
|
||||
if mode == 'visual':
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
history_state = None if encoder_history_states is None else encoder_history_states[i]
|
||||
|
||||
layer_outputs = layer_module(mode,
|
||||
hidden_states, attention_mask, head_mask[i],
|
||||
history_state)
|
||||
|
||||
concat_layer_outputs = torch.cat((layer_outputs[0][:,0:1,:], hidden_states[:,1:-self.config.directions,:], layer_outputs[0][:,1:self.config.directions+1,:]), 1)
|
||||
hidden_states = concat_layer_outputs
|
||||
|
||||
if i == self.config.num_hidden_layers - 1:
|
||||
state_attention_score = layer_outputs[1][:, :, 0, :]
|
||||
lang_attention_score = layer_outputs[1][:, :, -self.config.directions:, 1:-self.config.directions]
|
||||
vis_attention_score = layer_outputs[1][:, :, :, :]
|
||||
|
||||
outputs = (hidden_states, state_attention_score, lang_attention_score, vis_attention_score)
|
||||
|
||||
elif mode == 'language':
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
history_state = None if encoder_history_states is None else encoder_history_states[i] # default None
|
||||
|
||||
layer_outputs = layer_module(mode,
|
||||
hidden_states, attention_mask, head_mask[i],
|
||||
history_state)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if i == self.config.num_hidden_layers - 1:
|
||||
slang_attention_score = layer_outputs[1]
|
||||
|
||||
outputs = (hidden_states, slang_attention_score)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class BertImgModel(BertPreTrainedModel):
|
||||
""" Expand from BertModel to handle image region features as input
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertImgModel, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = CaptionBertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.img_dim = config.img_feature_dim
|
||||
logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, img_feats=None):
|
||||
|
||||
if attention_mask.dim() == 2:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
elif attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
if mode == 'visual':
|
||||
language_features = input_ids
|
||||
concat_embedding_output = torch.cat((language_features, img_feats), 1)
|
||||
elif mode == 'language':
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
concat_embedding_output = embedding_output
|
||||
|
||||
''' pass to the Transformer layers '''
|
||||
encoder_outputs = self.encoder(mode, concat_embedding_output,
|
||||
extended_attention_mask, head_mask=head_mask)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) # We "pool" the model by simply taking the hidden state corresponding to the first token
|
||||
|
||||
# add hidden_states and attentions if they are here
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class VLNBert(BertPreTrainedModel):
|
||||
"""
|
||||
Modified from BertForMultipleChoice to support oscar training.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(VLNBert, self).__init__(config)
|
||||
self.config = config
|
||||
self.bert = BertImgModel(config)
|
||||
|
||||
self.vis_lang_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.state_proj = nn.Linear(config.hidden_size*2, config.hidden_size, bias=True)
|
||||
self.state_LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None, attention_mask=None,
|
||||
position_ids=None, img_feats=None):
|
||||
|
||||
outputs = self.bert(mode, input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask, img_feats=img_feats)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
if mode == 'language':
|
||||
return sequence_output
|
||||
|
||||
elif mode == 'visual':
|
||||
# attention scores with respect to agent's state
|
||||
language_attentions = outputs[2][:, :, 1:-self.config.directions]
|
||||
visual_attentions = outputs[2][:, :, -self.config.directions:]
|
||||
|
||||
language_attention_scores = language_attentions.mean(dim=1) # mean over the 12 heads
|
||||
visual_attention_scores = visual_attentions.mean(dim=1)
|
||||
|
||||
# weighted_feat
|
||||
language_attention_probs = nn.Softmax(dim=-1)(language_attention_scores.clone()).unsqueeze(-1)
|
||||
visual_attention_probs = nn.Softmax(dim=-1)(visual_attention_scores.clone()).unsqueeze(-1)
|
||||
|
||||
language_seq = sequence_output[:, 1:-self.config.directions, :]
|
||||
visual_seq = sequence_output[:, -self.config.directions:, :]
|
||||
|
||||
# residual weighting, final attention to weight the raw inputs
|
||||
attended_language = (language_attention_probs * input_ids[:, 1:, :]).sum(1)
|
||||
attended_visual = (visual_attention_probs * img_feats).sum(1)
|
||||
|
||||
# update agent's state, unify history, language and vision by elementwise product
|
||||
vis_lang_feat = self.vis_lang_LayerNorm(attended_language * attended_visual)
|
||||
state_output = torch.cat((pooled_output, vis_lang_feat), dim=-1)
|
||||
state_proj = self.state_proj(state_output)
|
||||
state_proj = self.state_LayerNorm(state_proj)
|
||||
|
||||
return state_proj, visual_attention_scores
|
||||
443
r2r_src/vlnbert/vlnbert_PREVALENT.py
Normal file
443
r2r_src/vlnbert/vlnbert_PREVALENT.py
Normal file
@ -0,0 +1,443 @@
|
||||
# PREVALENT, 2020, weituo.hao@duke.edu
|
||||
# Modified in Recurrent VLN-BERT, 2020, Yicong.Hong@anu.edu.au
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
sys.path.append('Oscar/Oscar')
|
||||
from transformers.pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig
|
||||
import pdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
||||
BertLayerNorm = torch.nn.LayerNorm
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.output_attentions = True
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_scores) if self.output_attentions else (context_layer,)
|
||||
return outputs
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertAttention, self).__init__()
|
||||
self.self = BertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||
attention_output = self.output(self_outputs[0], input_tensor)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertXAttention(nn.Module):
|
||||
def __init__(self, config, ctx_dim=None):
|
||||
super().__init__()
|
||||
self.att = BertOutAttention(config, ctx_dim=ctx_dim)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
|
||||
output, attention_scores = self.att(input_tensor, ctx_tensor, ctx_att_mask)
|
||||
attention_output = self.output(output, input_tensor)
|
||||
return attention_output, attention_scores
|
||||
|
||||
|
||||
class BertOutAttention(nn.Module):
|
||||
def __init__(self, config, ctx_dim=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
# visual_dim = 2048
|
||||
if ctx_dim is None:
|
||||
ctx_dim =config.hidden_size
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(ctx_dim, self.all_head_size)
|
||||
self.value = nn.Linear(ctx_dim, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, context, attention_mask=None):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(context)
|
||||
mixed_value_layer = self.value(context)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
if attention_mask is not None:
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer, attention_scores
|
||||
|
||||
|
||||
class LXRTXLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# Lang self-att and FFN layer
|
||||
self.lang_self_att = BertAttention(config)
|
||||
self.lang_inter = BertIntermediate(config)
|
||||
self.lang_output = BertOutput(config)
|
||||
# Visn self-att and FFN layer
|
||||
self.visn_self_att = BertAttention(config)
|
||||
self.visn_inter = BertIntermediate(config)
|
||||
self.visn_output = BertOutput(config)
|
||||
# The cross attention layer
|
||||
self.visual_attention = BertXAttention(config)
|
||||
|
||||
def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
|
||||
''' Cross Attention -- cross for vision not for language '''
|
||||
visn_att_output, attention_scores = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask)
|
||||
return visn_att_output, attention_scores
|
||||
|
||||
def self_att(self, visn_input, visn_attention_mask):
|
||||
''' Self Attention -- on visual features with language clues '''
|
||||
visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)
|
||||
return visn_att_output
|
||||
|
||||
def output_fc(self, visn_input):
|
||||
''' Feed forward '''
|
||||
visn_inter_output = self.visn_inter(visn_input)
|
||||
visn_output = self.visn_output(visn_inter_output, visn_input)
|
||||
return visn_output
|
||||
|
||||
def forward(self, lang_feats, lang_attention_mask,
|
||||
visn_feats, visn_attention_mask, tdx):
|
||||
|
||||
''' visual self-attention with state '''
|
||||
visn_att_output = torch.cat((lang_feats[:, 0:1, :], visn_feats), dim=1)
|
||||
state_vis_mask = torch.cat((lang_attention_mask[:,:,:,0:1], visn_attention_mask), dim=-1)
|
||||
|
||||
''' state and vision attend to language '''
|
||||
visn_att_output, cross_attention_scores = self.cross_att(lang_feats[:, 1:, :], lang_attention_mask[:, :, :, 1:], visn_att_output, state_vis_mask)
|
||||
|
||||
language_attention_scores = cross_attention_scores[:, :, 0, :]
|
||||
|
||||
state_visn_att_output = self.self_att(visn_att_output, state_vis_mask)
|
||||
state_visn_output = self.output_fc(state_visn_att_output[0])
|
||||
|
||||
visn_att_output = state_visn_output[:, 1:, :]
|
||||
lang_att_output = torch.cat((state_visn_output[:, 0:1, :], lang_feats[:,1:,:]), dim=1)
|
||||
|
||||
visual_attention_scores = state_visn_att_output[1][:, :, 0, 1:]
|
||||
|
||||
return lang_att_output, visn_att_output, language_attention_scores, visual_attention_scores
|
||||
|
||||
|
||||
class VisionEncoder(nn.Module):
|
||||
def __init__(self, vision_size, config):
|
||||
super().__init__()
|
||||
feat_dim = vision_size
|
||||
|
||||
# Object feature encoding
|
||||
self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
|
||||
self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, visn_input):
|
||||
feats = visn_input
|
||||
|
||||
x = self.visn_fc(feats)
|
||||
x = self.visn_layer_norm(x)
|
||||
|
||||
output = self.dropout(x)
|
||||
return output
|
||||
|
||||
|
||||
class VLNBert(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(VLNBert, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.img_dim = config.img_feature_dim # 2176
|
||||
logger.info('VLNBert Image Dimension: {}'.format(self.img_dim))
|
||||
self.img_feature_type = config.img_feature_type # ''
|
||||
self.vl_layers = config.vl_layers # 4
|
||||
self.la_layers = config.la_layers # 9
|
||||
self.lalayer = nn.ModuleList(
|
||||
[BertLayer(config) for _ in range(self.la_layers)])
|
||||
self.addlayer = nn.ModuleList(
|
||||
[LXRTXLayer(config) for _ in range(self.vl_layers)])
|
||||
self.vision_encoder = VisionEncoder(self.config.img_feature_dim, self.config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, mode, input_ids, token_type_ids=None,
|
||||
attention_mask=None, lang_mask=None, vis_mask=None, position_ids=None, head_mask=None, img_feats=None):
|
||||
|
||||
attention_mask = lang_mask
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
if mode == 'language':
|
||||
''' LXMERT language branch (in VLN only perform this at initialization) '''
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
text_embeds = embedding_output
|
||||
|
||||
for layer_module in self.lalayer:
|
||||
temp_output = layer_module(text_embeds, extended_attention_mask)
|
||||
text_embeds = temp_output[0]
|
||||
|
||||
sequence_output = text_embeds
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
return pooled_output, sequence_output
|
||||
|
||||
elif mode == 'visual':
|
||||
''' LXMERT visual branch (no language processing during navigation) '''
|
||||
text_embeds = input_ids
|
||||
|
||||
text_mask = extended_attention_mask
|
||||
|
||||
img_embedding_output = self.vision_encoder(img_feats)
|
||||
img_seq_len = img_feats.shape[1]
|
||||
batch_size = text_embeds.size(0)
|
||||
|
||||
img_seq_mask = vis_mask
|
||||
|
||||
extended_img_mask = img_seq_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_img_mask = extended_img_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_img_mask = (1.0 - extended_img_mask) * -10000.0
|
||||
img_mask = extended_img_mask
|
||||
|
||||
lang_output = text_embeds
|
||||
visn_output = img_embedding_output
|
||||
|
||||
for tdx, layer_module in enumerate(self.addlayer):
|
||||
lang_output, visn_output, language_attention_scores, visual_attention_scores = layer_module(lang_output, text_mask, visn_output, img_mask, tdx)
|
||||
|
||||
sequence_output = lang_output
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
language_state_scores = language_attention_scores.mean(dim=1)
|
||||
visual_action_scores = visual_attention_scores.mean(dim=1)
|
||||
|
||||
return pooled_output, visual_action_scores
|
||||
48
r2r_src/vlnbert/vlnbert_init.py
Normal file
48
r2r_src/vlnbert/vlnbert_init.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
|
||||
|
||||
import sys
|
||||
sys.path.append('Oscar/Oscar')
|
||||
|
||||
from transformers.pytorch_transformers import (BertConfig, BertTokenizer)
|
||||
|
||||
def get_tokenizer(args):
|
||||
if args.vlnbert == 'oscar':
|
||||
tokenizer_class = BertTokenizer
|
||||
model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
|
||||
elif args.vlnbert == 'prevalent':
|
||||
tokenizer_class = BertTokenizer
|
||||
tokenizer = tokenizer_class.from_pretrained('bert-base-uncased')
|
||||
return tokenizer
|
||||
|
||||
def get_vlnbert_models(args, config=None):
|
||||
config_class = BertConfig
|
||||
|
||||
if args.vlnbert == 'oscar':
|
||||
from vlnbert.vlnbert_OSCAR import VLNBert
|
||||
model_class = VLNBert
|
||||
model_name_or_path = 'Oscar/pretrained_models/base-no-labels/ep_67_588997'
|
||||
vis_config = config_class.from_pretrained(model_name_or_path, num_labels=2, finetuning_task='vln-r2r')
|
||||
|
||||
vis_config.model_type = 'visual'
|
||||
vis_config.finetuning_task = 'vln-r2r'
|
||||
vis_config.hidden_dropout_prob = 0.3
|
||||
vis_config.hidden_size = 768
|
||||
vis_config.img_feature_dim = 2176
|
||||
vis_config.num_attention_heads = 12
|
||||
vis_config.num_hidden_layers = 12
|
||||
visual_model = model_class.from_pretrained(model_name_or_path, from_tf=False, config=vis_config)
|
||||
|
||||
elif args.vlnbert == 'prevalent':
|
||||
from vlnbert.vlnbert_PREVALENT import VLNBert
|
||||
model_class = VLNBert
|
||||
model_name_or_path = 'Prevalent/pretrained_model/pytorch_model.bin'
|
||||
vis_config = config_class.from_pretrained('bert-base-uncased')
|
||||
vis_config.img_feature_dim = 2176
|
||||
vis_config.img_feature_type = ""
|
||||
vis_config.vl_layers = 4
|
||||
vis_config.la_layers = 9
|
||||
|
||||
visual_model = model_class.from_pretrained(model_name_or_path, config=vis_config)
|
||||
|
||||
return visual_model
|
||||
157
recurrent-vln-bert.yml
Normal file
157
recurrent-vln-bert.yml
Normal file
@ -0,0 +1,157 @@
|
||||
name: recurrent-vln-bert
|
||||
channels:
|
||||
- bioconda
|
||||
- menpo
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- anaconda
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- bcolz=1.2.1=py36h04863e7_0
|
||||
- blas=1.0=mkl
|
||||
- blosc=1.16.3=hd408876_0
|
||||
- bokeh=2.0.1=py36_0
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2020.7.22=0
|
||||
- certifi=2020.6.20=py36_0
|
||||
- cffi=1.13.1=py36h2e261b9_0
|
||||
- click=7.1.1=py_0
|
||||
- cloudpickle=1.3.0=py_0
|
||||
- cmake=3.14.0=h52cb24c_0
|
||||
- cudatoolkit=10.2.89=hfd86e86_1
|
||||
- cytoolz=0.10.1=py36h7b6447c_0
|
||||
- dask=2.14.0=py_0
|
||||
- dask-core=2.14.0=py_0
|
||||
- distributed=2.14.0=py36_0
|
||||
- docutils=0.15.2=py36_0
|
||||
- expat=2.2.6=he6710b0_0
|
||||
- ffmpeg=4.0=hcdf2ecd_0
|
||||
- freetype=2.9.1=h8a8886c_1
|
||||
- fsspec=0.7.1=py_0
|
||||
- hdf5=1.10.2=hba1933b_1
|
||||
- heapdict=1.0.1=py_0
|
||||
- idna=2.10=py_0
|
||||
- intel-openmp=2019.4=243
|
||||
- jasper=2.0.14=h07fcdf6_1
|
||||
- jinja2=2.11.1=py_0
|
||||
- jsoncpp=1.8.4=hfd86e86_0
|
||||
- krb5=1.17.1=h173b8e3_0
|
||||
- libcurl=7.69.1=h20c2e04_0
|
||||
- libedit=3.1.20181209=hc058e9b_0
|
||||
- libgcc-ng=9.1.0=hdf63c60_0
|
||||
- libgfortran-ng=7.3.0=hdf63c60_0
|
||||
- libopencv=3.4.2=hb342d67_1
|
||||
- libopus=1.3=h7b6447c_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libssh2=1.9.0=h1ba5d50_1
|
||||
- libstdcxx-ng=9.1.0=hdf63c60_0
|
||||
- libtiff=4.0.10=h2733197_2
|
||||
- libvpx=1.7.0=h439df22_0
|
||||
- locket=0.2.0=py36_1
|
||||
- lz4-c=1.8.1.2=h14c3975_0
|
||||
- lzo=2.10=h1bfc0ba_1
|
||||
- markupsafe=1.1.1=py36h7b6447c_0
|
||||
- mkl=2019.4=243
|
||||
- mkl-service=2.3.0=py36he904b0f_0
|
||||
- mkl_fft=1.0.14=py36ha843d7b_0
|
||||
- mkl_random=1.1.0=py36hd6b4f25_0
|
||||
- msgpack-python=1.0.0=py36hfd86e86_1
|
||||
- ncurses=6.1=he6710b0_1
|
||||
- ninja=1.9.0=py36hfd86e86_0
|
||||
- numexpr=2.7.1=py36h423224d_0
|
||||
- olefile=0.46=py36_0
|
||||
- opencv=3.4.2=py36h6fd60c2_1
|
||||
- openjdk=8.0.152=h7b6447c_3
|
||||
- openssl=1.1.1g=h7b6447c_0
|
||||
- packaging=20.3=py_0
|
||||
- pandas=1.1.1=py36he6710b0_0
|
||||
- partd=1.1.0=py_0
|
||||
- pillow=6.2.1=py36h34e0f95_0
|
||||
- pip=19.3.1=py36_0
|
||||
- psutil=5.7.0=py36h7b6447c_0
|
||||
- py-opencv=3.4.2=py36hb342d67_1
|
||||
- pybind11=2.4.2=py36hfd86e86_0
|
||||
- pycparser=2.19=py36_0
|
||||
- pyopenssl=19.1.0=py_1
|
||||
- pyparsing=2.4.7=py_0
|
||||
- pytables=3.4.4=py36ha205bf6_0
|
||||
- python=3.6.9=h265db76_0
|
||||
- python-dateutil=2.8.1=py_0
|
||||
- pytz=2020.1=py_0
|
||||
- pyyaml=5.3.1=py36h7b6447c_0
|
||||
- readline=7.0=h7b6447c_5
|
||||
- requests=2.24.0=py_0
|
||||
- rhash=1.3.8=h1ba5d50_0
|
||||
- setuptools=41.6.0=py36_0
|
||||
- six=1.15.0=py_0
|
||||
- sortedcontainers=2.1.0=py36_0
|
||||
- sqlite=3.30.1=h7b6447c_0
|
||||
- tblib=1.6.0=py_0
|
||||
- tk=8.6.8=hbc83047_0
|
||||
- toolz=0.10.0=py_0
|
||||
- tornado=6.0.4=py36h7b6447c_1
|
||||
- typing_extensions=3.7.4.1=py36_0
|
||||
- urllib3=1.25.10=py_0
|
||||
- wheel=0.33.6=py36_0
|
||||
- xz=5.2.4=h14c3975_4
|
||||
- yaml=0.1.7=h96e3832_1
|
||||
- zict=2.0.0=py_0
|
||||
- zlib=1.2.11=h7b6447c_3
|
||||
- zstd=1.3.7=h0b5b093_0
|
||||
- java-jdk=7.0.91=1
|
||||
- tqdm=4.7.2=py36_0
|
||||
- boto3=1.13.14=pyh9f0ad1d_0
|
||||
- botocore=1.16.14=pyh9f0ad1d_0
|
||||
- brotlipy=0.7.0=py36h8c4c3a4_1000
|
||||
- cairo=1.14.12=h80bd089_1005
|
||||
- chardet=3.0.4=py36h9f0ad1d_1006
|
||||
- cryptography=2.9.2=py36h45558ae_0
|
||||
- fontconfig=2.13.1=h2176d3f_1000
|
||||
- freeglut=3.0.0=hf484d3e_1005
|
||||
- gettext=0.19.8.1=h9745a5d_1001
|
||||
- glew=2.1.0=he1b5a44_0
|
||||
- glib=2.56.2=had28632_1001
|
||||
- graphite2=1.3.13=hf484d3e_1000
|
||||
- harfbuzz=1.9.0=he243708_1001
|
||||
- htop=2.2.0=hf8c457e_1000
|
||||
- icu=58.2=hf484d3e_1000
|
||||
- jmespath=0.10.0=pyh9f0ad1d_0
|
||||
- libblas=3.8.0=14_mkl
|
||||
- libcblas=3.8.0=14_mkl
|
||||
- libglu=9.0.0=hf484d3e_1000
|
||||
- libiconv=1.15=h14c3975_1004
|
||||
- liblapack=3.8.0=14_mkl
|
||||
- libuuid=2.32.1=h14c3975_1000
|
||||
- libxcb=1.13=h14c3975_1002
|
||||
- libxml2=2.9.8=h143f9aa_1005
|
||||
- numpy=1.18.4=py36h7314795_0
|
||||
- pcre=8.41=hf484d3e_1003
|
||||
- pixman=0.34.0=h14c3975_1003
|
||||
- pthread-stubs=0.4=h14c3975_1001
|
||||
- pysocks=1.7.1=py36h9f0ad1d_1
|
||||
- python_abi=3.6=1_cp36m
|
||||
- pytorch-pretrained-bert=0.6.2=py36_0
|
||||
- regex=2020.5.14=py36h8c4c3a4_0
|
||||
- s3transfer=0.3.3=py36h9f0ad1d_1
|
||||
- xorg-fixesproto=5.0=h14c3975_1002
|
||||
- xorg-inputproto=2.3.2=h14c3975_1002
|
||||
- xorg-kbproto=1.0.7=h14c3975_1002
|
||||
- xorg-libice=1.0.9=h14c3975_1004
|
||||
- xorg-libsm=1.2.3=h4937e3b_1000
|
||||
- xorg-libx11=1.6.9=h516909a_0
|
||||
- xorg-libxau=1.0.8=h14c3975_1006
|
||||
- xorg-libxdmcp=1.1.2=h14c3975_1007
|
||||
- xorg-libxext=1.3.3=h14c3975_1004
|
||||
- xorg-libxfixes=5.0.3=h14c3975_1004
|
||||
- xorg-libxi=1.7.9=h14c3975_1002
|
||||
- xorg-libxrender=0.9.10=h14c3975_1002
|
||||
- xorg-renderproto=0.11.1=h14c3975_1002
|
||||
- xorg-xextproto=7.3.0=h14c3975_1002
|
||||
- xorg-xproto=7.0.31=h14c3975_1007
|
||||
- jpeg=9b=h024ee3a_2
|
||||
- libffi=3.2.1=hd88cf55_4
|
||||
- snappy=1.1.7=hbae5bb6_3
|
||||
- osmesa=12.2.2.dev=0
|
||||
- pytorch=1.6.0=py3.6_cuda10.2.89_cudnn7.6.5_0
|
||||
- torchvision=0.7.0=py36_cu102
|
||||
26
run/test_agent.bash
Normal file
26
run/test_agent.bash
Normal file
@ -0,0 +1,26 @@
|
||||
name=VLNBERT-test
|
||||
|
||||
flag="--vlnbert prevalent
|
||||
|
||||
--submit 0
|
||||
--test_only 0
|
||||
|
||||
--train validlistener
|
||||
--load snap/VLNBERT-PREVALENT-final/state_dict/best_val_unseen
|
||||
|
||||
--features places365
|
||||
--maxAction 15
|
||||
--batchSize 8
|
||||
--feedback sample
|
||||
--lr 1e-5
|
||||
--iters 300000
|
||||
--optim adamW
|
||||
|
||||
--mlWeight 0.20
|
||||
--maxInput 80
|
||||
--angleFeatSize 128
|
||||
--featdropout 0.4
|
||||
--dropout 0.5"
|
||||
|
||||
mkdir -p snap/$name
|
||||
CUDA_VISIBLE_DEVICES=1 python r2r_src/train.py $flag --name $name
|
||||
25
run/train_agent.bash
Normal file
25
run/train_agent.bash
Normal file
@ -0,0 +1,25 @@
|
||||
name=VLNBERT-train
|
||||
|
||||
flag="--vlnbert prevalent
|
||||
|
||||
--aug data/prevalent/prevalent_aug.json
|
||||
--test_only 0
|
||||
|
||||
--train auglistener
|
||||
|
||||
--features places365
|
||||
--maxAction 15
|
||||
--batchSize 8
|
||||
--feedback sample
|
||||
--lr 1e-5
|
||||
--iters 300000
|
||||
--optim adamW
|
||||
|
||||
--mlWeight 0.20
|
||||
--maxInput 80
|
||||
--angleFeatSize 128
|
||||
--featdropout 0.4
|
||||
--dropout 0.5"
|
||||
|
||||
mkdir -p snap/$name
|
||||
CUDA_VISIBLE_DEVICES=1 python r2r_src/train.py $flag --name $name
|
||||
Loading…
Reference in New Issue
Block a user