adversarial_VLNDUET/map_nav_src/soon/main.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

293 lines
12 KiB
Python

import os
import json
import time
import numpy as np
from collections import defaultdict
import torch
from tensorboardX import SummaryWriter
from utils.misc import set_random_seed
from utils.logger import write_to_record_file, print_progress, timeSince
from utils.distributed import init_distributed, is_default_gpu
from utils.distributed import all_gather, merge_dist_results
from models.vlnbert_init import get_tokenizer
from soon.agent_obj import SoonGMapObjectNavAgent
from utils.data import ImageFeaturesDB
from soon.data_utils import ObjectFeatureDB, construct_instrs
from soon.env import SoonObjectNavBatch
from soon.parser import parse_args
def build_dataset(args, rank=0, is_test=True):
tok = get_tokenizer(args)
feat_db = ImageFeaturesDB(args.img_ft_file, args.image_feat_size)
obj_db = ObjectFeatureDB(args.obj_ft_file, args.obj_feat_size)
dataset_class = SoonObjectNavBatch
# because we don't use distributed sampler here
# in order to make different processes deal with different training examples
# we need to shuffle the data with different seed in each processes
if args.aug is not None:
aug_instr_data = construct_instrs(
args.anno_dir, args.dataset, [args.aug], instr_type=args.instr_type,
tokenizer=args.tokenizer, max_instr_len=args.max_instr_len
)
aug_env = dataset_class(
feat_db, obj_db, aug_instr_data, args.connectivity_dir,
batch_size=args.batch_size, max_objects=args.max_objects,
angle_feat_size=args.angle_feat_size, is_train=True,
seed=args.seed+rank, sel_data_idxs=None, name='aug',
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
)
else:
aug_env = None
train_instr_data = construct_instrs(
args.anno_dir, args.dataset, ['train'], instr_type=args.instr_type,
tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
)
train_env = dataset_class(
feat_db, obj_db, train_instr_data, args.connectivity_dir,
batch_size=args.batch_size, max_objects=args.max_objects,
angle_feat_size=args.angle_feat_size, seed=args.seed+rank,
sel_data_idxs=None, name='train', is_train=True,
multi_endpoints=args.multi_endpoints, multi_startpoints=args.multi_startpoints,
)
val_env_names = ['val_train', 'val_unseen_instrs', 'val_unseen_house']
if args.submit:
val_env_names.append('test')
val_env_names.append('test_v2')
val_envs = {}
for split in val_env_names:
val_instr_data = construct_instrs(
args.anno_dir, args.dataset, [split], instr_type=args.instr_type,
tokenizer=args.tokenizer, max_instr_len=args.max_instr_len,
)
if split == 'val_train':
val_instr_data = val_instr_data[:100:2]
else:
if not is_test:
val_instr_data = val_instr_data[::5]
val_env = dataset_class(
feat_db, obj_db, val_instr_data, args.connectivity_dir, batch_size=args.batch_size*2,
angle_feat_size=args.angle_feat_size, seed=args.seed+rank,
sel_data_idxs=None if args.world_size < 2 else (rank, args.world_size), name=split,
max_objects=None, multi_endpoints=False, multi_startpoints=False, is_train=False,
) # evaluation using all objects
val_envs[split] = val_env
return train_env, val_envs, aug_env
def train(args, train_env, val_envs, aug_env=None, rank=-1):
default_gpu = is_default_gpu(args)
if default_gpu:
with open(os.path.join(args.log_dir, 'training_args.json'), 'w') as outf:
json.dump(vars(args), outf, indent=4)
writer = SummaryWriter(log_dir=args.log_dir)
record_file = os.path.join(args.log_dir, 'train.txt')
write_to_record_file(str(args) + '\n\n', record_file)
agent_class = SoonGMapObjectNavAgent
listner = agent_class(args, train_env, rank=rank)
# resume file
start_iter = 0
if args.resume_file is not None:
start_iter = listner.load(os.path.join(args.resume_file))
if default_gpu:
write_to_record_file(
"\nLOAD the model from {}, iteration {}".format(args.resume_file, start_iter),
record_file
)
# first evaluation
if args.eval_first:
loss_str = "validation before training"
for env_name, env in val_envs.items():
listner.env = env
# Get validation distance from goal under test evaluation conditions
listner.test(use_dropout=False, feedback='argmax', iters=None)
preds = listner.get_results()
# gather distributed results
preds = merge_dist_results(all_gather(preds))
if default_gpu:
score_summary, _ = env.eval_metrics(preds)
loss_str += ", %s " % env_name
for metric, val in score_summary.items():
loss_str += ', %s: %.2f' % (metric, val)
if default_gpu:
write_to_record_file(loss_str, record_file)
start = time.time()
if default_gpu:
write_to_record_file(
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
)
best_val = {
'val_unseen_instrs': {'spl': 0., 'sr': 0., 'state': ''},
'val_unseen_house': {'spl': 0., 'sr': 0., 'state': ''},
}
for idx in range(start_iter, start_iter+args.iters, args.log_every):
listner.logs = defaultdict(list)
interval = min(args.log_every, args.iters-idx)
iter = idx + interval
# Train for log_every interval
if aug_env is None:
listner.env = train_env
listner.train(interval, feedback=args.feedback) # Train interval iters
else:
jdx_length = len(range(interval // 2))
for jdx in range(interval // 2):
# Train with GT data
listner.env = train_env
# args.ml_weight = 0.2
listner.train(1, feedback=args.feedback)
# Train with Augmented data
listner.env = aug_env
# args.ml_weight = 0.2
listner.train(1, feedback=args.feedback)
if default_gpu:
print_progress(jdx, jdx_length, prefix='Progress:', suffix='Complete', bar_length=50)
if default_gpu:
# Log the training stats to tensorboard
total = max(sum(listner.logs['total']), 1) # RL: total valid actions for all examples in the batch
length = max(len(listner.logs['critic_loss']), 1) # RL: total (max length) in the batch
critic_loss = sum(listner.logs['critic_loss']) / total
policy_loss = sum(listner.logs['policy_loss']) / total
RL_loss = sum(listner.logs['RL_loss']) / max(len(listner.logs['RL_loss']), 1)
IL_loss = sum(listner.logs['IL_loss']) / max(len(listner.logs['IL_loss']), 1)
Point_loss = sum(listner.logs['Point_loss']) / max(len(listner.logs['Point_loss']), 1)
entropy = sum(listner.logs['entropy']) / total
writer.add_scalar("loss/critic", critic_loss, idx)
writer.add_scalar("policy_entropy", entropy, idx)
writer.add_scalar("loss/RL_loss", RL_loss, idx)
writer.add_scalar("loss/IL_loss", IL_loss, idx)
writer.add_scalar("loss/Point_loss", Point_loss, idx)
writer.add_scalar("total_actions", total, idx)
writer.add_scalar("max_length", length, idx)
write_to_record_file(
"\ntotal_actions %d, max_length %d, entropy %.4f, IL_loss %.4f, Point_loss %.4f, RL_loss %.4f, policy_loss %.4f, critic_loss %.4f" % (
total, length, entropy, IL_loss, Point_loss, RL_loss, policy_loss, critic_loss),
record_file
)
# Run validation
loss_str = "iter {}".format(iter)
for env_name, env in val_envs.items():
listner.env = env
# Get validation distance from goal under test evaluation conditions
listner.test(use_dropout=False, feedback='argmax', iters=None)
preds = listner.get_results()
preds = merge_dist_results(all_gather(preds))
if default_gpu:
score_summary, _ = env.eval_metrics(preds)
loss_str += ", %s " % env_name
for metric, val in score_summary.items():
loss_str += ', %s: %.2f' % (metric, val)
writer.add_scalar('%s/%s' % (metric, env_name), score_summary[metric], idx)
# select model by spl
if env_name in best_val:
if score_summary['spl'] >= best_val[env_name]['spl']:
best_val[env_name]['spl'] = score_summary['spl']
best_val[env_name]['sr'] = score_summary['sr']
best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
listner.save(idx, os.path.join(args.ckpt_dir, "best_%s" % (env_name)))
if default_gpu:
listner.save(idx, os.path.join(args.ckpt_dir, "latest_dict"))
write_to_record_file(
('%s (%d %d%%) %s' % (timeSince(start, float(iter)/args.iters), iter, float(iter)/args.iters*100, loss_str)),
record_file
)
write_to_record_file("BEST RESULT TILL NOW", record_file)
for env_name in best_val:
write_to_record_file(env_name + ' | ' + best_val[env_name]['state'], record_file)
def valid(args, train_env, val_envs, rank=-1):
default_gpu = is_default_gpu(args)
agent_class = SoonGMapObjectNavAgent
agent = agent_class(args, train_env, rank=rank)
if args.resume_file is not None:
print("Loaded the listener model at iter %d from %s\n" % (agent.load(args.resume_file), args.resume_file))
if default_gpu:
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
json.dump(vars(args), outf, indent=4)
record_file = os.path.join(args.log_dir, 'valid.txt')
write_to_record_file(str(args) + '\n\n', record_file)
for env_name, env in val_envs.items():
if os.path.exists(os.path.join(args.pred_dir, "submit_%s.json" % env_name)):
continue
agent.logs = defaultdict(list)
agent.env = env
iters = None
start_time = time.time()
agent.test(use_dropout=False, feedback='argmax', iters=iters)
print(env_name, 'cost time: %.2fs' % (time.time() - start_time))
preds = agent.get_results()
preds = merge_dist_results(all_gather(preds))
if default_gpu:
if 'test_v2' not in env_name:
score_summary, _ = env.eval_metrics(preds)
loss_str = "Env name: %s" % env_name
for metric, val in score_summary.items():
loss_str += ', %s: %.2f' % (metric, val)
write_to_record_file(loss_str+'\n', record_file)
if args.submit:
json.dump(
preds,
open(os.path.join(args.pred_dir, "submit_%s.json" % env_name), 'w'),
sort_keys=True, indent=4, separators=(',', ': ')
)
def main():
args = parse_args()
if args.world_size > 1:
rank = init_distributed(args)
torch.cuda.set_device(args.local_rank)
else:
rank = 0
set_random_seed(args.seed + rank)
train_env, val_envs, aug_env = build_dataset(args, rank=rank, is_test=args.test)
if not args.test:
train(args, train_env, val_envs, aug_env=aug_env, rank=rank)
else:
valid(args, train_env, val_envs, rank=rank)
if __name__ == '__main__':
main()