import json import os import sys import numpy as np import random import math import time import torch import torch.nn as nn from torch import optim import torch.nn.functional as F from env import R2RBatch import utils from utils import padding_idx, add_idx, Tokenizer, print_progress from param import args from collections import defaultdict import model_OSCAR, model_CA class BaseAgent(object): ''' Base class for an R2R agent to generate and save trajectories. ''' def __init__(self, env, results_path): self.env = env self.results_path = results_path random.seed(1) self.results = {} self.losses = [] # For learning agents def write_results(self): output = [{'instr_id': k, 'trajectory': v, 'predObjId': r} for k, (v,r) in self.results.items()] with open(self.results_path, 'w') as f: json.dump(output, f) def get_results(self): output = [{'instr_id': k, 'trajectory': v, 'predObjId': r, 'found': found} for k, (v,r, found) in self.results.items()] return output def rollout(self, **args): ''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] ''' raise NotImplementedError @staticmethod def get_agent(name): return globals()[name+"Agent"] def test(self, iters=None, **kwargs): self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch self.losses = [] self.results = {} # We rely on env showing the entire batch before repeating anything looped = False self.loss = 0 if iters is not None: # For each time, it will run the first 'iters' iterations. (It was shuffled before) for i in range(iters): trajs, found = self.rollout(**kwargs) for index, traj in enumerate(trajs): self.loss = 0 self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index]) else: # Do a full round while True: trajs, found = self.rollout(**kwargs) for index, traj in enumerate(trajs): if traj['instr_id'] in self.results: looped = True else: self.loss = 0 self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index]) if looped: break class Seq2SeqAgent(BaseAgent): ''' An agent based on an LSTM seq2seq model with attention. ''' # For now, the agent can't pick which forward move to make - just the one in the middle env_actions = { 'left': (0,-1, 0), # left 'right': (0, 1, 0), # right 'up': (0, 0, 1), # up 'down': (0, 0,-1), # down 'forward': (1, 0, 0), # forward '': (0, 0, 0), # '': (0, 0, 0), # '': (0, 0, 0) # } def __init__(self, env, results_path, tok, episode_len=20): super(Seq2SeqAgent, self).__init__(env, results_path) self.tok = tok self.episode_len = episode_len self.feature_size = self.env.feature_size # Models if args.vlnbert == 'oscar': self.vln_bert = model_OSCAR.VLNBERT(directions=args.directions, feature_size=self.feature_size + args.angle_feat_size).cuda() self.critic = model_OSCAR.Critic().cuda() elif args.vlnbert == 'vilbert': self.vln_bert = model_CA.VLNBERT( feature_size=self.feature_size + args.angle_feat_size).cuda() self.critic = model_CA.Critic().cuda() # Optimizers self.vln_bert_optimizer = args.optimizer(self.vln_bert.parameters(), lr=args.lr) self.critic_optimizer = args.optimizer(self.critic.parameters(), lr=args.lr) self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer) # Evaluations self.losses = [] self.criterion = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False) self.criterion_REF = nn.CrossEntropyLoss(ignore_index=args.ignoreid, size_average=False) # self.ndtw_criterion = utils.ndtw_initialize() self.objProposals, self.obj2viewpoint = utils.loadObjProposals() # Logs sys.stdout.flush() self.logs = defaultdict(list) def _sort_batch(self, obs, sorted_instr=True): ''' Extract instructions from a list of observations and sort by descending sequence length (to enable PyTorch packing). ''' seq_tensor = np.array([ob['instr_encoding'] for ob in obs]) seq_lengths = np.argmax(seq_tensor == padding_idx, axis=1) seq_lengths[seq_lengths == 0] = seq_tensor.shape[1] # Full length seq_tensor = torch.from_numpy(seq_tensor) seq_lengths = torch.from_numpy(seq_lengths) # Sort sequences by lengths if sorted_instr: seq_lengths, perm_idx = seq_lengths.sort(0, True) # True -> descending sorted_tensor = seq_tensor[perm_idx] perm_idx = list(perm_idx) else: sorted_tensor = seq_tensor perm_idx = None mask = (sorted_tensor != padding_idx) # seq_lengths[0] is the Maximum length token_type_ids = torch.zeros_like(mask) visual_mask = torch.ones(args.directions).bool() visual_mask = visual_mask.unsqueeze(0).repeat(mask.size(0),1) visual_mask = torch.cat((mask, visual_mask), -1) return sorted_tensor.long().cuda(), \ mask.bool().cuda(), token_type_ids.long().cuda(), \ visual_mask.bool().cuda(), \ list(seq_lengths), perm_idx def _feature_variable(self, obs): ''' Extract precomputed features into variable. ''' features = np.empty((len(obs), args.directions, self.feature_size + args.angle_feat_size), dtype=np.float32) for i, ob in enumerate(obs): features[i, :, :] = ob['feature'] # Image feat return torch.from_numpy(features).cuda() def _candidate_variable(self, obs): candidate_leng = [len(ob['candidate']) for ob in obs] candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) # Note: The candidate_feat at len(ob['candidate']) is the feature for the END # which is zero in my implementation for i, ob in enumerate(obs): for j, cc in enumerate(ob['candidate']): candidate_feat[i, j, :] = cc['feature'] result = torch.from_numpy(candidate_feat) ''' for i, ob in enumerate(obs): result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32) ''' result = result.cuda() return result, candidate_leng def _object_variable(self, obs): cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF if args.vlnbert == 'vilbert': cand_obj_feat = np.zeros((len(obs), max(cand_obj_leng), self.feature_size + 4), dtype=np.float32) elif args.vlnbert == 'oscar': cand_obj_feat = np.zeros((len(obs), max(cand_obj_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) cand_obj_pos = np.zeros((len(obs), max(cand_obj_leng), 5), dtype=np.float32) for i, ob in enumerate(obs): obj_local_pos, obj_features, candidate_objId = ob['candidate_obj'] for j, cc in enumerate(candidate_objId): cand_obj_feat[i, j, :] = obj_features[j] cand_obj_pos[i, j, :] = obj_local_pos[j] return torch.from_numpy(cand_obj_feat).cuda(), torch.from_numpy(cand_obj_pos).cuda(), cand_obj_leng def get_input_feat(self, obs): input_a_t = np.zeros((len(obs), args.angle_feat_size), np.float32) for i, ob in enumerate(obs): input_a_t[i] = utils.angle_feature(ob['heading'], ob['elevation']) input_a_t = torch.from_numpy(input_a_t).cuda() # f_t = self._feature_variable(obs) # Image features from obs f_t = None candidate_feat, candidate_leng = self._candidate_variable(obs) obj_feat, obj_pos, obj_leng = self._object_variable(obs) return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng def _teacher_action(self, obs, ended, cand_size, candidate_leng): """ Extract teacher actions into variable. :param obs: The observation. :param ended: Whether the action seq is ended :return: """ a = np.zeros(len(obs), dtype=np.int64) for i, ob in enumerate(obs): if ended[i]: # Just ignore this index a[i] = args.ignoreid else: for k, candidate in enumerate(ob['candidate']): if candidate['viewpointId'] == ob['teacher']: # Next view point a[i] = k break else: # Stop here assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" a[i] = cand_size - 1 ''' if ob['found']: a[i] = cand_size - 1 else: a[i] = candidate_leng[i] - 1 ''' return torch.from_numpy(a).cuda() def _teacher_REF(self, obs, just_ended): a = np.zeros(len(obs), dtype=np.int64) for i, ob in enumerate(obs): if not just_ended[i]: # Just ignore this index a[i] = args.ignoreid else: candidate_objs = ob['candidate_obj'][2] for k, kid in enumerate(candidate_objs): if kid == ob['objId']: if ob['found']: a[i] = k break else: a[i] = len(candidate_objs) break else: a[i] = args.ignoreid return torch.from_numpy(a).cuda() def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None): """ Interface between Panoramic view and Egocentric view It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator """ def take_action(i, idx, name): if type(name) is int: # Go to the next view self.env.env.sims[idx].makeAction([name], [0], [0]) else: # Adjust action_params = self.env_actions[name] self.env.env.sims[idx].makeAction([action_params[0]], [action_params[1]], [action_params[2]]) if perm_idx is None: perm_idx = range(len(perm_obs)) for i, idx in enumerate(perm_idx): action = a_t[i] if action != -1 and action != -2: # -1 is the action select_candidate = perm_obs[i]['candidate'][action] src_point = perm_obs[i]['viewIndex'] trg_point = select_candidate['pointId'] src_level = (src_point ) // 12 # The point idx started from 0 trg_level = (trg_point ) // 12 while src_level < trg_level: # Tune up take_action(i, idx, 'up') src_level += 1 while src_level > trg_level: # Tune down take_action(i, idx, 'down') src_level -= 1 while self.env.env.sims[idx].getState()[0].viewIndex != trg_point: # Turn right until the target take_action(i, idx, 'right') assert select_candidate['viewpointId'] == \ self.env.env.sims[idx].getState()[0].navigableLocations[select_candidate['idx']].viewpointId take_action(i, idx, select_candidate['idx']) state = self.env.env.sims[idx].getState()[0] if traj is not None: traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation)) def rollout(self, train_ml=None, train_rl=True, reset=True, speaker=None): """ :param train_ml: The weight to train with maximum likelihood :param train_rl: whether use RL in training :param reset: Reset the environment :param speaker: Speaker used in back translation. If the speaker is not None, use back translation. O.w., normal training :return: """ if self.feedback == 'teacher' or self.feedback == 'argmax': train_rl = False if reset: # Reset env obs = np.array(self.env.reset()) else: obs = np.array(self.env._get_obs()) batch_size = len(obs) # Reorder the language input for the encoder (do not ruin the original code) sentence, language_attention_mask, token_type_ids, \ visual_attention_mask, seq_lengths, perm_idx = self._sort_batch(obs) perm_obs = obs[perm_idx] ''' Language BERT ''' language_inputs = {'mode': 'language', 'sentence': sentence, 'token_type_ids': token_type_ids} # (batch_size, seq_len, hidden_size) if args.vlnbert == 'oscar': language_inputs['attention_mask'] = language_attention_mask language_features = self.vln_bert(**language_inputs) elif args.vlnbert == 'vilbert': language_inputs['lang_masks'] = language_attention_mask h_t, language_features = self.vln_bert(**language_inputs) language_attention_mask = language_attention_mask[:, 1:] # Record starting point traj = [{ 'instr_id': ob['instr_id'], 'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])], 'predObjId': None } for ob in perm_obs] # Init the reward shaping last_dist = np.zeros(batch_size, np.float32) # last_ndtw = np.zeros(batch_size, np.float32) for i, ob in enumerate(perm_obs): # The init distance from the view point to the target last_dist[i] = ob['distance'] path_act = [vp[0] for vp in traj[i]['path']] # last_ndtw[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw') # Initialization the tracking state ended = np.array([False] * batch_size) # Indices match permuation of the model, not env just_ended = np.array([False] * batch_size) found = np.array([None] * batch_size) # Init the logs rewards = [] hidden_states = [] policy_log_probs = [] masks = [] stop_mask = torch.tensor([False] * batch_size).cuda().unsqueeze(1) entropys = [] ml_loss = 0. ref_loss = 0. # For test result submission: no backtracking visited = [set() for _ in range(batch_size)] for t in range(self.episode_len): input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs) # the first [CLS] token, initialized by the language BERT, servers # as the agent's state passing through time steps if args.vlnbert != 'vilbert' and t >= 1: language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1) visual_temp_mask = (utils.length2mask(candidate_leng) == 0).bool() obj_temp_mask = (utils.length2mask(obj_leng) == 0).bool() visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1) self.vln_bert.vln_bert.config.directions = max(candidate_leng) self.vln_bert.vln_bert.config.obj_directions = max(obj_leng) ''' Visual BERT ''' visual_inputs = {'mode': 'visual', 'sentence': language_features, 'token_type_ids': token_type_ids, 'action_feats': input_a_t, 'pano_feats': f_t, 'cand_feats': candidate_feat, 'obj_feats': obj_feat, 'obj_pos': obj_pos, 'already_dropfeat': (speaker is not None)} if args.vlnbert == 'oscar': visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1) visual_inputs['attention_mask'] = visual_attention_mask elif args.vlnbert == 'vilbert': visual_inputs.update({ 'h_t': h_t, 'lang_masks': language_attention_mask, 'cand_masks': visual_temp_mask, 'obj_masks': obj_temp_mask, 'act_t': t, }) h_t, logit, logit_REF = self.vln_bert(**visual_inputs) hidden_states.append(h_t) # Mask outputs where agent can't move forward # Here the logit is [b, max_candidate] candidate_mask = utils.length2mask(candidate_leng) candidate_mask = torch.cat((candidate_mask, stop_mask), dim=-1) logit.masked_fill_(candidate_mask, -float('inf')) candidate_mask_obj = utils.length2mask(obj_leng) logit_REF.masked_fill_(candidate_mask_obj, -float('inf')) if train_ml is not None: # Supervised training target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng) ml_loss += self.criterion(logit, target) # Determine next model inputs if self.feedback == 'teacher': a_t = target # teacher forcing elif self.feedback == 'argmax': _, a_t = logit.max(1) # student forcing - argmax a_t = a_t.detach() log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch elif self.feedback == 'sample': probs = F.softmax(logit, 1) # sampling an action from model c = torch.distributions.Categorical(probs) self.logs['entropy'].append(c.entropy().sum().item()) # For log entropys.append(c.entropy()) # For optimization a_t = c.sample().detach() policy_log_probs.append(c.log_prob(a_t)) else: print(self.feedback) sys.exit('Invalid feedback option') # Prepare environment action # NOTE: Env action is in the perm_obs space cpu_a_t = a_t.cpu().numpy() for i, next_id in enumerate(cpu_a_t): if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \ and (not ended[i]): # just stoppped and forced stopped just_ended[i] = True if self.feedback == 'argmax': _, ref_t = logit_REF[i].max(0) if ref_t != obj_leng[i]-1: # decide not to do REF traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t] else: traj[i]['ref'] = 'NOT_FOUND' if args.submit: if obj_leng[i] == 1: traj[i]['predObjId'] = int("0") else: _, ref_t = logit_REF[i][:obj_leng[i]-1].max(0) try: traj[i]['predObjId'] = int(perm_obs[i]['candidate_obj'][2][ref_t]) except: import pdb; pdb.set_trace() else: just_ended[i] = False if (next_id == args.ignoreid) or (ended[i]): cpu_a_t[i] = found[i] elif (next_id == visual_temp_mask.size(1)): cpu_a_t[i] = -1 found[i] = -1 if self.feedback == 'argmax': _, ref_t = logit_REF[1].max(0) if ref_t == obj_leng[i]-1: found[i] = -2 else: found[i] = -1 ''' Supervised training for REF ''' if train_ml is not None: target_obj = self._teacher_REF(perm_obs, just_ended) ref_loss += self.criterion_REF(logit_REF, target_obj) # Make action and get the new state self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj) obs = np.array(self.env._get_obs()) perm_obs = obs[perm_idx] # Perm the obs for the resu if train_rl: # Calculate the mask and reward dist = np.zeros(batch_size, np.float32) # ndtw_score = np.zeros(batch_size, np.float32) reward = np.zeros(batch_size, np.float32) mask = np.ones(batch_size, np.float32) for i, ob in enumerate(perm_obs): dist[i] = ob['distance'] # path_act = [vp[0] for vp in traj[i]['path']] # ndtw_score[i] = self.ndtw_criterion[ob['scan']](path_act, ob['gt_path'], metric='ndtw') if ended[i]: reward[i] = 0.0 mask[i] = 0.0 else: action_idx = cpu_a_t[i] if action_idx == -1: # If the action now is end # navigation success if the target object is visible when STOP # end_viewpoint_id = ob['scan'] + '_' + ob['viewpoint'] # if self.objProposals.__contains__(end_viewpoint_id): # if ob['objId'] in self.objProposals[end_viewpoint_id]['objId']: # reward[i] = 2.0 + ndtw_score[i] * 2.0 # else: # reward[i] = -2.0 # else: # reward[i] = -2.0 if dist[i] < 1.0: # Correct reward[i] = 2.0 # + ndtw_score[i] * 2.0 else: # Incorrect reward[i] = -2.0 else: # The action is not end # Change of distance and nDTW reward reward[i] = - (dist[i] - last_dist[i]) # ndtw_reward = ndtw_score[i] - last_ndtw[i] if reward[i] > 0.0: # Quantification reward[i] = 1.0 # + ndtw_reward elif reward[i] < 0.0: reward[i] = -1.0 # + ndtw_reward else: raise NameError("The action doesn't change the move") # miss the target penalty if (last_dist[i] <= 1.0) and (dist[i]-last_dist[i] > 0.0): reward[i] -= (1.0 - last_dist[i]) * 2.0 rewards.append(reward) masks.append(mask) last_dist[:] = dist # last_ndtw[:] = ndtw_score # Update the finished actions # -1 means ended or ignored (already ended) ended[:] = np.logical_or(ended, (cpu_a_t == -1)) # Early exit if all ended if ended.all(): break if train_rl: # Last action in A2C input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs) if args.vlnbert != 'vilbert': language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1) visual_temp_mask = (utils.length2mask(candidate_leng) == 0).bool() obj_temp_mask = (utils.length2mask(obj_leng) == 0).bool() visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1) self.vln_bert.vln_bert.config.directions = max(candidate_leng) self.vln_bert.vln_bert.config.obj_directions = max(obj_leng) ''' Visual BERT ''' visual_inputs = {'mode': 'visual', 'sentence': language_features, 'token_type_ids': token_type_ids, 'action_feats': input_a_t, 'pano_feats': f_t, 'cand_feats': candidate_feat, 'obj_feats': obj_feat, 'obj_pos': obj_pos, 'already_dropfeat': (speaker is not None)} if args.vlnbert == 'oscar': visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask, obj_temp_mask), dim=-1) visual_inputs['attention_mask'] = visual_attention_mask elif args.vlnbert == 'vilbert': visual_inputs.update({ 'h_t': h_t, 'lang_masks': language_attention_mask, 'cand_masks': visual_temp_mask, 'obj_masks': obj_temp_mask, 'act_t': len(hidden_states), }) last_h_, _, _ = self.vln_bert(**visual_inputs) rl_loss = 0. # NOW, A2C!!! # Calculate the final discounted reward last_value__ = self.critic(last_h_).detach() # The value esti of the last state, remove the grad for safety discount_reward = np.zeros(batch_size, np.float32) # The inital reward is zero for i in range(batch_size): if not ended[i]: # If the action is not ended, use the value function as the last reward discount_reward[i] = last_value__[i] length = len(rewards) total = 0 for t in range(length-1, -1, -1): discount_reward = discount_reward * args.gamma + rewards[t] # If it ended, the reward will be 0 mask_ = torch.from_numpy(masks[t]).cuda() clip_reward = discount_reward.copy() r_ = torch.from_numpy(clip_reward).cuda() v_ = self.critic(hidden_states[t]) a_ = (r_ - v_).detach() # r_: The higher, the better. -ln(p(action)) * (discount_reward - value) rl_loss += (-policy_log_probs[t] * a_ * mask_).sum() rl_loss += (((r_ - v_) ** 2) * mask_).sum() * 0.5 # 1/2 L2 loss if self.feedback == 'sample': rl_loss += (- 0.01 * entropys[t] * mask_).sum() self.logs['critic_loss'].append((((r_ - v_) ** 2) * mask_).sum().item()) total = total + np.sum(masks[t]) self.logs['total'].append(total) # Normalize the loss function if args.normalize_loss == 'total': rl_loss /= total elif args.normalize_loss == 'batch': rl_loss /= batch_size else: assert args.normalize_loss == 'none' self.loss += rl_loss self.logs['RL_loss'].append(rl_loss.item()) if train_ml is not None: self.loss += ml_loss * train_ml / batch_size self.logs['IL_loss'].append((ml_loss * train_ml / batch_size).item()) self.loss += ref_loss * args.ref_loss_weight / batch_size self.logs['REF_loss'].append(ref_loss.item() * args.ref_loss_weight / batch_size) if type(self.loss) is int: # For safety, it will be activated if no losses are added self.losses.append(0.) else: self.losses.append(self.loss.item() / self.episode_len) # This argument is useless. # import pdb; pdb.set_trace() return traj, found def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None): ''' Evaluate once on each instruction in the current environment ''' self.feedback = feedback if use_dropout: self.vln_bert.train() self.critic.train() else: self.vln_bert.eval() self.critic.eval() super(Seq2SeqAgent, self).test(iters) def zero_grad(self): self.loss = 0. self.losses = [] for model, optimizer in zip(self.models, self.optimizers): model.train() optimizer.zero_grad() def accumulate_gradient(self, feedback='teacher', **kwargs): if feedback == 'teacher': self.feedback = 'teacher' self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs) elif feedback == 'sample': self.feedback = 'teacher' self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs) self.feedback = 'sample' self.rollout(train_ml=None, train_rl=True, **kwargs) else: assert False def optim_step(self): self.loss.backward() torch.nn.utils.clip_grad_norm_(self.vln_bert.parameters(), 40.) self.vln_bert_optimizer.step() self.critic_optimizer.step() def train(self, n_iters, feedback='teacher', **kwargs): ''' Train for a given number of iterations ''' self.feedback = feedback self.vln_bert.train() self.critic.train() self.losses = [] for iter in range(1, n_iters + 1): self.vln_bert_optimizer.zero_grad() self.critic_optimizer.zero_grad() self.loss = 0 if feedback == 'teacher': self.feedback = 'teacher' self.rollout(train_ml=args.teacher_weight, train_rl=False, **kwargs) elif feedback == 'sample': # agents in IL and RL separately if args.ml_weight != 0: self.feedback = 'teacher' self.rollout(train_ml=args.ml_weight, train_rl=False, **kwargs) self.feedback = 'sample' self.rollout(train_ml=None, train_rl=True, **kwargs) else: assert False self.loss.backward() torch.nn.utils.clip_grad_norm_(self.vln_bert.parameters(), 40.) self.vln_bert_optimizer.step() self.critic_optimizer.step() if args.aug is None: print_progress(iter, n_iters, prefix='Progress:', suffix='Complete', bar_length=50) def save(self, epoch, path): ''' Snapshot models ''' the_dir, _ = os.path.split(path) os.makedirs(the_dir, exist_ok=True) states = {} def create_state(name, model, optimizer): states[name] = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer), ("critic", self.critic, self.critic_optimizer)] for param in all_tuple: create_state(*param) torch.save(states, path) def load(self, path): ''' Loads parameters (but not training state) ''' states = torch.load(path) def recover_state(name, model, optimizer): state = model.state_dict() model_keys = set(state.keys()) load_keys = set(states[name]['state_dict'].keys()) if model_keys != load_keys: print("NOTICE: DIFFERENT KEYS IN THE LISTEREN") state.update(states[name]['state_dict']) model.load_state_dict(state) if args.loadOptim: optimizer.load_state_dict(states[name]['optimizer']) all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer), ("critic", self.critic, self.critic_optimizer)] for param in all_tuple: recover_state(*param) return states['vln_bert']['epoch'] - 1 def load_pretrain(self, path): ''' Loads parameters from pretrained network ''' load_states = torch.load(path) # print(self.vln_bert.state_dict()['candidate_att_layer.linear_in.weight']) # print(self.vln_bert.state_dict()['visual_bert.bert.encoder.layer.9.intermediate.dense.weight']) def recover_state(name, model): state = model.state_dict() model_keys = set(state.keys()) load_keys = set(load_states[name]['state_dict'].keys()) if model_keys != load_keys: print("NOTICE: DIFFERENT KEYS FOUND IN MODEL") for ikey in model_keys: if ikey not in load_keys: print('key not in model: ', ikey) for ikey in load_keys: if ikey not in model_keys: print('key not in loaded states: ', ikey) state.update(load_states[name]['state_dict']) model.load_state_dict(state) all_tuple = [("vln_bert", self.vln_bert)] for param in all_tuple: recover_state(*param) return load_states['vln_bert']['epoch'] - 1