adversarial_AIRBERT/r2r_src/agent.py
Shizhe Chen bbeb69aa5f init
2021-08-02 13:04:04 +00:00

623 lines
27 KiB
Python

# R2R-EnvDrop, 2019, haotan@cs.unc.edu
# Modified in Recurrent VLN-BERT, 2020, by Yicong.Hong@anu.edu.au
import os
import sys
import json
import numpy as np
import random
import math
import time
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from env import R2RBatch
import utils
from utils import padding_idx, print_progress
import model_OSCAR, model_PREVALENT, model_CA
from param import args
from collections import defaultdict
class BaseAgent(object):
''' Base class for an R2R agent to generate and save trajectories. '''
def __init__(self, env, results_path):
self.env = env
self.results_path = results_path
random.seed(1)
self.results = {}
self.detailed_results = {}
self.losses = [] # For learning agents
def write_results(self):
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
with open(self.results_path, 'w') as f:
json.dump(output, f)
def get_results(self, detailed=False):
if detailed:
output = [v for v in self.detailed_results.values()]
else:
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
return output
def rollout(self, **args):
''' Return a list of dicts containing
instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] '''
raise NotImplementedError
@staticmethod
def get_agent(name):
return globals()[name+"Agent"]
def test(self, iters=None, **kwargs):
self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch
self.losses = []
self.results = {}
self.detailed_results = {}
# We rely on env showing the entire batch before repeating anything
looped = False
self.loss = 0
if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters):
for traj in self.rollout(**kwargs):
self.loss = 0
self.results[traj['instr_id']] = traj['path']
self.detailed_results[traj['intr_id']] = traj
else: # Do a full round
while True:
for traj in self.rollout(**kwargs):
if traj['instr_id'] in self.results:
looped = True
else:
self.loss = 0
self.results[traj['instr_id']] = traj['path']
self.detailed_results[traj['instr_id']] = traj
if looped:
break
class Seq2SeqAgent(BaseAgent):
''' An agent based on an LSTM seq2seq model with attention. '''
# For now, the agent can't pick which forward move to make - just the one in the middle
env_actions = {
'left': (0, -1, 0), # left
'right': (0, 1, 0), # right
'up': (0, 0, 1), # up
'down': (0, 0, -1), # down
'forward': (1, 0, 0), # forward
'<end>': (0, 0, 0), # <end>
'<start>': (0, 0, 0), # <start>
'<ignore>': (0, 0, 0) # <ignore>
}
def __init__(self, env, results_path, tok, episode_len=20):
super(Seq2SeqAgent, self).__init__(env, results_path)
self.tok = tok
self.episode_len = episode_len
self.feature_size = self.env.feature_size
# Models
if args.vlnbert == 'oscar':
self.vln_bert = model_OSCAR.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
self.critic = model_OSCAR.Critic().cuda()
elif args.vlnbert == 'prevalent':
self.vln_bert = model_PREVALENT.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
self.critic = model_PREVALENT.Critic().cuda()
elif args.vlnbert == 'vilbert':
self.vln_bert = model_CA.VLNBERT(feature_size=self.feature_size + args.angle_feat_size).cuda()
self.critic = model_CA.Critic().cuda()
self.models = (self.vln_bert, self.critic)
# Optimizers
self.vln_bert_optimizer = args.optimizer(self.vln_bert.parameters(), lr=args.lr)
self.critic_optimizer = args.optimizer(self.critic.parameters(), lr=args.lr)
self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer)
# Evaluations
self.losses = []
self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
self.ndtw_criterion = utils.ndtw_initialize()
# Logs
sys.stdout.flush()
self.logs = defaultdict(list)
def _instruction_variable(self, obs):
seq_tensor = np.array([ob['instr_encoding'] for ob in obs])
# The padding_idx is the same in bert tokenizer???
seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1)
seq_lengths[seq_lengths == 0] = seq_tensor.shape[1] # Full length
max_seq_len = np.max(seq_lengths)
seq_tensor = torch.from_numpy(seq_tensor[:, :max_seq_len]).cuda()
seq_mask = (seq_tensor != padding_idx)
token_type_ids = torch.zeros_like(seq_tensor)
return seq_tensor, seq_mask, token_type_ids, seq_lengths
def _feature_variable(self, obs):
''' Extract precomputed features into variable. '''
features = np.empty((len(obs), args.views, self.feature_size + args.angle_feat_size), dtype=np.float32)
for i, ob in enumerate(obs):
features[i, :, :] = ob['feature'] # Image feat
return torch.from_numpy(features).cuda()
def _candidate_variable(self, obs):
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
# which is zero in my implementation
for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature']
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
def get_input_feat(self, obs):
input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32)
for i, ob in enumerate(obs):
input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation'])
input_a_t = torch.from_numpy(input_a_t).cuda()
# they only use candidate features in the model without considering more context in pano???
# f_t = self._feature_variable(obs) # Pano image features from obs
candidate_feat, candidate_leng = self._candidate_variable(obs)
return input_a_t, candidate_feat, candidate_leng
def _teacher_action(self, obs, ended):
"""
Extract teacher actions into variable.
:param obs: The observation.
:param ended: Whether the action seq is ended
:return:
"""
a = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if ended[i]: # Just ignore this index
a[i] = args.ignoreid
else:
for k, candidate in enumerate(ob['candidate']):
if candidate['viewpointId'] == ob['teacher']: # Next view point
a[i] = k
break
else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
a[i] = len(ob['candidate'])
return torch.from_numpy(a).cuda()
def make_equiv_action(self, a_t, obs, traj=None):
"""
Interface between Panoramic view and Egocentric view
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
"""
def take_action(i, name):
if type(name) is int: # Go to the next view
self.env.env.sims[i].makeAction(name, 0, 0)
else: # Adjust
self.env.env.sims[i].makeAction(*self.env_actions[name])
for i in range(len(obs)):
action = a_t[i]
if action != -1: # -1 is the <stop> action
select_candidate = obs[i]['candidate'][action]
src_point = obs[i]['viewIndex']
trg_point = select_candidate['pointId']
src_level = (src_point ) // 12 # The point idx started from 0
trg_level = (trg_point ) // 12
while src_level < trg_level: # Tune up
take_action(i, 'up')
src_level += 1
while src_level > trg_level: # Tune down
take_action(i, 'down')
src_level -= 1
while self.env.env.sims[i].getState().viewIndex != trg_point: # Turn right until the target
take_action(i, 'right')
assert select_candidate['viewpointId'] == \
self.env.env.sims[i].getState().navigableLocations[select_candidate['idx']].viewpointId
take_action(i, select_candidate['idx'])
state = self.env.env.sims[i].getState()
if traj is not None:
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
def rollout(self, train_ml=None, train_rl=True, reset=True):
"""
:param train_ml: The weight to train with maximum likelihood
:param train_rl: whether use RL in training
:param reset: Reset the environment
:return:
traj
:update:
self.loss, self.losses
"""
if self.feedback == 'teacher' or self.feedback == 'argmax':
train_rl = False
if reset: # Reset env
obs = np.array(self.env.reset())
else:
obs = np.array(self.env._get_obs())
batch_size = len(obs)
# Language input
sentence, language_attention_mask, token_type_ids, seq_lengths = self._instruction_variable(obs)
''' Language BERT '''
language_inputs = {'mode': 'language',
'sentence': sentence,
'attention_mask': language_attention_mask,
'lang_mask': language_attention_mask,
'token_type_ids': token_type_ids}
# (batch_size, seq_len, hidden_size)
if args.vlnbert == 'oscar':
language_features = self.vln_bert(**language_inputs)
elif args.vlnbert == 'prevalent':
h_t, language_features = self.vln_bert(**language_inputs)
elif args.vlnbert == 'vilbert':
h_t, language_features = self.vln_bert(**language_inputs)
language_attention_mask = language_attention_mask[:, 1:]
# Record starting point
traj = [{
'instr_id': ob['instr_id'],
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
'txt_attns': [],
'img_attns': [],
'act_cand_logits': [],
'act_cand_viewpoints': [],
} for ob in obs]
# Init the reward shaping
last_dist = np.zeros(batch_size, np.float32)
last_ndtw = np.zeros(batch_size, np.float32)
for i, ob in enumerate(obs): # The init distance from the view point to the target
last_dist[i] = ob['distance']
path_act = [vp[0] for vp in traj[i]['path']]
if ob['scan'] in self.ndtw_criterion:
last_ndtw[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
# For test result submission: no backtracking
visited = [set() for _ in obs]
# Initialization the tracking state
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
# Init the logs
rewards = []
hidden_states = []
policy_log_probs = []
masks = []
entropys = []
ml_loss = 0.
for t in range(self.episode_len):
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(obs)
# the first [CLS] token, initialized by the language BERT, serves
# as the agent's state passing through time steps
if (t >= 1 and args.vlnbert != 'vilbert') or (args.vlnbert == 'prevalent'):
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).bool()
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual',
'sentence': language_features,
'attention_mask': visual_attention_mask,
'lang_mask': language_attention_mask,
'token_type_ids': token_type_ids,
'action_feats': input_a_t,
# 'pano_feats': f_t,
'cand_feats': candidate_feat,
'cand_mask': visual_temp_mask,
}
if args.vlnbert == 'vilbert':
visual_inputs['h_t'] = h_t
h_t, logit, txt_attn_probs, img_attn_probs = self.vln_bert(**visual_inputs)
hidden_states.append(h_t)
# Mask outputs where agent can't move forward
# Here the logit is [b, max_candidate]
candidate_mask = utils.length2mask(candidate_leng)
logit.masked_fill_(candidate_mask.bool(), -float('inf'))
# Supervised training
target = self._teacher_action(obs, ended)
ml_loss += self.criterion(logit, target)
# Determine next model inputs
if self.feedback == 'teacher':
a_t = target # teacher forcing
elif self.feedback == 'argmax':
_, a_t = logit.max(1) # student forcing - argmax
a_t = a_t.detach()
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
elif self.feedback == 'sample':
probs = F.softmax(logit, 1) # sampling an action from model
c = torch.distributions.Categorical(probs)
self.logs['entropy'].append(c.entropy().sum().item()) # For log
entropys.append(c.entropy()) # For optimization
a_t = c.sample().detach()
policy_log_probs.append(c.log_prob(a_t))
else:
print(self.feedback)
sys.exit('Invalid feedback option')
# record probabilities in results
if args.submit:
cpu_img_attn_probs = img_attn_probs.data.cpu().numpy().tolist()
cpu_txt_attn_probs = txt_attn_probs.data.cpu().numpy().tolist()
cpu_logits = logit.data.cpu().numpy().tolist()
for ib in range(batch_size):
if not ended[ib]:
traj[ib]['img_attns'].append(cpu_img_attn_probs[ib][:candidate_leng[ib]])
traj[ib]['txt_attns'].append(cpu_txt_attn_probs[ib][:seq_lengths[ib]])
traj[ib]['act_cand_logits'].append(cpu_logits[ib][:candidate_leng[ib]])
traj[ib]['act_cand_viewpoints'].append([x['viewpointId'] for x in obs[ib]['candidate']])
# Prepare environment action
cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t):
if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is <end>
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
# Make action and get the new state
self.make_equiv_action(cpu_a_t, obs, traj)
obs = np.array(self.env._get_obs())
if train_rl:
# Calculate the mask and reward
dist = np.zeros(batch_size, np.float32)
ndtw_score = np.zeros(batch_size, np.float32)
reward = np.zeros(batch_size, np.float32)
mask = np.ones(batch_size, np.float32)
for i, ob in enumerate(obs):
dist[i] = ob['distance']
# only maintain viewpoints
path_act = [traj[i]['path'][0][0]]
for vp in traj[i]['path'][1:]:
if vp[0] != path_act[-1]:
path_act.append(vp[0])
ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
if ended[i]:
reward[i] = 0.0
mask[i] = 0.0
else:
action_idx = cpu_a_t[i]
# Target reward
if action_idx == -1: # If the action now is end
if dist[i] < 3.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0
else: # Incorrect
reward[i] = -2.0
else: # The action is not end
# Path fidelity rewards (distance & nDTW)
reward[i] = - (dist[i] - last_dist[i])
ndtw_reward = ndtw_score[i] - last_ndtw[i]
if reward[i] > 0.0: # Quantification
reward[i] = 1.0 + ndtw_reward
elif reward[i] < 0.0:
reward[i] = -1.0 + ndtw_reward
else:
raise NameError("The action doesn't change the move")
# Miss the target penalty
if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0):
reward[i] -= (1.0 - last_dist[i]) * 2.0
rewards.append(reward)
masks.append(mask)
last_dist[:] = dist
last_ndtw[:] = ndtw_score
# Update the finished actions
# -1 means ended or ignored (already ended)
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
# Early exit if all ended
if ended.all():
break
if train_rl:
# Last action in A2C: used for reward calculation when the path is not ended
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(obs)
if args.vlnbert != 'vilbert':
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).bool()
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual',
'sentence': language_features,
'attention_mask': visual_attention_mask,
'lang_mask': language_attention_mask,
'token_type_ids': token_type_ids,
'action_feats': input_a_t,
# 'pano_feats': f_t,
'cand_feats': candidate_feat,
'cand_mask': visual_temp_mask,
}
if args.vlnbert == 'vilbert':
visual_inputs['h_t'] = h_t
last_h_, _, _, _ = self.vln_bert(**visual_inputs)
rl_loss = 0. # include reinforcement loss and critic loss (and optional entropy loss)
# NOW, A2C!!!
# Calculate the final discounted reward
last_value__ = self.critic(last_h_).detach() # The value esti of the last state, remove the grad for safety
discount_reward = np.zeros(batch_size, np.float32) # The inital reward is zero
for i in range(batch_size):
if not ended[i]: # If the action is not ended, use the value function as the last reward
discount_reward[i] = last_value__[i]
length = len(rewards)
total = 0
for t in range(length-1, -1, -1):
discount_reward = discount_reward * args.gamma + rewards[t] # If it ended, the reward will be 0
mask_ = torch.from_numpy(masks[t]).cuda()
clip_reward = discount_reward.copy()
r_ = torch.from_numpy(clip_reward).cuda()
v_ = self.critic(hidden_states[t])
a_ = (r_ - v_).detach()
# r_: The higher, the better. -ln(p(action)) * (discount_reward - value)
rl_loss += (-policy_log_probs[t] * a_ * mask_).sum()
rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss
if self.feedback == 'sample':
rl_loss += (- 0.01 * entropys[t] * mask_).sum()
self.logs['critic_loss'].append((((r_ - v_) ** 2) * mask_).sum().item())
total = total + np.sum(masks[t])
self.logs['total'].append(total)
# Normalize the loss function
if args.normalize_loss == 'total':
rl_loss /= total
elif args.normalize_loss == 'batch':
rl_loss /= batch_size
else:
assert args.normalize_loss == 'none'
self.loss += rl_loss
self.logs['RL_loss'].append(rl_loss.item())
if train_ml is not None:
self.loss += ml_loss * train_ml / batch_size
self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item())
if type(self.loss) is int: # For safety, it will be activated if no losses are added
self.losses.append(0.)
else:
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
return traj
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment '''
self.feedback = feedback
if use_dropout:
self.vln_bert.train()
self.critic.train()
else:
self.vln_bert.eval()
self.critic.eval()
super(Seq2SeqAgent, self).test(iters)
def zero_grad(self):
self.loss = 0.
self.losses = []
for model, optimizer in zip(self.models, self.optimizers):
model.train()
optimizer.zero_grad()
def accumulate_gradient(self, feedback='teacher', **kwargs):
if feedback == 'teacher':
self.feedback = 'teacher'
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
elif feedback == 'sample':
self.feedback = 'teacher'
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
self.feedback = 'sample'
self.rollout(train_ml=None, train_rl=True, **kwargs)
else:
assert False
def optim_step(self):
self.loss.backward()
torch.nn.utils.clip_grad_norm(self.vln_bert.parameters(), 40.)
self.vln_bert_optimizer.step()
self.critic_optimizer.step()
def train(self, n_iters, feedback='teacher', **kwargs):
''' Train for a given number of iterations '''
self.feedback = feedback
self.vln_bert.train()
self.critic.train()
self.losses = []
for iter in range(1, n_iters + 1):
self.vln_bert_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
self.loss = 0
if feedback == 'teacher':
self.feedback = 'teacher'
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
elif feedback == 'sample': # agents in IL and RL separately
if args.ml_weight != 0:
self.feedback = 'teacher'
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
self.feedback = 'sample'
self.rollout(train_ml=None, train_rl=True, **kwargs)
else:
assert False
self.loss.backward()
torch.nn.utils.clip_grad_norm_(self.vln_bert.parameters(), 40.)
self.vln_bert_optimizer.step()
self.critic_optimizer.step()
if args.aug is None:
print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50)
def save(self, epoch, path):
''' Snapshot models '''
the_dir, _ = os.path.split(path)
os.makedirs(the_dir, exist_ok=True)
states = {}
def create_state(name, model, optimizer):
states[name] = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
("critic", self.critic, self.critic_optimizer)]
for param in all_tuple:
create_state(*param)
torch.save(states, path)
def load(self, path):
''' Loads parameters (but not training state) '''
states = torch.load(path)
def recover_state(name, model, optimizer):
state = model.state_dict()
model_keys = set(state.keys())
load_keys = set(states[name]['state_dict'].keys())
if model_keys != load_keys:
print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
state.update(states[name]['state_dict'])
model.load_state_dict(state)
if args.loadOptim:
optimizer.load_state_dict(states[name]['optimizer'])
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
("critic", self.critic, self.critic_optimizer)]
for param in all_tuple:
recover_state(*param)
return states['vln_bert']['epoch'] - 1