import os import json import time import numpy as np from collections import defaultdict import torch from tensorboardX import SummaryWriter from utils.misc import set_random_seed from utils.logger import write_to_record_file, print_progress, timeSince from utils.distributed import init_distributed, is_default_gpu from utils.distributed import all_gather, merge_dist_results from utils.data import ImageFeaturesDB from r2r.data_utils import construct_instrs from r2r.env import R2RNavBatch from r2r.parser import parse_args from models.vlnbert_init import get_tokenizer from r2r.agent import GMapNavAgent def build_dataset(args, rank=0, is_test=False): tok = get_tokenizer(args) feat_db = ImageFeaturesDB(args.img_ft_file, args.image_feat_size) dataset_class = R2RNavBatch # because we don't use distributed sampler here # in order to make different processes deal with different training examples # we need to shuffle the data with different seed in each processes if args.aug is not None: aug_instr_data = construct_instrs( args.anno_dir, args.dataset, [args.aug], tokenizer=args.tokenizer, max_instr_len=args.max_instr_len, is_test=is_test ) aug_env = dataset_class( feat_db, aug_instr_data, args.connectivity_dir, batch_size=args.batch_size, angle_feat_size=args.angle_feat_size, seed=args.seed+rank, sel_data_idxs=None, name='aug', ) else: aug_env = None train_instr_data = construct_instrs( args.anno_dir, args.dataset, ['train'], tokenizer=args.tokenizer, max_instr_len=args.max_instr_len, is_test=is_test ) train_env = dataset_class( feat_db, train_instr_data, args.connectivity_dir, batch_size=args.batch_size, angle_feat_size=args.angle_feat_size, seed=args.seed+rank, sel_data_idxs=None, name='train', ) # val_env_names = ['val_train_seen'] val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] if args.dataset == 'r4r' and (not args.test): val_env_names[-1] == 'val_unseen_sampled' if args.submit and args.dataset != 'r4r': val_env_names.append('test') val_envs = {} for split in val_env_names: val_instr_data = construct_instrs( args.anno_dir, args.dataset, [split], tokenizer=args.tokenizer, max_instr_len=args.max_instr_len, is_test=is_test ) val_env = dataset_class( feat_db, val_instr_data, args.connectivity_dir, batch_size=args.batch_size, angle_feat_size=args.angle_feat_size, seed=args.seed+rank, sel_data_idxs=None if args.world_size < 2 else (rank, args.world_size), name=split, ) # evaluation using all objects val_envs[split] = val_env return train_env, val_envs, aug_env def train(args, train_env, val_envs, aug_env=None, rank=-1): default_gpu = is_default_gpu(args) if default_gpu: with open(os.path.join(args.log_dir, 'training_args.json'), 'w') as outf: json.dump(vars(args), outf, indent=4) writer = SummaryWriter(log_dir=args.log_dir) record_file = os.path.join(args.log_dir, 'train.txt') write_to_record_file(str(args) + '\n\n', record_file) agent_class = GMapNavAgent listner = agent_class(args, train_env, rank=rank) # resume file start_iter = 0 if args.resume_file is not None: start_iter = listner.load(os.path.join(args.resume_file)) if default_gpu: write_to_record_file( "\nLOAD the model from {}, iteration ".format(args.resume_file, start_iter), record_file ) # first evaluation if args.eval_first: loss_str = "validation before training" for env_name, env in val_envs.items(): listner.env = env # Get validation distance from goal under test evaluation conditions listner.test(use_dropout=False, feedback='argmax', iters=None) preds = listner.get_results() # gather distributed results preds = merge_dist_results(all_gather(preds)) if default_gpu: score_summary, _ = env.eval_metrics(preds) loss_str += ", %s " % env_name for metric, val in score_summary.items(): loss_str += ', %s: %.2f' % (metric, val) if default_gpu: write_to_record_file(loss_str, record_file) # return start = time.time() if default_gpu: write_to_record_file( '\nListener training starts, start iteration: %s' % str(start_iter), record_file ) best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}} if args.dataset == 'r4r': best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}} for idx in range(start_iter, start_iter+args.iters, args.log_every): listner.logs = defaultdict(list) interval = min(args.log_every, args.iters-idx) iter = idx + interval # Train for log_every interval if aug_env is None: listner.env = train_env listner.train(interval, feedback=args.feedback) # Train interval iters else: jdx_length = len(range(interval // 2)) for jdx in range(interval // 2): # Train with GT data listner.env = train_env listner.train(1, feedback=args.feedback) # Train with Augmented data listner.env = aug_env listner.train(1, feedback=args.feedback) if default_gpu: print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50) if default_gpu: # Log the training stats to tensorboard total = max(sum(listner.logs['total']), 1) # RL: total valid actions for all examples in the batch length = max(len(listner.logs['critic_loss']), 1) # RL: total (max length) in the batch critic_loss = sum(listner.logs['critic_loss']) / total policy_loss = sum(listner.logs['policy_loss']) / total RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1) IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1) entropy = sum(listner.logs['entropy']) / total writer.add_scalar("loss/critic", critic_loss, idx) writer.add_scalar("policy_entropy", entropy, idx) writer.add_scalar("loss/RL_loss", RL_loss, idx) writer.add_scalar("loss/IL_loss", IL_loss, idx) writer.add_scalar("total_actions", total, idx) writer.add_scalar("max_length", length, idx) write_to_record_file( "\ntotal_actions %d, max_length %d, entropy %.4f, IL_loss %.4f, RL_loss %.4f, policy_loss %.4f, critic_loss %.4f" % ( total, length, entropy, IL_loss, RL_loss, policy_loss, critic_loss), record_file ) # Run validation loss_str = "iter {}".format(iter) for env_name, env in val_envs.items(): listner.env = env # Get validation distance from goal under test evaluation conditions listner.test(use_dropout=False, feedback='argmax', iters=None) preds = listner.get_results() preds = merge_dist_results(all_gather(preds)) if default_gpu: score_summary, _ = env.eval_metrics(preds) loss_str += ", %s " % env_name for metric, val in score_summary.items(): loss_str += ', %s: %.2f' % (metric, val) writer.add_scalar('%s/%s' % (metric, env_name), score_summary[metric], idx) # select model by spl if env_name in best_val: if score_summary['spl'] >= best_val[env_name]['spl']: best_val[env_name]['spl'] = score_summary['spl'] best_val[env_name]['sr'] = score_summary['sr'] best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str) listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name))) if default_gpu: listner.save(idx, os.path.join(args.ckpt_dir, "latest_dict")) write_to_record_file( ('%s (%d %d%%) %s' % (timeSince(start, float(iter)/args.iters), iter, float(iter)/args.iters*100, loss_str)), record_file ) write_to_record_file("BEST RESULT TILL NOW", record_file) for env_name in best_val: write_to_record_file(env_name + ' | ' + best_val[env_name]['state'], record_file) def valid(args, train_env, val_envs, rank=-1): default_gpu = is_default_gpu(args) agent_class = GMapNavAgent agent = agent_class(args, train_env, rank=rank) if args.resume_file is not None: print("Loaded the listener model at iter %d from %s" % ( agent.load(args.resume_file), args.resume_file)) if default_gpu: with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf: json.dump(vars(args), outf, indent=4) record_file = os.path.join(args.log_dir, 'valid.txt') write_to_record_file(str(args) + '\n\n', record_file) for env_name, env in val_envs.items(): prefix = 'submit' if args.detailed_output is False else 'detail' if os.path.exists(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name))): continue agent.logs = defaultdict(list) agent.env = env iters = None start_time = time.time() agent.test( use_dropout=False, feedback='argmax', iters=iters) print(env_name, 'cost time: %.2fs' % (time.time() - start_time)) preds = agent.get_results(detailed_output=args.detailed_output) preds = merge_dist_results(all_gather(preds)) if default_gpu: if 'test' not in env_name: score_summary, _ = env.eval_metrics(preds) loss_str = "Env name: %s" % env_name for metric, val in score_summary.items(): loss_str += ', %s: %.2f' % (metric, val) write_to_record_file(loss_str+'\n', record_file) if args.submit: json.dump( preds, open(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name)), 'w'), sort_keys=True, indent=4, separators=(',', ': ') ) def main(): args = parse_args() if args.world_size > 1: rank = init_distributed(args) torch.cuda.set_device(args.local_rank) else: rank = 0 set_random_seed(args.seed + rank) train_env, val_envs, aug_env = build_dataset(args, rank=rank, is_test=args.test) if not args.test: train(args, train_env, val_envs, aug_env=aug_env, rank=rank) else: valid(args, train_env, val_envs, rank=rank) if __name__ == '__main__': main()