import torch import os import time import json import random import numpy as np from collections import defaultdict from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress import utils from env import R2RBatch from agent import Seq2SeqAgent from eval import Evaluation from param import args import warnings warnings.filterwarnings("ignore") from tensorboardX import SummaryWriter from vlnbert.vlnbert_init import get_tokenizer log_dir = 'snap/%s' % args.name if not os.path.exists(log_dir): os.makedirs(log_dir) IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv' PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv' if args.features == 'imagenet': features = IMAGENET_FEATURES elif args.features == 'places365': features = PLACE365_FEATURES feedback_method = args.feedback # teacher or sample print(args); print('') ''' train the listener ''' def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None): writer = SummaryWriter(log_dir=log_dir) listner = Seq2SeqAgent(train_env, "", tok, args.maxAction) record_file = open('./logs/' + args.name + '.txt', 'a') record_file.write(str(args) + '\n\n') record_file.close() start_iter = 0 if args.load is not None: if args.aug is None: start_iter = listner.load(os.path.join(args.load)) print("\nLOAD the model from {}, iteration ".format(args.load, start_iter)) else: load_iter = listner.load(os.path.join(args.load)) print("\nLOAD the model from {}, iteration ".format(args.load, load_iter)) start = time.time() print('\nListener training starts, start iteration: %s' % str(start_iter)) best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}, 'val_train_seen': {"spl": 0., "sr": 0., "state":"", 'update':False}} for idx in range(start_iter, start_iter+n_iters, log_every): listner.logs = defaultdict(list) interval = min(log_every, n_iters-idx) iter = idx + interval # Train for log_every interval if aug_env is None: listner.env = train_env listner.train(interval, feedback=feedback_method) # Train interval iters else: jdx_length = len(range(interval // 2)) for jdx in range(interval // 2): # Train with GT data listner.env = train_env args.ml_weight = 0.2 listner.train(1, feedback=feedback_method) # Train with Augmented data listner.env = aug_env args.ml_weight = 0.2 listner.train(1, feedback=feedback_method) print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50) # Log the training stats to tensorboard total = max(sum(listner.logs['total']), 1) length = max(len(listner.logs['critic_loss']), 1) critic_loss = sum(listner.logs['critic_loss']) / total RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1) IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1) entropy = sum(listner.logs['entropy']) / total print("training:") print("total: {}, length: {}, critic_loss: {}, RL_loss: {}, IL_loss:{}, entropy: {}".format(total, length, critic_loss, RL_loss, IL_loss, entropy )) writer.add_scalar("loss/critic", critic_loss, idx) writer.add_scalar("policy_entropy", entropy, idx) writer.add_scalar("loss/RL_loss", RL_loss, idx) writer.add_scalar("loss/IL_loss", IL_loss, idx) writer.add_scalar("total_actions", total, idx) writer.add_scalar("max_length", length, idx) # print("total_actions", total, ", max_length", length) # Run validation loss_str = "iter {}".format(iter) for env_name, (env, evaluator) in val_envs.items(): listner.env = env # Get validation distance from goal under test evaluation conditions listner.test(use_dropout=False, feedback='argmax', iters=None) result = listner.get_results() score_summary, _ = evaluator.score(result) loss_str += ", %s " % env_name for metric, val in score_summary.items(): if metric in ['spl']: writer.add_scalar("spl/%s" % env_name, val, idx) if env_name in best_val: if val > best_val[env_name]['spl']: best_val[env_name]['spl'] = val best_val[env_name]['update'] = True elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']): best_val[env_name]['spl'] = val best_val[env_name]['update'] = True loss_str += ', %s: %.4f' % (metric, val) record_file = open('./logs/' + args.name + '.txt', 'a') record_file.write(loss_str + '\n') record_file.close() for env_name in best_val: if best_val[env_name]['update']: best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str) best_val[env_name]['update'] = False listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name))) else: listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict")) print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters), iter, float(iter)/n_iters*100, loss_str))) if iter % log_every == 0: print("BEST RESULT TILL NOW") for env_name in best_val: print(env_name, best_val[env_name]['state']) record_file = open('./logs/' + args.name + '.txt', 'a') record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n') # print(best_val) record_file.close() listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx))) def valid(train_env, tok, val_envs={}): agent = Seq2SeqAgent(train_env, "", tok, args.maxAction) print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load)) for env_name, (env, evaluator) in val_envs.items(): agent.logs = defaultdict(list) agent.env = env iters = None agent.test(use_dropout=False, feedback='argmax', iters=iters) result = agent.get_results() if env_name != '': score_summary, _ = evaluator.score(result) loss_str = "Env name: %s" % env_name for metric,val in score_summary.items(): loss_str += ', %s: %.4f' % (metric, val) print(loss_str) if args.submit: json.dump( result, open(os.path.join(log_dir, "submit_%s.json" % env_name), 'w'), sort_keys=True, indent=4, separators=(',', ': ') ) def setup(): torch.manual_seed(1) torch.cuda.manual_seed(1) random.seed(0) np.random.seed(0) def train_val(test_only=False): ''' Train on the training set, and validate on seen and unseen splits. ''' setup() tok = get_tokenizer(args) feat_dict = read_img_features(features, test_only=test_only) if test_only: featurized_scans = None val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] val_env_names = ['val_train_seen'] train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) from collections import OrderedDict if args.submit: val_env_names.append('test') else: pass print("only {} in train_val()".format(val_env_names)) val_envs = OrderedDict( ((split, (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok), Evaluation([split], featurized_scans, tok)) ) for split in val_env_names ) ) if args.train == 'listener': train(train_env, tok, args.iters, val_envs=val_envs, log_every=1000) elif args.train == 'validlistener': valid(train_env, tok, val_envs=val_envs) else: assert False def train_val_augment(test_only=False): """ Train the listener with the augmented data """ setup() # Create a batch training environment that will also preprocess text tok_bert = get_tokenizer(args) # Load the env img features feat_dict = read_img_features(features, test_only=test_only) if test_only: featurized_scans = None val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] # Load the augmentation data aug_path = args.aug # Create the training environment train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert) aug_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug') # Setup the validation data val_envs = {split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert), Evaluation([split], featurized_scans, tok_bert)) for split in val_env_names} # Start training train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env) if __name__ == "__main__": device = torch.device('cuda') print(torch.cuda.get_device_name(device)) if args.train in ['listener', 'validlistener']: train_val(test_only=args.test_only) elif args.train == 'auglistener': train_val_augment(test_only=args.test_only) else: assert False