264 lines
9.6 KiB
Python
264 lines
9.6 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
import random
|
|
import math
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import optim
|
|
import torch.nn.functional as F
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from utils.distributed import is_default_gpu
|
|
from utils.logger import print_progress
|
|
|
|
|
|
class BaseAgent(object):
|
|
''' Base class for an REVERIE agent to generate and save trajectories. '''
|
|
|
|
def __init__(self, env):
|
|
self.env = env
|
|
self.results = {}
|
|
|
|
def get_results(self, detailed_output=False):
|
|
output = []
|
|
for k, v in self.results.items():
|
|
output.append({'instr_id': k, 'trajectory': v['path'], 'pred_objid': v['pred_objid'], 'found': v['found'], 'gt_found': v['gt_found']})
|
|
if detailed_output:
|
|
output[-1]['details'] = v['details']
|
|
return output
|
|
|
|
def rollout(self, **args):
|
|
''' Return a list of dicts containing instr_id:'xx', path:[(viewpointId, heading_rad, elevation_rad)] '''
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_agent(name):
|
|
return globals()[name+"Agent"]
|
|
|
|
def test(self, iters=None, **kwargs):
|
|
self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch
|
|
self.losses = []
|
|
self.results = {}
|
|
# We rely on env showing the entire batch before repeating anything
|
|
looped = False
|
|
self.loss = 0
|
|
if iters is not None:
|
|
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
|
for i in range(iters):
|
|
for traj in self.rollout(**kwargs):
|
|
self.loss = 0
|
|
self.results[traj['instr_id']] = traj
|
|
else: # Do a full round
|
|
while True:
|
|
for traj in self.rollout(**kwargs):
|
|
if traj['instr_id'] in self.results:
|
|
looped = True
|
|
else:
|
|
self.loss = 0
|
|
self.results[traj['instr_id']] = traj
|
|
if looped:
|
|
break
|
|
|
|
def test_viz(self, iters=None, **kwargs):
|
|
self.env.reset_epoch(shuffle=(iters is not None)) # If iters is not none, shuffle the env batch
|
|
self.losses = []
|
|
self.results = {}
|
|
# We rely on env showing the entire batch before repeating anything
|
|
looped = False
|
|
self.loss = 0
|
|
if iters is not None:
|
|
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
|
for i in range(iters):
|
|
for traj in self.rollout(**kwargs):
|
|
self.loss = 0
|
|
self.results[traj['instr_id']] = traj
|
|
else: # Do a full round
|
|
while True:
|
|
for traj in self.rollout_viz(**kwargs):
|
|
if traj['instr_id'] in self.results:
|
|
looped = True
|
|
else:
|
|
self.loss = 0
|
|
self.results[traj['instr_id']] = traj
|
|
if looped:
|
|
break
|
|
|
|
class Seq2SeqAgent(BaseAgent):
|
|
env_actions = {
|
|
'left': (0, -1, 0), # left
|
|
'right': (0, 1, 0), # right
|
|
'up': (0, 0, 1), # up
|
|
'down': (0, 0, -1), # down
|
|
'forward': (1, 0, 0), # forward
|
|
'<end>': (0, 0, 0), # <end>
|
|
'<start>': (0, 0, 0), # <start>
|
|
'<ignore>': (0, 0, 0) # <ignore>
|
|
}
|
|
for k, v in env_actions.items():
|
|
env_actions[k] = [[vx] for vx in v]
|
|
|
|
def __init__(self, args, env, rank=0):
|
|
super().__init__(env)
|
|
self.args = args
|
|
|
|
self.default_gpu = is_default_gpu(self.args)
|
|
self.rank = rank
|
|
|
|
# Models
|
|
self._build_model()
|
|
|
|
if self.args.world_size > 1:
|
|
self.vln_bert = DDP(self.vln_bert, device_ids=[self.rank], find_unused_parameters=True)
|
|
self.critic = DDP(self.critic, device_ids=[self.rank], find_unused_parameters=True)
|
|
|
|
self.models = (self.vln_bert, self.critic)
|
|
self.device = torch.device('cuda:%d'%self.rank)
|
|
|
|
# Optimizers
|
|
if self.args.optim == 'rms':
|
|
optimizer = torch.optim.RMSprop
|
|
elif self.args.optim == 'adam':
|
|
optimizer = torch.optim.Adam
|
|
elif self.args.optim == 'adamW':
|
|
optimizer = torch.optim.AdamW
|
|
elif self.args.optim == 'sgd':
|
|
optimizer = torch.optim.SGD
|
|
else:
|
|
assert False
|
|
if self.default_gpu:
|
|
print('Optimizer: %s' % self.args.optim)
|
|
|
|
self.vln_bert_optimizer = optimizer(self.vln_bert.parameters(), lr=self.args.lr)
|
|
self.critic_optimizer = optimizer(self.critic.parameters(), lr=self.args.lr)
|
|
self.optimizers = (self.vln_bert_optimizer, self.critic_optimizer)
|
|
|
|
# Evaluations
|
|
self.criterion = nn.CrossEntropyLoss(ignore_index=self.args.ignoreid, reduction='sum')
|
|
|
|
# Logs
|
|
sys.stdout.flush()
|
|
self.logs = defaultdict(list)
|
|
|
|
def _build_model(self):
|
|
raise NotImplementedError('child class should implement _build_model: self.vln_bert & self.critic')
|
|
|
|
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None, viz=False):
|
|
''' Evaluate once on each instruction in the current environment '''
|
|
self.feedback = feedback
|
|
if use_dropout:
|
|
self.vln_bert.train()
|
|
self.critic.train()
|
|
else:
|
|
self.vln_bert.eval()
|
|
self.critic.eval()
|
|
if viz:
|
|
super().test_viz(iters=iters)
|
|
else:
|
|
super().test(iters=iters)
|
|
|
|
def train(self, n_iters, feedback='teacher', **kwargs):
|
|
''' Train for a given number of iterations '''
|
|
self.feedback = feedback
|
|
|
|
self.vln_bert.train()
|
|
self.critic.train()
|
|
|
|
self.losses = []
|
|
for iter in range(1, n_iters + 1):
|
|
|
|
self.vln_bert_optimizer.zero_grad()
|
|
self.critic_optimizer.zero_grad()
|
|
|
|
self.loss = 0
|
|
|
|
if self.args.train_alg == 'imitation':
|
|
self.feedback = 'teacher'
|
|
self.rollout(
|
|
train_ml=1., train_rl=False, **kwargs
|
|
)
|
|
elif self.args.train_alg == 'dagger':
|
|
if self.args.ml_weight != 0:
|
|
self.feedback = 'teacher'
|
|
self.rollout(
|
|
train_ml=self.args.ml_weight, train_rl=False, **kwargs
|
|
)
|
|
self.feedback = self.args.dagger_sample
|
|
self.rollout(train_ml=1, train_rl=False, **kwargs)
|
|
else:
|
|
if self.args.ml_weight != 0:
|
|
self.feedback = 'teacher'
|
|
self.rollout(
|
|
train_ml=self.args.ml_weight, train_rl=False, **kwargs
|
|
)
|
|
self.feedback = 'sample'
|
|
self.rollout(train_ml=None, train_rl=True, **kwargs)
|
|
|
|
#print(self.rank, iter, self.loss)
|
|
self.loss.backward()
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.vln_bert.parameters(), 40.)
|
|
|
|
self.vln_bert_optimizer.step()
|
|
self.critic_optimizer.step()
|
|
|
|
if self.args.aug is None:
|
|
print_progress(iter, n_iters+1, prefix='Progress:', suffix='Complete', bar_length=50)
|
|
|
|
def save(self, epoch, path):
|
|
''' Snapshot models '''
|
|
the_dir, _ = os.path.split(path)
|
|
os.makedirs(the_dir, exist_ok=True)
|
|
states = {}
|
|
def create_state(name, model, optimizer):
|
|
states[name] = {
|
|
'epoch': epoch + 1,
|
|
'state_dict': model.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
}
|
|
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
|
|
("critic", self.critic, self.critic_optimizer)]
|
|
for param in all_tuple:
|
|
create_state(*param)
|
|
torch.save(states, path)
|
|
|
|
def load(self, path):
|
|
''' Loads parameters (but not training state) '''
|
|
states = torch.load(path)
|
|
|
|
def recover_state(name, model, optimizer):
|
|
state = model.state_dict()
|
|
model_keys = set(state.keys())
|
|
load_keys = set(states[name]['state_dict'].keys())
|
|
state_dict = states[name]['state_dict']
|
|
if model_keys != load_keys:
|
|
print("NOTICE: DIFFERENT KEYS IN THE LISTEREN")
|
|
if not list(model_keys)[0].startswith('module.') and list(load_keys)[0].startswith('module.'):
|
|
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
if list(model_keys)[0].startswith('module.') and (not list(load_keys)[0].startswith('module.')):
|
|
state_dict = {'module.'+k: v for k, v in state_dict.items()}
|
|
same_state_dict = {}
|
|
extra_keys = []
|
|
for k, v in state_dict.items():
|
|
if k in model_keys:
|
|
same_state_dict[k] = v
|
|
else:
|
|
extra_keys.append(k)
|
|
state_dict = same_state_dict
|
|
print('Extra keys in state_dict: %s' % (', '.join(extra_keys)))
|
|
state.update(state_dict)
|
|
model.load_state_dict(state)
|
|
if self.args.resume_optimizer:
|
|
optimizer.load_state_dict(states[name]['optimizer'])
|
|
all_tuple = [("vln_bert", self.vln_bert, self.vln_bert_optimizer),
|
|
("critic", self.critic, self.critic_optimizer)]
|
|
for param in all_tuple:
|
|
recover_state(*param)
|
|
return states['vln_bert']['epoch'] - 1
|
|
|
|
|