import torch import os import time import json import random import numpy as np import pickle as pkl from collections import defaultdict from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress import utils from env import R2RBatch from agent import Seq2SeqAgent from eval import Evaluation from param import args # import warnings # warnings.filterwarnings("ignore") from tensorboardX import SummaryWriter from vlnbert.vlnbert_init import get_tokenizer log_dir = 'snap/%s' % args.name if not os.path.exists(log_dir): os.makedirs(log_dir) PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv' features = PLACE365_FEATURES feedback_method = args.feedback # teacher or sample print(args); print('') ''' train the listener ''' def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None): writer = SummaryWriter(log_dir=log_dir) listner = Seq2SeqAgent(train_env, "", tok, args.maxAction) record_file = open(os.path.join(log_dir, 'train_log.txt'), 'a') record_file.write(str(args) + '\n\n') record_file.close() start_iter = 0 if args.load is not None: if args.aug is None: start_iter = listner.load(os.path.join(args.load)) print("\nLOAD the model from {}, iteration ".format(args.load, start_iter)) else: load_iter = listner.load(os.path.join(args.load)) print("\nLOAD the model from {}, iteration ".format(args.load, load_iter)) start = time.time() print('\nListener training starts, start iteration: %s' % str(start_iter)) best_val = {'val_unseen': {"spl": 0., "sr": 0., "sspl": 0. ,"state":"", 'update':False}} for idx in range(start_iter, start_iter+n_iters, log_every): listner.logs = defaultdict(list) interval = min(log_every, n_iters-idx) iter = idx + interval # Train for log_every interval if aug_env is None: listner.env = train_env listner.train(interval, feedback=feedback_method) # Train interval iters else: jdx_length = len(range(interval // 2)) for jdx in range(interval // 2): # Train with GT data listner.env = train_env args.ml_weight = 0.2 listner.train(1, feedback=feedback_method) # Train with Augmented data listner.env = aug_env args.ml_weight = 0.2 listner.train(1, feedback=feedback_method) print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50) # Log the training stats to tensorboard total = max(sum(listner.logs['total']), 1) length = max(len(listner.logs['critic_loss']), 1) critic_loss = sum(listner.logs['critic_loss']) / total RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1) IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1) REF_loss = sum(listner.logs['REF_loss']) / max(len(listner.logs['REF_loss']), 1) entropy = sum(listner.logs['entropy']) / total writer.add_scalar("loss/critic", critic_loss, idx) writer.add_scalar("policy_entropy", entropy, idx) writer.add_scalar("loss/RL_loss", RL_loss, idx) writer.add_scalar("loss/IL_loss", IL_loss, idx) writer.add_scalar("loss/REF_loss", REF_loss, idx) writer.add_scalar("total_actions", total, idx) writer.add_scalar("max_length", length, idx) print("total_actions", total, ", max_length", length) # Run validation loss_str = "iter %d IL_loss %.2f RL_loss %.2f REF_loss %.2f critic_loss %.2f entropy %.2f" % (iter, IL_loss, RL_loss, REF_loss, critic_loss, entropy) for env_name, (env, evaluator) in val_envs.items(): listner.env = env # Get validation distance from goal under test evaluation conditions listner.test(use_dropout=False, feedback='argmax', iters=None) result = listner.get_results() score_summary, _ = evaluator.score(result) loss_str += ", %s " % env_name for metric, val in score_summary.items(): if metric in ['spl']: writer.add_scalar("spl/%s" % env_name, val, idx) if env_name in best_val: if val > best_val[env_name]['spl']: best_val[env_name]['spl'] = val best_val[env_name]['update'] = True elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']): best_val[env_name]['spl'] = val best_val[env_name]['update'] = True loss_str += ', %s: %.4f' % (metric, val) record_file = open(os.path.join(log_dir, 'train_log.txt'), 'a') record_file.write(loss_str + '\n') record_file.close() for env_name in best_val: if best_val[env_name]['update']: best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str) best_val[env_name]['update'] = False listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name))) else: listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict")) print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters), iter, float(iter)/n_iters*100, loss_str))) if iter % 1000 == 0: print("BEST RESULT TILL NOW") for env_name in best_val: print(env_name, best_val[env_name]['state']) record_file = open(os.path.join(log_dir, 'train_log.txt'), 'a') record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n') record_file.close() listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx))) def valid(train_env, tok, val_envs={}): torch.set_grad_enabled(False) agent = Seq2SeqAgent(train_env, "", tok, args.maxAction) if args.load is None: args.load = os.path.join('snap', args.name, 'state_dict', 'best_val_unseen') print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load)) record_path = os.path.join(log_dir, 'valid_log.txt') record_file = open(record_path, 'a') record_file.write(str(args) + '\n\n') record_file.close() for env_name, (env, evaluator) in val_envs.items(): out_result_file = os.path.join(log_dir, "submit_%s.json" % (env_name)) if os.path.exists(out_result_file): continue agent.logs = defaultdict(list) agent.env = env iters = None agent.test(use_dropout=False, feedback='argmax', iters=iters) result = agent.get_results() if args.submit: json.dump( result, open(out_result_file, 'w'), sort_keys=True, indent=4, separators=(',', ': ') ) if env_name != '' and (not env_name.startswith('test')): score_summary, _ = evaluator.score(result) loss_str = "Env name: %s" % env_name for metric,val in score_summary.items(): loss_str += ', %s: %.4f' % (metric, val) print(loss_str) record_file = open(record_path, 'a') record_file.write(loss_str + '\n') record_file.close() def setup(): torch.manual_seed(1) torch.cuda.manual_seed(1) random.seed(0) np.random.seed(0) def train_val(test_only=False): ''' Train on the training set, and validate on seen and unseen splits. ''' setup() tok = get_tokenizer(args) feat_dict = read_img_features(features, test_only=test_only) # load object feature with open('data/REVERIE/BBoxS/REVERIE_obj_feats.pkl', 'rb') as f_obj: obj_feats = pkl.load(f_obj) if test_only: featurized_scans = None val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) val_env_names = ['val_seen', 'val_unseen'] train_env = R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=['train'], tokenizer=tok) from collections import OrderedDict if args.submit: val_env_names.append('test') else: pass val_envs = OrderedDict( ((split, (R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=[split], tokenizer=tok), Evaluation([split], featurized_scans, tok)) ) for split in val_env_names ) ) val_envs = {key: value for key, value in val_envs.items() if len(value[0].data) > 0} if args.train == 'listener': train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs) elif args.train == 'validlistener': valid(train_env, tok, val_envs=val_envs) else: assert False def train_val_augment(test_only=False): """ Train the listener with the augmented data """ setup() # Create a batch training environment that will also preprocess text tok_bert = get_tokenizer(args) # Load the env img features feat_dict = read_img_features(features, test_only=test_only) # load object feature with open('data/REVERIE/BBoxS/REVERIE_obj_feats.pkl', 'rb') as f_obj: obj_feats = pkl.load(f_obj) if test_only: featurized_scans = None val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) val_env_names = ['val_seen', 'val_unseen'] # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] # Load the augmentation data aug_path = args.aug # Create the training environment train_env = R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert) aug_env = R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug') # Setup the validation data val_envs = {split: (R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert), Evaluation([split], featurized_scans, tok_bert)) for split in val_env_names} val_envs = {key: value for key, value in val_envs.items() if len(value[0].data) > 0} # Start training train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env, log_every=args.log_every) if __name__ == "__main__": if args.train in ['listener', 'validlistener']: train_val(test_only=args.test_only) elif args.train == 'auglistener': train_val_augment(test_only=args.test_only) else: assert False