import json import os import sys import numpy as np import random import math import time from collections import defaultdict import torch import torch.nn as nn from torch import optim import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from utils.distributed import is_default_gpu from utils.logger import print_progress class BaseAgent(object): ''' Base class for an REVERIE agent to generate and save trajectories. ''' def __init__(self, env): self.env = env self.results = {} def get_results(self, detailed_output=False): output = [] for k, v in self.results.items(): output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid']}) if detailed_output: output[-1]['details'] = v['details'] return output def rollout(self, **args): ''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] ''' raise NotImplementedError @staticmethod def get_agent(name): return globals()[name+"Agent"] def test(self, iters=None, **kwargs): self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch self.losses = [] self.results = {} # We rely on env showing the entire batch before repeating anything looped = False self.loss = 0 if iters is not None: # For each time, it will run the first 'iters' iterations. (It was shuffled before) for i in range(iters): for traj in self.rollout(**kwargs): self.loss = 0 self.results[traj['instr_id']] = traj else: # Do a full round while True: for traj in self.rollout(**kwargs): if traj['instr_id'] in self.results: looped = True else: self.loss = 0 self.results[traj['instr_id']] = traj if looped: break def test_viz(self, iters=None, **kwargs): self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch self.losses = [] self.results = {} # We rely on env showing the entire batch before repeating anything looped = False self.loss = 0 if iters is not None: # For each time, it will run the first 'iters' iterations. (It was shuffled before) for i in range(iters): for traj in self.rollout(**kwargs): self.loss = 0 self.results[traj['instr_id']] = traj else: # Do a full round while True: for traj in self.rollout_viz(**kwargs): if traj['instr_id'] in self.results: looped = True else: self.loss = 0 self.results[traj['instr_id']] = traj if looped: break class Seq2SeqAgent(BaseAgent): env_actions = { 'left': (0, -1, 0), # left 'right': (0, 1, 0), # right 'up': (0, 0, 1), # up 'down': (0, 0, -1), # down 'forward': (1, 0, 0), # forward '': (0, 0, 0), # '': (0, 0, 0), # '': (0, 0, 0) # } for k, v in env_actions.items(): env_actions[k] = [[vx] for vx in v] def __init__(self, args, env, rank=0): super().__init__(env) self.args = args self.default_gpu = is_default_gpu(self.args) self.rank = rank # Models self._build_model() if self.args.world_size > 1: self.vln_bert = DDP(self.vln_bert, device_ids=[self.rank], find_unused_parameters=True) self.critic = DDP(self.critic, device_ids=[self.rank], find_unused_parameters=True) self.models = (self.vln_bert, self.critic) self.device = torch.device('cuda:%d'%self.rank) # Optimizers if self.args.optim == 'rms': optimizer = torch.optim.RMSprop elif self.args.optim == 'adam': optimizer = torch.optim.Adam elif self.args.optim == 'adamW': optimizer = torch.optim.AdamW elif self.args.optim == 'sgd': optimizer = torch.optim.SGD else: assert False if self.default_gpu: print('Optimizer: %s' % self.args.optim) self.vln_bert_optimizer = optimizer(self.vln_bert.parameters(), lr=self.args.lr) self.critic_optimizer = optimizer(self.critic.parameters(), lr=self.args.lr) self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer) # Evaluations self.criterion = nn.CrossEntropyLoss(ignore_index=self.args.ignoreid, reduction='sum') # Logs sys.stdout.flush() self.logs = defaultdict(list) def _build_model(self): raise NotImplementedError('child class should implement _build_model: self.vln_bert & self.critic') def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None, viz=False): ''' Evaluate once on each instruction in the current environment ''' self.feedback = feedback if use_dropout: self.vln_bert.train() self.critic.train() else: self.vln_bert.eval() self.critic.eval() if viz: super().test_viz(iters=iters) else: super().test(iters=iters) def train(self, n_iters, feedback='teacher', **kwargs): ''' Train for a given number of iterations ''' self.feedback = feedback self.vln_bert.train() self.critic.train() self.losses = [] for iter in range(1, n_iters + 1): self.vln_bert_optimizer.zero_grad() self.critic_optimizer.zero_grad() self.loss = 0 if self.args.train_alg == 'imitation': self.feedback = 'teacher' self.rollout( train_ml=1., train_rl=False, **kwargs ) elif self.args.train_alg == 'dagger': if self.args.ml_weight != 0: self.feedback = 'teacher' self.rollout( train_ml=self.args.ml_weight, train_rl=False, **kwargs ) self.feedback = self.args.dagger_sample self.rollout(train_ml=1, train_rl=False, **kwargs) else: if self.args.ml_weight != 0: self.feedback = 'teacher' self.rollout( train_ml=self.args.ml_weight, train_rl=False, **kwargs ) self.feedback = 'sample' self.rollout(train_ml=None, train_rl=True, **kwargs) #print(self.rank, iter, self.loss) self.loss.backward() torch.nn.utils.clip_grad_norm_(self.vln_bert.parameters(), 40.) self.vln_bert_optimizer.step() self.critic_optimizer.step() if self.args.aug is None: print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50) def save(self, epoch, path): ''' Snapshot models ''' the_dir, _ = os.path.split(path) os.makedirs(the_dir, exist_ok=True) states = {} def create_state(name, model, optimizer): states[name] = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer), ("critic", self.critic, self.critic_optimizer)] for param in all_tuple: create_state(*param) torch.save(states, path) def load(self, path): ''' Loads parameters (but not training state) ''' states = torch.load(path) def recover_state(name, model, optimizer): state = model.state_dict() model_keys = set(state.keys()) load_keys = set(states[name]['state_dict'].keys()) state_dict = states[name]['state_dict'] if model_keys != load_keys: print("NOTICE: DIFFERENT KEYS IN THE LISTEREN") if not list(model_keys)[0].startswith('module.') and list(load_keys)[0].startswith('module.'): state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} if list(model_keys)[0].startswith('module.') and (not list(load_keys)[0].startswith('module.')): state_dict = {'module.'+k: v for k, v in state_dict.items()} same_state_dict = {} extra_keys = [] for k, v in state_dict.items(): if k in model_keys: same_state_dict[k] = v else: extra_keys.append(k) state_dict = same_state_dict print('Extra keys in state_dict: %s' % (', '.join(extra_keys))) state.update(state_dict) model.load_state_dict(state) if self.args.resume_optimizer: optimizer.load_state_dict(states[name]['optimizer']) all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer), ("critic", self.critic, self.critic_optimizer)] for param in all_tuple: recover_state(*param) return states['vln_bert']['epoch'] - 1