vlnbert/r2r_src/agent.py

675 lines
28 KiB
Python

# R2R-EnvDrop, 2019, haotan@cs.unc.edu
# Modified in Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
import json
import os
import sys
import numpy as np
import random
import math
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from env import R2RBatch
import utils
from utils import padding_idx, print_progress
import model_OSCAR, model_PREVALENT
import param
from param import args
from collections import defaultdict
class BaseAgent(object):
''' Base class for an R2R agent to generate and save trajectories. '''
def __init__(self, env, results_path):
self.env = env
self.results_path = results_path
random.seed(1)
self.results = {}
self.losses = [] # For learning agents
def write_results(self):
output = [{'instr_id':k, 'trajectory': v[0], 'found': v[1]} for k,v in self.results.items()]
with open(self.results_path, 'w') as f:
json.dump(output, f)
def get_results(self):
output = [{'instr_id': k, 'trajectory': v[0], 'found': v[1]} for k, v in self.results.items()]
return output
def rollout(self, **args):
''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] '''
raise NotImplementedError
@staticmethod
def get_agent(name):
return globals()[name+"Agent"]
def test(self, iters=None, **kwargs):
self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch
self.losses = []
self.results = {}
# We rely on env showing the entire batch before repeating anything
looped = False
self.loss = 0
if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters):
trajs, found = self.rollout(**kwargs)
print(found)
for index, traj in enumerate(trajs):
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], found[index])
else: # Do a full round
while True:
trajs, found = self.rollout(**kwargs)
print("FOUND: ", found)
for index, traj in enumerate(trajs):
if traj['instr_id'] in self.results:
looped = True
else:
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], found[index])
if looped:
break
class Seq2SeqAgent(BaseAgent):
''' An agent based on an LSTM seq2seq model with attention. '''
# For now, the agent can't pick which forward move to make - just the one in the middle
env_actions = {
'left': ([0],[-1], [0]), # left
'right': ([0], [1], [0]), # right
'up': ([0], [0], [1]), # up
'down': ([0], [0],[-1]), # down
'forward': ([1], [0], [0]), # forward
'<end>': ([0], [0], [0]), # <end>
'<start>': ([0], [0], [0]), # <start>
'<ignore>': ([0], [0], [0]) # <ignore>
}
def __init__(self, env, results_path, tok, episode_len=20):
super(Seq2SeqAgent, self).__init__(env, results_path)
self.tok = tok
self.episode_len = episode_len
self.feature_size = self.env.feature_size
# Models
if args.vlnbert == 'oscar':
self.vln_bert = model_OSCAR.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
self.critic = model_OSCAR.Critic().cuda()
elif args.vlnbert == 'prevalent':
self.vln_bert = model_PREVALENT.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
self.critic = model_PREVALENT.Critic().cuda()
self.models = (self.vln_bert, self.critic)
# Optimizers
self.vln_bert_optimizer = args.optimizer(self.vln_bert.parameters(), lr=args.lr)
self.critic_optimizer = args.optimizer(self.critic.parameters(), lr=args.lr)
self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer)
# Evaluations
self.losses = []
self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
self.ndtw_criterion = utils.ndtw_initialize()
# Logs
sys.stdout.flush()
self.logs = defaultdict(list)
def _sort_batch(self, obs):
seq_tensor = np.array([ob['instr_encoding'] for ob in obs])
seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1)
seq_lengths[seq_lengths == 0] = seq_tensor.shape[1]
seq_tensor = torch.from_numpy(seq_tensor)
seq_lengths = torch.from_numpy(seq_lengths)
# Sort sequences by lengths
seq_lengths, perm_idx = seq_lengths.sort(0, True) # True -> descending
sorted_tensor = seq_tensor[perm_idx]
mask = (sorted_tensor != padding_idx)
token_type_ids = torch.zeros_like(mask)
return Variable(sorted_tensor, requires_grad=False).long().cuda(), \
mask.long().cuda(), token_type_ids.long().cuda(), \
list(seq_lengths), list(perm_idx)
def _feature_variable(self, obs):
''' Extract precomputed features into variable. '''
features = np.empty((len(obs), args.views, self.feature_size + args.angle_feat_size), dtype=np.float32)
for i, ob in enumerate(obs):
features[i, :, :] = ob['feature'] # Image feat
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
def _candidate_variable(self, obs):
candidate_leng = [len(ob['candidate']) + 2 for ob in obs] # +1 is for the end
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
# which is zero in my implementation
for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature']
# 補上 not fount token
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones((self.feature_size + args.angle_feat_size))
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
def get_input_feat(self, obs):
input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32)
for i, ob in enumerate(obs):
input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation'])
input_a_t = torch.from_numpy(input_a_t).cuda()
# f_t = self._feature_variable(obs) # Pano image features from obs
candidate_feat, candidate_leng = self._candidate_variable(obs)
return input_a_t, candidate_feat, candidate_leng
def _teacher_action(self, obs, ended):
"""
Extract teacher actions into variable.
:param obs: The observation.
:param ended: Whether the action seq is ended
:return:
"""
a = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if ended[i]: # Just ignore this index
a[i] = args.ignoreid
else:
for k, candidate in enumerate(ob['candidate']):
if candidate['viewpointId'] == ob['teacher']: # Next view point
a[i] = k
break
else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
if ob['swap']: # instruction 有被換過,所以要 not found
a[i] = len(ob['candidate'])-1
else: # STOP
a[i] = len(ob['candidate'])-2
print(" ", a)
return torch.from_numpy(a).cuda()
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None):
"""
Interface between Panoramic view and Egocentric view
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
"""
def take_action(i, idx, name):
if type(name) is int: # Go to the next view
self.env.env.sims[idx].makeAction([name], [0], [0])
else: # Adjust
self.env.env.sims[idx].makeAction(*self.env_actions[name])
if perm_idx is None:
perm_idx = range(len(perm_obs))
for i, idx in enumerate(perm_idx):
action = a_t[i]
# print('action: ', action)
if action != -1 and action != -2: # -1 is the <stop> action
select_candidate = perm_obs[i]['candidate'][action]
src_point = perm_obs[i]['viewIndex']
trg_point = select_candidate['pointId']
src_level = (src_point ) // 12 # The point idx started from 0
trg_level = (trg_point ) // 12
while src_level < trg_level: # Tune up
take_action(i, idx, 'up')
src_level += 1
while src_level > trg_level: # Tune down
take_action(i, idx, 'down')
src_level -= 1
while self.env.env.sims[idx].getState()[0].viewIndex != trg_point: # Turn right until the target
take_action(i, idx, 'right')
assert select_candidate['viewpointId'] == \
self.env.env.sims[idx].getState()[0].navigableLocations[select_candidate['idx']].viewpointId
take_action(i, idx, select_candidate['idx'])
state = self.env.env.sims[idx].getState()[0]
# print(state.rgb.shape)
# print("action: {} view_index: {}".format(action, state.viewIndex))
if traj is not None:
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
else:
found[i] = action
'''
elif action == -1:
print('<STOP>')
elif action == -2:
print('<NOT_FOUND>')
'''
def rollout(self, train_ml=None, train_rl=True, reset=True):
"""
:param train_ml: The weight to train with maximum likelihood
:param train_rl: whether use RL in training
:param reset: Reset the environment
:return:
"""
print("ROLLOUT!!!")
if self.feedback == 'teacher' or self.feedback == 'argmax':
train_rl = False
# self.env is `R2RBatch`
# get obervation
if reset: # Reset env
obs = np.array(self.env.reset())
else:
obs = np.array(self.env._get_obs())
batch_size = len(obs)
# Language input
sentence, language_attention_mask, token_type_ids, \
seq_lengths, perm_idx = self._sort_batch(obs)
perm_obs = obs[perm_idx]
''' Language BERT '''
language_inputs = {'mode': 'language',
'sentence': sentence,
'attention_mask': language_attention_mask,
'lang_mask': language_attention_mask,
'token_type_ids': token_type_ids}
if args.vlnbert == 'oscar':
language_features = self.vln_bert(**language_inputs)
elif args.vlnbert == 'prevalent':
h_t, language_features = self.vln_bert(**language_inputs)
# Record starting point
traj = [{
'instr_id': ob['instr_id'],
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
} for ob in perm_obs]
found = [None for _ in range(len(perm_obs))]
# Init the reward shaping
last_dist = np.zeros(batch_size, np.float32)
last_ndtw = np.zeros(batch_size, np.float32)
for i, ob in enumerate(perm_obs): # The init distance from the view point to the target
last_dist[i] = ob['distance']
path_act = [vp[0] for vp in traj[i]['path']]
last_ndtw[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
# Initialization the tracking state
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
# Init the logs
rewards = []
hidden_states = []
policy_log_probs = []
masks = []
entropys = []
ml_loss = 0.
for t in range(self.episode_len):
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
# the first [CLS] token, initialized by the language BERT, serves
# as the agent's state passing through time steps
if (t >= 1) or (args.vlnbert=='prevalent'):
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual',
'sentence': language_features,
'attention_mask': visual_attention_mask,
'lang_mask': language_attention_mask,
'vis_mask': visual_temp_mask,
'token_type_ids': token_type_ids,
'action_feats': input_a_t,
# 'pano_feats': f_t,
'cand_feats': candidate_feat}
h_t, logit = self.vln_bert(**visual_inputs)
hidden_states.append(h_t)
# Mask outputs where agent can't move forward
# Here the logit is [b, max_candidate]
# (8, max(candidate))
candidate_mask = utils.length2mask(candidate_leng)
logit.masked_fill_(candidate_mask, -float('inf'))
# Supervised training
target = self._teacher_action(perm_obs, ended)
for i, d in enumerate(target):
# print(perm_obs[i]['swap'], perm_obs[i]['instructions'])
# print(d)
_, at_t = logit.max(1)
'''
if at_t[i].item() == candidate_leng[i]-1:
print("-2")
elif at_t[i].item() == candidate_leng[i]-2:
print("-1")
else:
print(at_t[i].item())
print()
'''
ml_loss += self.criterion(logit, target)
a_predict = None
# Determine next model inputs
if self.feedback == 'teacher':
a_t = target # teacher forcing
_, a_predict = logit.max(1)
a_predict = a_predict.detach()
elif self.feedback == 'argmax':
_, a_t = logit.max(1) # student forcing - argmax
a_t = a_t.detach()
a_predict = a_t.detach()
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
elif self.feedback == 'sample':
probs = F.softmax(logit, 1) # sampling an action from model
c = torch.distributions.Categorical(probs)
self.logs['entropy'].append(c.entropy().sum().item()) # For log
entropys.append(c.entropy()) # For optimization
new_c = c.sample()
a_t = new_c.detach()
a_predict = new_c.detach()
policy_log_probs.append(c.log_prob(a_t))
else:
# print(self.feedback)
sys.exit('Invalid feedback option')
# Prepare environment action
# NOTE: Env action is in the perm_obs space
cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t):
if next_id == args.ignoreid or ended[i]:
if found[i] == True:
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
else:
cpu_a_t[i] = -2
elif next_id == (candidate_leng[i]-2):
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
elif next_id == (candidate_leng[i]-1):
cpu_a_t[i] = -2
cpu_a_predict = a_predict.cpu().numpy()
for i, next_id in enumerate(cpu_a_predict):
if next_id == (candidate_leng[i]-2):
cpu_a_predict[i] = -1 # Change the <end> and ignore action to -1
elif next_id == (candidate_leng[i]-1):
cpu_a_predict[i] = -2
# Make action and get the new state
print(cpu_a_t)
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found=found)
obs = np.array(self.env._get_obs())
perm_obs = obs[perm_idx] # Perm the obs for the resu
if train_rl:
# Calculate the mask and reward
dist = np.zeros(batch_size, np.float32)
ndtw_score = np.zeros(batch_size, np.float32)
reward = np.zeros(batch_size, np.float32)
mask = np.ones(batch_size, np.float32)
for i, ob in enumerate(perm_obs):
dist[i] = ob['distance']
path_act = [vp[0] for vp in traj[i]['path']]
ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
if ended[i]:
reward[i] = 0.0
mask[i] = 0.0
else:
action_idx = cpu_a_t[i]
# Target reward
if action_idx == -1: # If the action now is end
if dist[i] < 3.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['swap']:
reward[i] -= 2
else:
reward[i] += 1
else: # Incorrect
reward[i] = -2.0
elif action_idx == -2: # NOT_FOUND reward 設定在這裏
if dist[i] < 3.0:
reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['swap']:
reward[i] += 3 # 偵測到錯誤 instruction,多加一分
else:
reward[i] -= 2
else: # Incorrect
reward[i] = -2.0
reward[i] += 1 # distance > 3, 確實沒找到東西,從扣二變成扣一
else: # The action is not end
# Path fidelity rewards (distance & nDTW)
reward[i] = - (dist[i] - last_dist[i])
ndtw_reward = ndtw_score[i] - last_ndtw[i]
if reward[i] > 0.0: # Quantification
reward[i] = 1.0 + ndtw_reward
elif reward[i] < 0.0:
reward[i] = -1.0 + ndtw_reward
else:
raise NameError("The action doesn't change the move")
# Miss the target penalty
if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0):
reward[i] -= (1.0 - last_dist[i]) * 2.0
rewards.append(reward)
masks.append(mask)
last_dist[:] = dist
last_ndtw[:] = ndtw_score
# Update the finished actions
# -1 means ended or ignored (already ended)
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
# Early exit if all ended
if ended.all():
break
# print()
if train_rl:
# Last action in A2C
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).long()
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual',
'sentence': language_features,
'attention_mask': visual_attention_mask,
'lang_mask': language_attention_mask,
'vis_mask': visual_temp_mask,
'token_type_ids': token_type_ids,
'action_feats': input_a_t,
# 'pano_feats': f_t,
'cand_feats': candidate_feat}
last_h_, _ = self.vln_bert(**visual_inputs)
rl_loss = 0.
# NOW, A2C!!!
# Calculate the final discounted reward
last_value__ = self.critic(last_h_).detach() # The value esti of the last state, remove the grad for safety
discount_reward = np.zeros(batch_size, np.float32) # The inital reward is zero
for i in range(batch_size):
if not ended[i]: # If the action is not ended, use the value function as the last reward
discount_reward[i] = last_value__[i]
length = len(rewards)
total = 0
for t in range(length-1, -1, -1):
discount_reward = discount_reward * args.gamma + rewards[t] # If it ended, the reward will be 0
mask_ = Variable(torch.from_numpy(masks[t]), requires_grad=False).cuda()
clip_reward = discount_reward.copy()
r_ = Variable(torch.from_numpy(clip_reward), requires_grad=False).cuda()
v_ = self.critic(hidden_states[t])
a_ = (r_ - v_).detach()
rl_loss += (-policy_log_probs[t] * a_ * mask_).sum()
rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss
if self.feedback == 'sample':
rl_loss += (- 0.01 * entropys[t] * mask_).sum()
self.logs['critic_loss'].append((((r_ - v_) ** 2) * mask_).sum().item())
total = total + np.sum(masks[t])
self.logs['total'].append(total)
# Normalize the loss function
if args.normalize_loss == 'total':
rl_loss /= total
elif args.normalize_loss == 'batch':
rl_loss /= batch_size
else:
assert args.normalize_loss == 'none'
self.loss += rl_loss
self.logs['RL_loss'].append(rl_loss.item())
if train_ml is not None:
self.loss += ml_loss * train_ml / batch_size
self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item())
if type(self.loss) is int: # For safety, it will be activated if no losses are added
self.losses.append(0.)
else:
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
print("\n\n")
return traj, found
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment '''
self.feedback = feedback
if use_dropout:
self.vln_bert.train()
self.critic.train()
else:
self.vln_bert.eval()
self.critic.eval()
super(Seq2SeqAgent, self).test(iters)
def zero_grad(self):
self.loss = 0.
self.losses = []
for model, optimizer in zip(self.models, self.optimizers):
model.train()
optimizer.zero_grad()
def accumulate_gradient(self, feedback='teacher', **kwargs):
if feedback == 'teacher':
self.feedback = 'teacher'
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
elif feedback == 'sample':
self.feedback = 'teacher'
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
self.feedback = 'sample'
self.rollout(train_ml=None, train_rl=True, **kwargs)
else:
assert False
def optim_step(self):
self.loss.backward()
torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.)
self.vln_bert_optimizer.step()
self.critic_optimizer.step()
def train(self, n_iters, feedback='teacher', **kwargs):
''' Train for a given number of iterations '''
self.feedback = feedback
self.vln_bert.train()
self.critic.train()
self.losses = []
for iter in range(1, n_iters + 1):
self.vln_bert_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
self.loss = 0
if feedback == 'teacher':
self.feedback = 'teacher'
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
elif feedback == 'sample': # agents in IL and RL separately
if args.ml_weight != 0:
self.feedback = 'teacher'
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
self.feedback = 'sample'
self.rollout(train_ml=None, train_rl=True, **kwargs)
else:
assert False
self.loss.backward()
torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.)
self.vln_bert_optimizer.step()
self.critic_optimizer.step()
if args.aug is None:
print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50)
def save(self, epoch, path):
''' Snapshot models '''
the_dir, _ = os.path.split(path)
os.makedirs(the_dir, exist_ok=True)
states = {}
def create_state(name, model, optimizer):
states[name] = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
("critic", self.critic, self.critic_optimizer)]
for param in all_tuple:
create_state(*param)
torch.save(states, path)
def load(self, path):
''' Loads parameters (but not training state) '''
states = torch.load(path)
def recover_state(name, model, optimizer):
state = model.state_dict()
model_keys = set(state.keys())
load_keys = set(states[name]['state_dict'].keys())
if model_keys != load_keys:
print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
state.update(states[name]['state_dict'])
model.load_state_dict(state)
if args.loadOptim:
optimizer.load_state_dict(states[name]['optimizer'])
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
("critic", self.critic, self.critic_optimizer)]
for param in all_tuple:
recover_state(*param)
return states['vln_bert']['epoch'] - 1