adversarial_VLNBERT/r2r_src/train.py

294 lines
11 KiB
Python

import torch
import os
import time
import json
import random
import numpy as np
from collections import defaultdict
# from speaker import Speaker
from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress
# from utils import Tokenizer
import utils
from env import R2RBatch
from agent import Seq2SeqAgent
from eval import Evaluation
from param import args
import warnings
warnings.filterwarnings("ignore")
from tensorboardX import SummaryWriter
from vlnbert.vlnbert_model import get_tokenizer
log_dir = 'snap/%s' % args.name
if not os.path.exists(log_dir):
os.makedirs(log_dir)
TRAIN_VOCAB = 'data/train_vocab.txt'
TRAINVAL_VOCAB = 'data/trainval_vocab.txt'
IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv'
if args.features == 'imagenet':
features = IMAGENET_FEATURES
elif args.features == 'places365':
features = PLACE365_FEATURES
feedback_method = args.feedback # teacher or sample
print(args); print('')
''' train the listener '''
def train(train_env, tok, n_iters, log_every=1000, val_envs={}, aug_env=None):
writer = SummaryWriter(log_dir=log_dir)
listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)
record_file = open('./logs/' + args.name + '.txt', 'a')
record_file.write(str(args) + '\n\n')
record_file.close()
start_iter = 0
if args.load is not None:
if args.aug is None:
start_iter = listner.load(os.path.join(args.load))
print("\nLOAD the model from {}, iteration ".format(args.load, start_iter))
else:
load_iter = listner.load(os.path.join(args.load))
print("\nLOAD the model from {}, iteration ".format(args.load, load_iter))
# elif args.load_pretrain is not None:
# print("LOAD the pretrained model from %s" % args.load_pretrain)
# listner.load_pretrain(os.path.join(args.load_pretrain))
# print("Pretrained model loaded\n")
start = time.time()
print('\nListener training starts, start iteration: %s' % str(start_iter))
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":"", 'update':False}}
for idx in range(start_iter, start_iter+n_iters, log_every):
listner.logs = defaultdict(list)
interval = min(log_every, n_iters-idx)
iter = idx + interval
# Train for log_every interval
if aug_env is None: # The default training process
listner.env = train_env
listner.train(interval, feedback=feedback_method) # Train interval iters
print('-----------default training process no accumulate_grad')
else:
if args.accumulate_grad: # default False
None
else:
jdx_length = len(range(interval // 2))
for jdx in range(interval // 2):
# Train with GT data
listner.env = train_env
args.ml_weight = 0.2
listner.train(1, feedback=feedback_method)
# Train with Augmented data
listner.env = aug_env
args.ml_weight = 0.2
listner.train(1, feedback=feedback_method)
print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50)
# Log the training stats to tensorboard
total = max(sum(listner.logs['total']), 1)
length = max(len(listner.logs['critic_loss']), 1)
critic_loss = sum(listner.logs['critic_loss']) / total #/ length / args.batchSize
entropy = sum(listner.logs['entropy']) / total #/ length / args.batchSize
predict_loss = sum(listner.logs['us_loss']) / max(len(listner.logs['us_loss']), 1)
writer.add_scalar("loss/critic", critic_loss, idx)
writer.add_scalar("policy_entropy", entropy, idx)
writer.add_scalar("loss/unsupervised", predict_loss, idx)
writer.add_scalar("total_actions", total, idx)
writer.add_scalar("max_length", length, idx)
print("total_actions", total, ", max_length", length)
# Run validation
loss_str = "iter {}".format(iter)
for env_name, (env, evaluator) in val_envs.items():
listner.env = env
# Get validation loss under the same conditions as training
iters = None if args.fast_train or env_name != 'train' else 20 # 20 * 64 = 1280
# Get validation distance from goal under test evaluation conditions
listner.test(use_dropout=False, feedback='argmax', iters=iters)
result = listner.get_results()
score_summary, _ = evaluator.score(result)
loss_str += ", %s " % env_name
for metric,val in score_summary.items():
if metric in ['spl']:
writer.add_scalar("spl/%s" % env_name, val, idx)
if env_name in best_val:
if val > best_val[env_name]['spl']:
best_val[env_name]['spl'] = val
best_val[env_name]['update'] = True
elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
best_val[env_name]['spl'] = val
best_val[env_name]['update'] = True
loss_str += ', %s: %.4f' % (metric, val)
record_file = open('./logs/' + args.name + '.txt', 'a')
record_file.write(loss_str + '\n')
record_file.close()
for env_name in best_val:
if best_val[env_name]['update']:
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
best_val[env_name]['update'] = False
listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name)))
else:
listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict"))
print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters),
iter, float(iter)/n_iters*100, loss_str)))
if iter % 1000 == 0:
print("BEST RESULT TILL NOW")
for env_name in best_val:
print(env_name, best_val[env_name]['state'])
record_file = open('./logs/' + args.name + '.txt', 'a')
record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n')
record_file.close()
listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx)))
def valid(train_env, tok, val_envs={}):
agent = Seq2SeqAgent(train_env, "", tok, args.maxAction)
print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load))
for env_name, (env, evaluator) in val_envs.items():
agent.logs = defaultdict(list)
agent.env = env
iters = None
agent.test(use_dropout=False, feedback='argmax', iters=iters)
result = agent.get_results()
if env_name != '':
score_summary, _ = evaluator.score(result)
loss_str = "Env name: %s" % env_name
for metric,val in score_summary.items():
loss_str += ', %s: %.4f' % (metric, val)
print(loss_str)
if args.submit:
json.dump(
result,
open(os.path.join(log_dir, "submit_%s.json" % env_name), 'w'),
sort_keys=True, indent=4, separators=(',', ': ')
)
def setup():
torch.manual_seed(1)
torch.cuda.manual_seed(1)
random.seed(0)
np.random.seed(0)
# Check for vocabs
if not os.path.exists(TRAIN_VOCAB):
write_vocab(build_vocab(splits=['train']), TRAIN_VOCAB)
# if not os.path.exists(TRAINVAL_VOCAB):
# write_vocab(build_vocab(splits=['train','val_seen','val_unseen']), TRAINVAL_VOCAB)
def train_val(test_only=False):
''' Train on the training set, and validate on seen and unseen splits. '''
setup()
# Create a batch training environment that will also preprocess text
vocab = read_vocab(TRAIN_VOCAB)
# tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)
tok = get_tokenizer()
feat_dict = read_img_features(features, test_only=test_only)
if test_only:
featurized_scans = None
val_env_names = ['val_train_seen']
else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
# val_env_names = ['val_seen', 'val_unseen']
val_env_names = ['train', 'val_unseen']
# val_env_names = ['val_unseen']
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
from collections import OrderedDict
if args.submit:
val_env_names.append('test')
else:
pass
#val_env_names.append('train')
val_envs = OrderedDict(
((split,
(R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok),
Evaluation([split], featurized_scans, tok))
)
for split in val_env_names
)
)
if args.train == 'listener':
train(train_env, tok, args.iters, val_envs=val_envs)
# train(train_env, tok, 1000, val_envs=val_envs, log_every=10)
elif args.train == 'validlistener':
valid(train_env, tok, val_envs=val_envs)
else:
assert False
def train_val_augment(test_only=False):
"""
Train the listener with the augmented data
"""
setup()
# Create a batch training environment that will also preprocess text
vocab = read_vocab(TRAIN_VOCAB)
tok_bert = get_tokenizer()
# Load the env img features
feat_dict = read_img_features(features, test_only=test_only)
if test_only:
featurized_scans = None
val_env_names = ['val_train_seen']
else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
# Load the augmentation data
aug_path = args.aug
# Create the training environment
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert)
aug_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug')
# Setup the validation data
val_envs = {split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert),
Evaluation([split], featurized_scans, tok_bert))
for split in val_env_names}
# Start training
train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env)
if __name__ == "__main__":
if args.train in ['listener', 'validlistener']:
train_val(test_only=args.test_only)
elif args.train == 'auglistener':
train_val_augment(test_only=args.test_only)
else:
assert False