adversarial_AIRBERT/reverie_src/agent.py

777 lines
34 KiB
Python

import json
import os
import sys
import numpy as np
import random
import math
import time
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from env import R2RBatch
import utils
from utils import padding_idx, add_idx, Tokenizer, print_progress
from param import args
from collections import defaultdict
import model_OSCAR, model_CA
class BaseAgent(object):
''' Base class for an R2R agent to generate and save trajectories. '''
def __init__(self, env, results_path):
self.env = env
self.results_path = results_path
random.seed(1)
self.results = {}
self.losses = [] # For learning agents
def write_results(self):
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r} for k, (v,r) in self.results.items()]
with open(self.results_path, 'w') as f:
json.dump(output, f)
def get_results(self):
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r, 'found': found} for k, (v,r, found) in self.results.items()]
return output
def rollout(self, **args):
''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] '''
raise NotImplementedError
@staticmethod
def get_agent(name):
return globals()[name+"Agent"]
def test(self, iters=None, **kwargs):
self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch
self.losses = []
self.results = {}
# We rely on env showing the entire batch before repeating anything
looped = False
self.loss = 0
if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters):
trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
else: # Do a full round
while True:
trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
if traj['instr_id'] in self.results:
looped = True
else:
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
if looped:
break
class Seq2SeqAgent(BaseAgent):
''' An agent based on an LSTM seq2seq model with attention. '''
# For now, the agent can't pick which forward move to make - just the one in the middle
env_actions = {
'left': (0,-1, 0), # left
'right': (0, 1, 0), # right
'up': (0, 0, 1), # up
'down': (0, 0,-1), # down
'forward': (1, 0, 0), # forward
'<end>': (0, 0, 0), # <end>
'<start>': (0, 0, 0), # <start>
'<ignore>': (0, 0, 0) # <ignore>
}
def __init__(self, env, results_path, tok, episode_len=20):
super(Seq2SeqAgent, self).__init__(env, results_path)
self.tok = tok
self.episode_len = episode_len
self.feature_size = self.env.feature_size
# Models
if args.vlnbert == 'oscar':
self.vln_bert = model_OSCAR.VLNBERT(directions=args.directions,
feature_size=self.feature_size + args.angle_feat_size).cuda()
self.critic = model_OSCAR.Critic().cuda()
elif args.vlnbert == 'vilbert':
self.vln_bert = model_CA.VLNBERT(
feature_size=self.feature_size + args.angle_feat_size).cuda()
self.critic = model_CA.Critic().cuda()
# Optimizers
self.vln_bert_optimizer = args.optimizer(self.vln_bert.parameters(), lr=args.lr)
self.critic_optimizer = args.optimizer(self.critic.parameters(), lr=args.lr)
self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer)
# Evaluations
self.losses = []
self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
self.criterion_REF = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False)
# self.ndtw_criterion = utils.ndtw_initialize()
self.objProposals, self.obj2viewpoint = utils.loadObjProposals()
# Logs
sys.stdout.flush()
self.logs = defaultdict(list)
def _sort_batch(self, obs, sorted_instr=True):
''' Extract instructions from a list of observations and sort by descending
sequence length (to enable PyTorch packing). '''
seq_tensor = np.array([ob['instr_encoding'] for ob in obs])
seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1)
seq_lengths[seq_lengths == 0] = seq_tensor.shape[1] # Full length
seq_tensor = torch.from_numpy(seq_tensor)
seq_lengths = torch.from_numpy(seq_lengths)
# Sort sequences by lengths
if sorted_instr:
seq_lengths, perm_idx = seq_lengths.sort(0, True) # True -> descending
sorted_tensor = seq_tensor[perm_idx]
perm_idx = list(perm_idx)
else:
sorted_tensor = seq_tensor
perm_idx = None
mask = (sorted_tensor != padding_idx) # seq_lengths[0] is the Maximum length
token_type_ids = torch.zeros_like(mask)
visual_mask = torch.ones(args.directions).bool()
visual_mask = visual_mask.unsqueeze(0).repeat(mask.size(0),1)
visual_mask = torch.cat((mask, visual_mask), -1)
return sorted_tensor.long().cuda(), \
mask.bool().cuda(), token_type_ids.long().cuda(), \
visual_mask.bool().cuda(), \
list(seq_lengths), perm_idx
def _feature_variable(self, obs):
''' Extract precomputed features into variable. '''
features = np.empty((len(obs), args.directions, self.feature_size + args.angle_feat_size), dtype=np.float32)
for i, ob in enumerate(obs):
features[i, :, :] = ob['feature'] # Image feat
return torch.from_numpy(features).cuda()
def _candidate_variable(self, obs):
candidate_leng = [len(ob['candidate']) for ob in obs]
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
# which is zero in my implementation
for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature']
result = torch.from_numpy(candidate_feat)
'''
for i, ob in enumerate(obs):
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
'''
result = result.cuda()
return result, candidate_leng
def _object_variable(self, obs):
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
if args.vlnbert == 'vilbert':
cand_obj_feat = np.zeros((len(obs), max(cand_obj_leng), self.feature_size + 4), dtype=np.float32)
elif args.vlnbert == 'oscar':
cand_obj_feat = np.zeros((len(obs), max(cand_obj_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
cand_obj_pos = np.zeros((len(obs), max(cand_obj_leng), 5), dtype=np.float32)
for i, ob in enumerate(obs):
obj_local_pos, obj_features, candidate_objId = ob['candidate_obj']
for j, cc in enumerate(candidate_objId):
cand_obj_feat[i, j, :] = obj_features[j]
cand_obj_pos[i, j, :] = obj_local_pos[j]
return torch.from_numpy(cand_obj_feat).cuda(), torch.from_numpy(cand_obj_pos).cuda(), cand_obj_leng
def get_input_feat(self, obs):
input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32)
for i, ob in enumerate(obs):
input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation'])
input_a_t = torch.from_numpy(input_a_t).cuda()
# f_t = self._feature_variable(obs) # Image features from obs
f_t = None
candidate_feat, candidate_leng = self._candidate_variable(obs)
obj_feat, obj_pos, obj_leng = self._object_variable(obs)
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
def _teacher_action(self, obs, ended, cand_size, candidate_leng):
"""
Extract teacher actions into variable.
:param obs: The observation.
:param ended: Whether the action seq is ended
:return:
"""
a = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if ended[i]: # Just ignore this index
a[i] = args.ignoreid
else:
for k, candidate in enumerate(ob['candidate']):
if candidate['viewpointId'] == ob['teacher']: # Next view point
a[i] = k
break
else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
a[i] = cand_size - 1
'''
if ob['found']:
a[i] = cand_size - 1
else:
a[i] = candidate_leng[i] - 1
'''
return torch.from_numpy(a).cuda()
def _teacher_REF(self, obs, just_ended):
a = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if not just_ended[i]: # Just ignore this index
a[i] = args.ignoreid
else:
candidate_objs = ob['candidate_obj'][2]
for k, kid in enumerate(candidate_objs):
if kid == ob['objId']:
if ob['found']:
a[i] = k
break
else:
a[i] = len(candidate_objs)
break
else:
a[i] = args.ignoreid
return torch.from_numpy(a).cuda()
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
"""
Interface between Panoramic view and Egocentric view
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
"""
def take_action(i, idx, name):
if type(name) is int: # Go to the next view
self.env.env.sims[idx].makeAction([name], [0], [0])
else: # Adjust
action_params = self.env_actions[name]
self.env.env.sims[idx].makeAction([action_params[0]], [action_params[1]], [action_params[2]])
if perm_idx is None:
perm_idx = range(len(perm_obs))
for i, idx in enumerate(perm_idx):
action = a_t[i]
if action != -1 and action != -2: # -1 is the <stop> action
select_candidate = perm_obs[i]['candidate'][action]
src_point = perm_obs[i]['viewIndex']
trg_point = select_candidate['pointId']
src_level = (src_point ) // 12 # The point idx started from 0
trg_level = (trg_point ) // 12
while src_level < trg_level: # Tune up
take_action(i, idx, 'up')
src_level += 1
while src_level > trg_level: # Tune down
take_action(i, idx, 'down')
src_level -= 1
while self.env.env.sims[idx].getState()[0].viewIndex != trg_point: # Turn right until the target
take_action(i, idx, 'right')
assert select_candidate['viewpointId'] == \
self.env.env.sims[idx].getState()[0].navigableLocations[select_candidate['idx']].viewpointId
take_action(i, idx, select_candidate['idx'])
state = self.env.env.sims[idx].getState()[0]
if traj is not None:
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
def rollout(self, train_ml=None, train_rl=True, reset=True, speaker=None):
"""
:param train_ml: The weight to train with maximum likelihood
:param train_rl: whether use RL in training
:param reset: Reset the environment
:param speaker: Speaker used in back translation.
If the speaker is not None, use back translation.
O.w., normal training
:return:
"""
if self.feedback == 'teacher' or self.feedback == 'argmax':
train_rl = False
if reset: # Reset env
obs = np.array(self.env.reset())
else:
obs = np.array(self.env._get_obs())
batch_size = len(obs)
# Reorder the language input for the encoder (do not ruin the original code)
sentence, language_attention_mask, token_type_ids, \
visual_attention_mask, seq_lengths, perm_idx = self._sort_batch(obs)
perm_obs = obs[perm_idx]
''' Language BERT '''
language_inputs = {'mode': 'language',
'sentence': sentence,
'token_type_ids': token_type_ids}
# (batch_size, seq_len, hidden_size)
if args.vlnbert == 'oscar':
language_inputs['attention_mask'] = language_attention_mask
language_features = self.vln_bert(**language_inputs)
elif args.vlnbert == 'vilbert':
language_inputs['lang_masks'] = language_attention_mask
h_t, language_features = self.vln_bert(**language_inputs)
language_attention_mask = language_attention_mask[:, 1:]
# Record starting point
traj = [{
'instr_id': ob['instr_id'],
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
'predObjId': None
} for ob in perm_obs]
# Init the reward shaping
last_dist = np.zeros(batch_size, np.float32)
# last_ndtw = np.zeros(batch_size, np.float32)
for i, ob in enumerate(perm_obs): # The init distance from the view point to the target
last_dist[i] = ob['distance']
path_act = [vp[0] for vp in traj[i]['path']]
# last_ndtw[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
# Initialization the tracking state
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
just_ended = np.array([False] * batch_size)
found = np.array([None] * batch_size)
# Init the logs
rewards = []
hidden_states = []
policy_log_probs = []
masks = []
stop_mask = torch.tensor([False] * batch_size).cuda().unsqueeze(1)
entropys = []
ml_loss = 0.
ref_loss = 0.
# For test result submission: no backtracking
visited = [set() for _ in range(batch_size)]
for t in range(self.episode_len):
input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs)
# the first [CLS] token, initialized by the language BERT, servers
# as the agent's state passing through time steps
if args.vlnbert != 'vilbert' and t >= 1:
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).bool()
obj_temp_mask = (utils.length2mask(obj_leng) == 0).bool()
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
self.vln_bert.vln_bert.config.obj_directions = max(obj_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual',
'sentence': language_features,
'token_type_ids': token_type_ids,
'action_feats': input_a_t,
'pano_feats': f_t,
'cand_feats': candidate_feat,
'obj_feats': obj_feat,
'obj_pos': obj_pos,
'already_dropfeat': (speaker is not None)}
if args.vlnbert == 'oscar':
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1)
visual_inputs['attention_mask'] = visual_attention_mask
elif args.vlnbert == 'vilbert':
visual_inputs.update({
'h_t': h_t,
'lang_masks': language_attention_mask,
'cand_masks': visual_temp_mask,
'obj_masks': obj_temp_mask,
'act_t': t,
})
h_t, logit, logit_REF = self.vln_bert(**visual_inputs)
hidden_states.append(h_t)
# Mask outputs where agent can't move forward
# Here the logit is [b, max_candidate]
candidate_mask = utils.length2mask(candidate_leng)
candidate_mask = torch.cat((candidate_mask, stop_mask), dim=-1)
logit.masked_fill_(candidate_mask, -float('inf'))
candidate_mask_obj = utils.length2mask(obj_leng)
logit_REF.masked_fill_(candidate_mask_obj, -float('inf'))
if train_ml is not None:
# Supervised training
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
ml_loss += self.criterion(logit, target)
# Determine next model inputs
if self.feedback == 'teacher':
a_t = target # teacher forcing
elif self.feedback == 'argmax':
_, a_t = logit.max(1) # student forcing - argmax
a_t = a_t.detach()
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
elif self.feedback == 'sample':
probs = F.softmax(logit, 1) # sampling an action from model
c = torch.distributions.Categorical(probs)
self.logs['entropy'].append(c.entropy().sum().item()) # For log
entropys.append(c.entropy()) # For optimization
a_t = c.sample().detach()
policy_log_probs.append(c.log_prob(a_t))
else:
print(self.feedback)
sys.exit('Invalid feedback option')
# Prepare environment action
# NOTE: Env action is in the perm_obs space
cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t):
if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
and (not ended[i]): # just stoppped and forced stopped
just_ended[i] = True
if self.feedback == 'argmax':
_, ref_t = logit_REF[i].max(0)
if ref_t != obj_leng[i]-1: # decide not to do REF
traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t]
else:
traj[i]['ref'] = 'NOT_FOUND'
if args.submit:
if obj_leng[i] == 1:
traj[i]['predObjId'] = int("0")
else:
_, ref_t = logit_REF[i][:obj_leng[i]-1].max(0)
try:
traj[i]['predObjId'] = int(perm_obs[i]['candidate_obj'][2][ref_t])
except:
import pdb; pdb.set_trace()
else:
just_ended[i] = False
if (next_id == args.ignoreid) or (ended[i]):
cpu_a_t[i] = found[i]
elif (next_id == visual_temp_mask.size(1)):
cpu_a_t[i] = -1
found[i] = -1
if self.feedback == 'argmax':
_, ref_t = logit_REF[1].max(0)
if ref_t == obj_leng[i]-1:
found[i] = -2
else:
found[i] = -1
''' Supervised training for REF '''
if train_ml is not None:
target_obj = self._teacher_REF(perm_obs, just_ended)
ref_loss += self.criterion_REF(logit_REF, target_obj)
# Make action and get the new state
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
obs = np.array(self.env._get_obs())
perm_obs = obs[perm_idx] # Perm the obs for the resu
if train_rl:
# Calculate the mask and reward
dist = np.zeros(batch_size, np.float32)
# ndtw_score = np.zeros(batch_size, np.float32)
reward = np.zeros(batch_size, np.float32)
mask = np.ones(batch_size, np.float32)
for i, ob in enumerate(perm_obs):
dist[i] = ob['distance']
# path_act = [vp[0] for vp in traj[i]['path']]
# ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw')
if ended[i]:
reward[i] = 0.0
mask[i] = 0.0
else:
action_idx = cpu_a_t[i]
if action_idx == -1: # If the action now is end
# navigation success if the target object is visible when STOP
# end_viewpoint_id = ob['scan'] + '_' + ob['viewpoint']
# if self.objProposals.__contains__(end_viewpoint_id):
# if ob['objId'] in self.objProposals[end_viewpoint_id]['objId']:
# reward[i] = 2.0 + ndtw_score[i] * 2.0
# else:
# reward[i] = -2.0
# else:
# reward[i] = -2.0
if dist[i] < 1.0: # Correct
reward[i] = 2.0 # + ndtw_score[i] * 2.0
else: # Incorrect
reward[i] = -2.0
else: # The action is not end
# Change of distance and nDTW reward
reward[i] = - (dist[i] - last_dist[i])
# ndtw_reward = ndtw_score[i] - last_ndtw[i]
if reward[i] > 0.0: # Quantification
reward[i] = 1.0 # + ndtw_reward
elif reward[i] < 0.0:
reward[i] = -1.0 # + ndtw_reward
else:
raise NameError("The action doesn't change the move")
# miss the target penalty
if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0):
reward[i] -= (1.0 - last_dist[i]) * 2.0
rewards.append(reward)
masks.append(mask)
last_dist[:] = dist
# last_ndtw[:] = ndtw_score
# Update the finished actions
# -1 means ended or ignored (already ended)
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
# Early exit if all ended
if ended.all():
break
if train_rl:
# Last action in A2C
input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs)
if args.vlnbert != 'vilbert':
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
visual_temp_mask = (utils.length2mask(candidate_leng) == 0).bool()
obj_temp_mask = (utils.length2mask(obj_leng) == 0).bool()
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
self.vln_bert.vln_bert.config.obj_directions = max(obj_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual',
'sentence': language_features,
'token_type_ids': token_type_ids,
'action_feats': input_a_t,
'pano_feats': f_t,
'cand_feats': candidate_feat,
'obj_feats': obj_feat,
'obj_pos': obj_pos,
'already_dropfeat': (speaker is not None)}
if args.vlnbert == 'oscar':
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1)
visual_inputs['attention_mask'] = visual_attention_mask
elif args.vlnbert == 'vilbert':
visual_inputs.update({
'h_t': h_t,
'lang_masks': language_attention_mask,
'cand_masks': visual_temp_mask,
'obj_masks': obj_temp_mask,
'act_t': len(hidden_states),
})
last_h_, _, _ = self.vln_bert(**visual_inputs)
rl_loss = 0.
# NOW, A2C!!!
# Calculate the final discounted reward
last_value__ = self.critic(last_h_).detach() # The value esti of the last state, remove the grad for safety
discount_reward = np.zeros(batch_size, np.float32) # The inital reward is zero
for i in range(batch_size):
if not ended[i]: # If the action is not ended, use the value function as the last reward
discount_reward[i] = last_value__[i]
length = len(rewards)
total = 0
for t in range(length-1, -1, -1):
discount_reward = discount_reward * args.gamma + rewards[t] # If it ended, the reward will be 0
mask_ = torch.from_numpy(masks[t]).cuda()
clip_reward = discount_reward.copy()
r_ = torch.from_numpy(clip_reward).cuda()
v_ = self.critic(hidden_states[t])
a_ = (r_ - v_).detach()
# r_: The higher, the better. -ln(p(action)) * (discount_reward - value)
rl_loss += (-policy_log_probs[t] * a_ * mask_).sum()
rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss
if self.feedback == 'sample':
rl_loss += (- 0.01 * entropys[t] * mask_).sum()
self.logs['critic_loss'].append((((r_ - v_) ** 2) * mask_).sum().item())
total = total + np.sum(masks[t])
self.logs['total'].append(total)
# Normalize the loss function
if args.normalize_loss == 'total':
rl_loss /= total
elif args.normalize_loss == 'batch':
rl_loss /= batch_size
else:
assert args.normalize_loss == 'none'
self.loss += rl_loss
self.logs['RL_loss'].append(rl_loss.item())
if train_ml is not None:
self.loss += ml_loss * train_ml / batch_size
self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item())
self.loss += ref_loss * args.ref_loss_weight / batch_size
self.logs['REF_loss'].append(ref_loss.item() * args.ref_loss_weight / batch_size)
if type(self.loss) is int: # For safety, it will be activated if no losses are added
self.losses.append(0.)
else:
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
# import pdb; pdb.set_trace()
return traj, found
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment '''
self.feedback = feedback
if use_dropout:
self.vln_bert.train()
self.critic.train()
else:
self.vln_bert.eval()
self.critic.eval()
super(Seq2SeqAgent, self).test(iters)
def zero_grad(self):
self.loss = 0.
self.losses = []
for model, optimizer in zip(self.models, self.optimizers):
model.train()
optimizer.zero_grad()
def accumulate_gradient(self, feedback='teacher', **kwargs):
if feedback == 'teacher':
self.feedback = 'teacher'
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
elif feedback == 'sample':
self.feedback = 'teacher'
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
self.feedback = 'sample'
self.rollout(train_ml=None, train_rl=True, **kwargs)
else:
assert False
def optim_step(self):
self.loss.backward()
torch.nn.utils.clip_grad_norm_(self.vln_bert.parameters(), 40.)
self.vln_bert_optimizer.step()
self.critic_optimizer.step()
def train(self, n_iters, feedback='teacher', **kwargs):
''' Train for a given number of iterations '''
self.feedback = feedback
self.vln_bert.train()
self.critic.train()
self.losses = []
for iter in range(1, n_iters + 1):
self.vln_bert_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
self.loss = 0
if feedback == 'teacher':
self.feedback = 'teacher'
self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs)
elif feedback == 'sample': # agents in IL and RL separately
if args.ml_weight != 0:
self.feedback = 'teacher'
self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs)
self.feedback = 'sample'
self.rollout(train_ml=None, train_rl=True, **kwargs)
else:
assert False
self.loss.backward()
torch.nn.utils.clip_grad_norm_(self.vln_bert.parameters(), 40.)
self.vln_bert_optimizer.step()
self.critic_optimizer.step()
if args.aug is None:
print_progress(iter, n_iters, prefix='Progress:', suffix='Complete', bar_length=50)
def save(self, epoch, path):
''' Snapshot models '''
the_dir, _ = os.path.split(path)
os.makedirs(the_dir, exist_ok=True)
states = {}
def create_state(name, model, optimizer):
states[name] = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
("critic", self.critic, self.critic_optimizer)]
for param in all_tuple:
create_state(*param)
torch.save(states, path)
def load(self, path):
''' Loads parameters (but not training state) '''
states = torch.load(path)
def recover_state(name, model, optimizer):
state = model.state_dict()
model_keys = set(state.keys())
load_keys = set(states[name]['state_dict'].keys())
if model_keys != load_keys:
print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
state.update(states[name]['state_dict'])
model.load_state_dict(state)
if args.loadOptim:
optimizer.load_state_dict(states[name]['optimizer'])
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
("critic", self.critic, self.critic_optimizer)]
for param in all_tuple:
recover_state(*param)
return states['vln_bert']['epoch'] - 1
def load_pretrain(self, path):
''' Loads parameters from pretrained network '''
load_states = torch.load(path)
# print(self.vln_bert.state_dict()['candidate_att_layer.linear_in.weight'])
# print(self.vln_bert.state_dict()['visual_bert.bert.encoder.layer.9.intermediate.dense.weight'])
def recover_state(name, model):
state = model.state_dict()
model_keys = set(state.keys())
load_keys = set(load_states[name]['state_dict'].keys())
if model_keys != load_keys:
print("NOTICE: DIFFERENT KEYS FOUND IN MODEL")
for ikey in model_keys:
if ikey not in load_keys:
print('key not in model: ', ikey)
for ikey in load_keys:
if ikey not in model_keys:
print('key not in loaded states: ', ikey)
state.update(load_states[name]['state_dict'])
model.load_state_dict(state)
all_tuple = [("vln_bert", self.vln_bert)]
for param in all_tuple:
recover_state(*param)
return load_states['vln_bert']['epoch'] - 1