import torch import os import time import json import random import numpy as np from collections import defaultdict # from speaker import Speaker from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress # from utils import Tokenizer import utils from env import R2RBatch from agent import Seq2SeqAgent from eval import Evaluation from param import args import warnings warnings.filterwarnings("ignore") from tensorboardX import SummaryWriter from vlnbert.vlnbert_model import get_tokenizer log_dir = 'snap/%s' % args.name if not os.path.exists(log_dir): os.makedirs(log_dir) TRAIN_VOCAB = 'data/train_vocab.txt' TRAINVAL_VOCAB = 'data/trainval_vocab.txt' IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv' PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv' if args.features == 'imagenet': features = IMAGENET_FEATURES elif args.features == 'places365': features = PLACE365_FEATURES feedback_method = args.feedback # teacher or sample print(args); print('') ''' train the listener ''' def train(train_env, tok, n_iters, log_every=1000, val_envs={}, aug_env=None): writer = SummaryWriter(log_dir=log_dir) listner = Seq2SeqAgent(train_env, "", tok, args.maxAction) record_file = open('./logs/' + args.name + '.txt', 'a') record_file.write(str(args) + '\n\n') record_file.close() start_iter = 0 if args.load is not None: if args.aug is None: start_iter = listner.load(os.path.join(args.load)) print("\nLOAD the model from {}, iteration ".format(args.load, start_iter)) else: load_iter = listner.load(os.path.join(args.load)) print("\nLOAD the model from {}, iteration ".format(args.load, load_iter)) # elif args.load_pretrain is not None: # print("LOAD the pretrained model from %s" % args.load_pretrain) # listner.load_pretrain(os.path.join(args.load_pretrain)) # print("Pretrained model loaded\n") start = time.time() print('\nListener training starts, start iteration: %s' % str(start_iter)) best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}} for idx in range(start_iter, start_iter+n_iters, log_every): listner.logs = defaultdict(list) interval = min(log_every, n_iters-idx) iter = idx + interval # Train for log_every interval if aug_env is None: # The default training process listner.env = train_env listner.train(interval, feedback=feedback_method) # Train interval iters print('-----------default training process no accumulate_grad') else: if args.accumulate_grad: # default False None else: jdx_length = len(range(interval // 2)) for jdx in range(interval // 2): # Train with GT data listner.env = train_env args.ml_weight = 0.2 listner.train(1, feedback=feedback_method) # Train with Augmented data listner.env = aug_env args.ml_weight = 0.2 listner.train(1, feedback=feedback_method) print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50) # Log the training stats to tensorboard total = max(sum(listner.logs['total']), 1) length = max(len(listner.logs['critic_loss']), 1) critic_loss = sum(listner.logs['critic_loss']) / total #/ length / args.batchSize entropy = sum(listner.logs['entropy']) / total #/ length / args.batchSize predict_loss = sum(listner.logs['us_loss']) / max(len(listner.logs['us_loss']), 1) writer.add_scalar("loss/critic", critic_loss, idx) writer.add_scalar("policy_entropy", entropy, idx) writer.add_scalar("loss/unsupervised", predict_loss, idx) writer.add_scalar("total_actions", total, idx) writer.add_scalar("max_length", length, idx) print("total_actions", total, ", max_length", length) # Run validation loss_str = "iter {}".format(iter) for env_name, (env, evaluator) in val_envs.items(): listner.env = env # Get validation loss under the same conditions as training iters = None if args.fast_train or env_name != 'train' else 20 # 20 * 64 = 1280 # Get validation distance from goal under test evaluation conditions listner.test(use_dropout=False, feedback='argmax', iters=iters) result = listner.get_results() score_summary, _ = evaluator.score(result) loss_str += ", %s " % env_name for metric,val in score_summary.items(): if metric in ['spl']: writer.add_scalar("spl/%s" % env_name, val, idx) if env_name in best_val: if val > best_val[env_name]['spl']: best_val[env_name]['spl'] = val best_val[env_name]['update'] = True elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']): best_val[env_name]['spl'] = val best_val[env_name]['update'] = True loss_str += ', %s: %.4f' % (metric, val) record_file = open('./logs/' + args.name + '.txt', 'a') record_file.write(loss_str + '\n') record_file.close() for env_name in best_val: if best_val[env_name]['update']: best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str) best_val[env_name]['update'] = False listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name))) else: listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict")) print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters), iter, float(iter)/n_iters*100, loss_str))) if iter % 1000 == 0: print("BEST RESULT TILL NOW") for env_name in best_val: print(env_name, best_val[env_name]['state']) record_file = open('./logs/' + args.name + '.txt', 'a') record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n') record_file.close() listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx))) def valid(train_env, tok, val_envs={}): agent = Seq2SeqAgent(train_env, "", tok, args.maxAction) print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load)) for env_name, (env, evaluator) in val_envs.items(): agent.logs = defaultdict(list) agent.env = env iters = None agent.test(use_dropout=False, feedback='argmax', iters=iters) result = agent.get_results() if env_name != '': score_summary, _ = evaluator.score(result) loss_str = "Env name: %s" % env_name for metric,val in score_summary.items(): loss_str += ', %s: %.4f' % (metric, val) print(loss_str) if args.submit: json.dump( result, open(os.path.join(log_dir, "submit_%s.json" % env_name), 'w'), sort_keys=True, indent=4, separators=(',', ': ') ) def setup(): torch.manual_seed(1) torch.cuda.manual_seed(1) random.seed(0) np.random.seed(0) # Check for vocabs if not os.path.exists(TRAIN_VOCAB): write_vocab(build_vocab(splits=['train']), TRAIN_VOCAB) # if not os.path.exists(TRAINVAL_VOCAB): # write_vocab(build_vocab(splits=['train','val_seen','val_unseen']), TRAINVAL_VOCAB) def train_val(test_only=False): ''' Train on the training set, and validate on seen and unseen splits. ''' setup() # Create a batch training environment that will also preprocess text vocab = read_vocab(TRAIN_VOCAB) # tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput) tok = get_tokenizer() feat_dict = read_img_features(features, test_only=test_only) if test_only: featurized_scans = None val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) # val_env_names = ['val_seen', 'val_unseen'] val_env_names = ['train', 'val_seen', 'val_unseen'] # val_env_names = ['val_unseen'] train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) from collections import OrderedDict if args.submit: val_env_names.append('test') else: pass #val_env_names.append('train') val_envs = OrderedDict( ((split, (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok), Evaluation([split], featurized_scans, tok)) ) for split in val_env_names ) ) if args.train == 'listener': train(train_env, tok, args.iters, val_envs=val_envs) # train(train_env, tok, 1000, val_envs=val_envs, log_every=10) elif args.train == 'validlistener': valid(train_env, tok, val_envs=val_envs) else: assert False def train_val_augment(test_only=False): """ Train the listener with the augmented data """ setup() # Create a batch training environment that will also preprocess text vocab = read_vocab(TRAIN_VOCAB) tok_bert = get_tokenizer() # Load the env img features feat_dict = read_img_features(features, test_only=test_only) if test_only: featurized_scans = None val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] # Load the augmentation data aug_path = args.aug # Create the training environment train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert) aug_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug') # Setup the validation data val_envs = {split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert), Evaluation([split], featurized_scans, tok_bert)) for split in val_env_names} # Start training train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env) if __name__ == "__main__": if args.train in ['listener', 'validlistener']: train_val(test_only=args.test_only) elif args.train == 'auglistener': train_val_augment(test_only=args.test_only) else: assert False