adversarial_AIRBERT/reverie_src/train.py

293 lines
11 KiB
Python

import torch
import os
import time
import json
import random
import numpy as np
import pickle as pkl
from collections import defaultdict
from utils import read_vocab, write_vocab, build_vocab, padding_idx, timeSince, read_img_features, print_progress
import utils
from env import R2RBatch
from agent import Seq2SeqAgent
from eval import Evaluation
from param import args
# import warnings
# warnings.filterwarnings("ignore")
from tensorboardX import SummaryWriter
from vlnbert.vlnbert_init import get_tokenizer
log_dir = 'snap/%s' % args.name
if not os.path.exists(log_dir):
os.makedirs(log_dir)
PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv'
features = PLACE365_FEATURES
feedback_method = args.feedback # teacher or sample
print(args); print('')
''' train the listener '''
def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
writer = SummaryWriter(log_dir=log_dir)
listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)
record_file = open(os.path.join(log_dir, 'train_log.txt'), 'a')
record_file.write(str(args) + '\n\n')
record_file.close()
start_iter = 0
if args.load is not None:
if args.aug is None:
start_iter = listner.load(os.path.join(args.load))
print("\nLOAD the model from {}, iteration ".format(args.load, start_iter))
else:
load_iter = listner.load(os.path.join(args.load))
print("\nLOAD the model from {}, iteration ".format(args.load, load_iter))
start = time.time()
print('\nListener training starts, start iteration: %s' % str(start_iter))
best_val = {'val_unseen': {"spl": 0., "sr": 0., "sspl": 0. ,"state":"", 'update':False}}
for idx in range(start_iter, start_iter+n_iters, log_every):
listner.logs = defaultdict(list)
interval = min(log_every, n_iters-idx)
iter = idx + interval
# Train for log_every interval
if aug_env is None:
listner.env = train_env
listner.train(interval, feedback=feedback_method) # Train interval iters
else:
jdx_length = len(range(interval // 2))
for jdx in range(interval // 2):
# Train with GT data
listner.env = train_env
args.ml_weight = 0.2
listner.train(1, feedback=feedback_method)
# Train with Augmented data
listner.env = aug_env
args.ml_weight = 0.2
listner.train(1, feedback=feedback_method)
print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50)
# Log the training stats to tensorboard
total = max(sum(listner.logs['total']), 1)
length = max(len(listner.logs['critic_loss']), 1)
critic_loss = sum(listner.logs['critic_loss']) / total
RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
REF_loss = sum(listner.logs['REF_loss']) / max(len(listner.logs['REF_loss']), 1)
entropy = sum(listner.logs['entropy']) / total
writer.add_scalar("loss/critic", critic_loss, idx)
writer.add_scalar("policy_entropy", entropy, idx)
writer.add_scalar("loss/RL_loss", RL_loss, idx)
writer.add_scalar("loss/IL_loss", IL_loss, idx)
writer.add_scalar("loss/REF_loss", REF_loss, idx)
writer.add_scalar("total_actions", total, idx)
writer.add_scalar("max_length", length, idx)
print("total_actions", total, ", max_length", length)
# Run validation
loss_str = "iter %d IL_loss %.2f RL_loss %.2f REF_loss %.2f critic_loss %.2f entropy %.2f" % (iter,
IL_loss, RL_loss, REF_loss, critic_loss, entropy)
for env_name, (env, evaluator) in val_envs.items():
listner.env = env
# Get validation distance from goal under test evaluation conditions
listner.test(use_dropout=False, feedback='argmax', iters=None)
result = listner.get_results()
score_summary, _ = evaluator.score(result)
loss_str += ", %s " % env_name
for metric, val in score_summary.items():
if metric in ['sspl']:
writer.add_scalar("sspl/%s" % env_name, val, idx)
if env_name in best_val:
if val > best_val[env_name]['sspl']:
best_val[env_name]['sspl'] = val
best_val[env_name]['update'] = True
elif (val == best_val[env_name]['sspl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
best_val[env_name]['sspl'] = val
best_val[env_name]['update'] = True
loss_str += ', %s: %.4f' % (metric, val)
record_file = open(os.path.join(log_dir, 'train_log.txt'), 'a')
record_file.write(loss_str + '\n')
record_file.close()
for env_name in best_val:
if best_val[env_name]['update']:
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
best_val[env_name]['update'] = False
listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name)))
else:
listner.save(idx, os.path.join("snap", args.name, "state_dict", "latest_dict"))
print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters),
iter, float(iter)/n_iters*100, loss_str)))
if iter % 1000 == 0:
print("BEST RESULT TILL NOW")
for env_name in best_val:
print(env_name, best_val[env_name]['state'])
record_file = open(os.path.join(log_dir, 'train_log.txt'), 'a')
record_file.write('BEST RESULT TILL NOW: ' + env_name + ' | ' + best_val[env_name]['state'] + '\n')
record_file.close()
listner.save(idx, os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx)))
def valid(train_env, tok, val_envs={}):
torch.set_grad_enabled(False)
agent = Seq2SeqAgent(train_env, "", tok, args.maxAction)
if args.load is None:
args.load = os.path.join('snap', args.name, 'state_dict', 'best_val_unseen')
print("Loaded the listener model at iter %d from %s" % (agent.load(args.load), args.load))
record_path = os.path.join(log_dir, 'valid_log.txt')
record_file = open(record_path, 'a')
record_file.write(str(args) + '\n\n')
record_file.close()
for env_name, (env, evaluator) in val_envs.items():
out_result_file = os.path.join(log_dir, "submit_%s.json" % (env_name))
if os.path.exists(out_result_file):
continue
agent.logs = defaultdict(list)
agent.env = env
iters = None
agent.test(use_dropout=False, feedback='argmax', iters=iters)
result = agent.get_results()
if args.submit:
json.dump(
result,
open(out_result_file, 'w'),
sort_keys=True, indent=4, separators=(',', ': ')
)
if env_name != '' and (not env_name.startswith('test')):
score_summary, _ = evaluator.score(result)
loss_str = "Env name: %s" % env_name
for metric,val in score_summary.items():
loss_str += ', %s: %.4f' % (metric, val)
print(loss_str)
record_file = open(record_path, 'a')
record_file.write(loss_str + '\n')
record_file.close()
def setup():
torch.manual_seed(1)
torch.cuda.manual_seed(1)
random.seed(0)
np.random.seed(0)
def train_val(test_only=False):
''' Train on the training set, and validate on seen and unseen splits. '''
setup()
tok = get_tokenizer(args)
feat_dict = read_img_features(features, test_only=test_only)
# load object feature
with open('data/REVERIE/BBoxS/REVERIE_obj_feats.pkl', 'rb') as f_obj:
obj_feats = pkl.load(f_obj)
if test_only:
featurized_scans = None
val_env_names = ['val_train_seen']
else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
val_env_names = ['val_seen', 'val_unseen']
train_env = R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
from collections import OrderedDict
if args.submit:
val_env_names.append('test')
else:
pass
val_envs = OrderedDict(
((split,
(R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=[split], tokenizer=tok),
Evaluation([split], featurized_scans, tok))
)
for split in val_env_names
)
)
val_envs = {key: value for key, value in val_envs.items() if len(value[0].data) > 0}
if args.train == 'listener':
train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs)
# train(train_env, tok, args.iters, log_every=100, val_envs=val_envs)
elif args.train == 'validlistener':
valid(train_env, tok, val_envs=val_envs)
else:
assert False
def train_val_augment(test_only=False):
"""
Train the listener with the augmented data
"""
setup()
# Create a batch training environment that will also preprocess text
tok_bert = get_tokenizer(args)
# Load the env img features
feat_dict = read_img_features(features, test_only=test_only)
# load object feature
with open('data/REVERIE/BBoxS/REVERIE_obj_feats.pkl', 'rb') as f_obj:
obj_feats = pkl.load(f_obj)
if test_only:
featurized_scans = None
val_env_names = ['val_train_seen']
else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
val_env_names = ['val_seen', 'val_unseen']
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
# Load the augmentation data
aug_path = args.aug
# Create the training environment
train_env = R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=['train'], tokenizer=tok_bert)
aug_env = R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok_bert, name='aug')
# Setup the validation data
val_envs = {split: (R2RBatch(feat_dict, obj_feats, batch_size=args.batchSize, splits=[split], tokenizer=tok_bert),
Evaluation([split], featurized_scans, tok_bert))
for split in val_env_names}
val_envs = {key: value for key, value in val_envs.items() if len(value[0].data) > 0}
# Start training
train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env,
log_every=args.log_every)
if __name__ == "__main__":
if args.train in ['listener', 'validlistener']:
train_val(test_only=args.test_only)
elif args.train == 'auglistener':
train_val_augment(test_only=args.test_only)
else:
assert False