diff --git a/README.md b/README.md index 7579876..a11cbe5 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,15 @@ # Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigation -This repository is the official implementation of [Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigatio](). +This repository is the official implementation of [Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigation](https://arxiv.org/abs/2202.11742). + +Winner of the [ICCV 2021 Workshop Human Interaction for Robotic Navigation REVERIE & SOON Challenges](https://human-interaction4robotic-navigation.github.io/challenge.html). + +Project webpage: [https://cshizhe.github.io/projects/vln_duet.html](https://cshizhe.github.io/projects/vln_duet.html). + +Following language instructions to navigate in unseen environments is a challenging problem for autonomous embodied agents. The agent not only needs to ground languages in visual scenes, but also should explore the environment to reach its target. In this work, we propose a dual-scale graph transformer (DUET) for joint long-term action planning and fine-grained cross-modal understanding. We build a topological map on-the-fly to enable efficient exploration in global action space. To balance the complexity of large action space reasoning and fine-grained language grounding, we dynamically combine a fine-scale encoding over local observations and a coarse-scale encoding on a global map via graph transformers. The proposed approach, DUET, significantly outperforms state-of-the-art methods on goal-oriented vision-and-language navigation (VLN) benchmarks REVERIE and SOON. It also improves the success rate on the fine-grained VLN benchmark R2R. + +![framework](files/teaser.png) + ## Requirements @@ -16,7 +25,7 @@ conda activate vlnduet pip install -r requirements.txt ``` -3. Download data from [Dropbox](https://www.dropbox.com/s/7bijvxdw3rf451c/datasets.tar.gz?dl=0), including processed annotations, features and pretrained models. Put the data in `datasets' directory. +3. Download data from [Dropbox](https://www.dropbox.com/s/7bijvxdw3rf451c/datasets.tar.gz?dl=0), including processed annotations, features and pretrained models of REVERIE, SOON, R2R and R4R datasets. Put the data in `datasets' directory. 4. Download pretrained lxmert ``` @@ -25,19 +34,17 @@ wget https://nlp.cs.unc.edu/data/model_LXRT.pth -P datasets/pretrained ``` ## Pretraining + Combine behavior cloning and auxiliary proxy tasks in pretraining: ```pretrain cd pretrain_src -bash run_reverie.sh # (run_soon.sh, run_r2r.sh) +bash run_reverie.sh # (run_soon.sh, run_r2r.sh, run_r4r.sh) ``` ## Fine-tuning & Evaluation -Combine behavior cloning and auxiliary proxy tasks in pretraining: +Use pseudo interative demonstrator to fine-tune the model: ```finetune cd map_nav_src bash scripts/run_reverie.sh # (run_soon.sh, run_r2r.sh) ``` - -## Examples -Video examples can be found [here](https://www.dropbox.com/sh/g8vqygz7fgerg9s/AAAZ3gd9WdReUgRezxLnb1f_a?dl=0). diff --git a/files/teaser.png b/files/teaser.png new file mode 100644 index 0000000..83a061c Binary files /dev/null and b/files/teaser.png differ diff --git a/map_nav_src/r2r/agent.py b/map_nav_src/r2r/agent.py index 35e2cb4..f43e7ba 100644 --- a/map_nav_src/r2r/agent.py +++ b/map_nav_src/r2r/agent.py @@ -236,6 +236,54 @@ class GMapNavAgent(Seq2SeqAgent): return torch.from_numpy(a).cuda() + def _teacher_action_r4r( + self, obs, vpids, ended, visited_masks=None, imitation_learning=False, t=None, traj=None + ): + """R4R is not the shortest path. The goal location can be visited nodes. + """ + a = np.zeros(len(obs), dtype=np.int64) + for i, ob in enumerate(obs): + if ended[i]: # Just ignore this index + a[i] = self.args.ignoreid + else: + if imitation_learning: + assert ob['viewpoint'] == ob['gt_path'][t] + if t == len(ob['gt_path']) - 1: + a[i] = 0 # stop + else: + goal_vp = ob['gt_path'][t + 1] + for j, vpid in enumerate(vpids[i]): + if goal_vp == vpid: + a[i] = j + break + else: + if ob['viewpoint'] == ob['gt_path'][-1]: + a[i] = 0 # Stop if arrived + else: + scan = ob['scan'] + cur_vp = ob['viewpoint'] + min_idx, min_dist = self.args.ignoreid, float('inf') + for j, vpid in enumerate(vpids[i]): + if j > 0 and ((visited_masks is None) or (not visited_masks[i][j])): + if self.args.expert_policy == 'ndtw': + dist = - cal_dtw( + self.env.shortest_distances[scan], + sum(traj[i]['path'], []) + self.env.shortest_paths[scan][ob['viewpoint']][vpid][1:], + ob['gt_path'], + threshold=3.0 + )['nDTW'] + elif self.args.expert_policy == 'spl': + # dist = min([self.env.shortest_distances[scan][vpid][end_vp] for end_vp in ob['gt_end_vps']]) + dist = self.env.shortest_distances[scan][vpid][ob['gt_path'][-1]] \ + + self.env.shortest_distances[scan][cur_vp][vpid] + if dist < min_dist: + min_dist = dist + min_idx = j + a[i] = min_idx + if min_idx == self.args.ignoreid: + print('scan %s: all vps are searched' % (scan)) + return torch.from_numpy(a).cuda() + def make_equiv_action(self, a_t, gmaps, obs, traj=None): """ Interface between Panoramic view and Egocentric view @@ -355,10 +403,22 @@ class GMapNavAgent(Seq2SeqAgent): if train_ml is not None: # Supervised training - nav_targets = self._teacher_action( - obs, nav_vpids, ended, - visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None - ) + if self.args.dataset == 'r2r': + # nav_targets = self._teacher_action( + # obs, nav_vpids, ended, + # visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None + # ) + nav_targets = self._teacher_action_r4r( + obs, nav_vpids, ended, + visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None, + imitation_learning=(self.feedback=='teacher'), t=t, traj=traj + ) + elif self.args.dataset == 'r4r': + nav_targets = self._teacher_action_r4r( + obs, nav_vpids, ended, + visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None, + imitation_learning=(self.feedback=='teacher'), t=t, traj=traj + ) # print(t, nav_logits, nav_targets) ml_loss += self.criterion(nav_logits, nav_targets) # print(t, 'ml_loss', ml_loss.item(), self.criterion(nav_logits, nav_targets).item()) diff --git a/map_nav_src/r2r/data_utils.py b/map_nav_src/r2r/data_utils.py index 384eff1..abbdec8 100644 --- a/map_nav_src/r2r/data_utils.py +++ b/map_nav_src/r2r/data_utils.py @@ -19,6 +19,10 @@ def load_instr_datasets(anno_dir, dataset, splits, tokenizer, is_test=True): if split == 'val_train_seen': new_data = new_data[:50] + if not is_test: + if dataset == 'r4r' and split == 'val_unseen': + ridxs = np.random.permutation(len(new_data))[:200] + new_data = [new_data[ridx] for ridx in ridxs] else: # augmented data print('\nLoading augmented data %s for pretraining...' % os.path.basename(split)) with open(split) as f: diff --git a/map_nav_src/r2r/main_nav.py b/map_nav_src/r2r/main_nav.py index 4ea9f8c..5484f61 100644 --- a/map_nav_src/r2r/main_nav.py +++ b/map_nav_src/r2r/main_nav.py @@ -59,8 +59,10 @@ def build_dataset(args, rank=0, is_test=False): # val_env_names = ['val_train_seen'] val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] + if args.dataset == 'r4r' and (not args.test): + val_env_names[-1] == 'val_unseen_sampled' - if args.submit: + if args.submit and args.dataset != 'r4r': val_env_names.append('test') val_envs = {} @@ -129,6 +131,8 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1): ) best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}} + if args.dataset == 'r4r': + best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}} for idx in range(start_iter, start_iter+args.iters, args.log_every): listner.logs = defaultdict(list) diff --git a/map_nav_src/r2r/parser.py b/map_nav_src/r2r/parser.py index 5efc214..b35cbcb 100644 --- a/map_nav_src/r2r/parser.py +++ b/map_nav_src/r2r/parser.py @@ -6,7 +6,7 @@ def parse_args(): parser = argparse.ArgumentParser(description="") parser.add_argument('--root_dir', type=str, default='../datasets') - parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r']) + parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r', 'r4r']) parser.add_argument('--output_dir', type=str, default='default', help='experiment id') parser.add_argument('--seed', type=int, default=0) diff --git a/map_nav_src/scripts/run_r2r.sh b/map_nav_src/scripts/run_r2r.sh index ea433bb..c6e0a42 100644 --- a/map_nav_src/scripts/run_r2r.sh +++ b/map_nav_src/scripts/run_r2r.sh @@ -1,3 +1,4 @@ +DATA_ROOT=../datasets train_alg=dagger @@ -56,7 +57,7 @@ flag="--root_dir ${DATA_ROOT} # train CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \ --tokenizer bert \ - --bert_ckpt_file '' \ + --bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \ --eval_first # test diff --git a/map_nav_src/scripts/run_r4r.sh b/map_nav_src/scripts/run_r4r.sh new file mode 100644 index 0000000..ad77c19 --- /dev/null +++ b/map_nav_src/scripts/run_r4r.sh @@ -0,0 +1,66 @@ + +train_alg=dagger + +features=vitbase +ft_dim=768 +obj_features=vitbase +obj_ft_dim=768 + +ngpus=1 +seed=0 + +name=${train_alg}-${features} +name=${name}-seed.${seed} +name=${name}-init.aug.45k + +outdir=${DATA_ROOT}/R2R/exprs_map/finetune/${name} + +flag="--root_dir ${DATA_ROOT} + --dataset r4r + --output_dir ${outdir} + --world_size ${ngpus} + --seed ${seed} + --tokenizer bert + + --enc_full_graph + --graph_sprels + --fusion dynamic + + --expert_policy spl + --train_alg ${train_alg} + + --num_l_layers 9 + --num_x_layers 4 + --num_pano_layers 2 + + --max_action_len 15 + --max_instr_len 200 + + --batch_size 8 + --lr 1e-5 + --iters 200000 + --log_every 1000 + --optim adamW + + --features ${features} + --image_feat_size ${ft_dim} + --angle_feat_size 4 + + --ml_weight 0.2 + + --feat_dropout 0.4 + --dropout 0.5 + + --gamma 0." + +# train +CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \ + --tokenizer bert \ + --bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \ + --eval_first + +# test +CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \ + --tokenizer bert \ + --resume_file ../datasets/R2R/trained_models/best_val_unseen \ + --test --submit \ No newline at end of file diff --git a/map_nav_src/scripts/run_reverie.sh b/map_nav_src/scripts/run_reverie.sh index 418127f..69edfb7 100644 --- a/map_nav_src/scripts/run_reverie.sh +++ b/map_nav_src/scripts/run_reverie.sh @@ -61,7 +61,7 @@ flag="--root_dir ${DATA_ROOT} # train CUDA_VISIBLE_DEVICES='0' python reverie/main_nav_obj.py $flag \ --tokenizer bert \ - --bert_ckpt_file '' \ + --bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \ --eval_first # test diff --git a/map_nav_src/scripts/run_soon.sh b/map_nav_src/scripts/run_soon.sh index 330f595..3ebb048 100644 --- a/map_nav_src/scripts/run_soon.sh +++ b/map_nav_src/scripts/run_soon.sh @@ -1,3 +1,4 @@ +DATA_ROOT=../datasets train_alg=dagger @@ -60,7 +61,7 @@ flag="--root_dir ${DATA_ROOT} CUDA_VISIBLE_DEVICES='0' python soon/main.py $flag \ --tokenizer bert \ - --bert_ckpt_file '' \ + --bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \ --eval_first # test diff --git a/pretrain_src/config/r2r_pretrain.json b/pretrain_src/config/r2r_pretrain.json index e9608a1..b6dc8de 100644 --- a/pretrain_src/config/r2r_pretrain.json +++ b/pretrain_src/config/r2r_pretrain.json @@ -4,8 +4,8 @@ "output_dir": "", "mrc_mask_prob": 0.15, "max_txt_len": 200, - "train_batch_size": 16, - "val_batch_size": 16, + "train_batch_size": 64, + "val_batch_size": 64, "gradient_accumulation_steps": 1, "learning_rate": 5e-05, "valid_steps": 2500, diff --git a/pretrain_src/config/r4r_pretrain.json b/pretrain_src/config/r4r_pretrain.json new file mode 100644 index 0000000..ef6eeab --- /dev/null +++ b/pretrain_src/config/r4r_pretrain.json @@ -0,0 +1,48 @@ +{ + "model_config": "", + "checkpoint": null, + "output_dir": "", + "mrc_mask_prob": 0.15, + "max_txt_len": 200, + "train_batch_size": 32, + "val_batch_size": 32, + "gradient_accumulation_steps": 1, + "learning_rate": 5e-05, + "valid_steps": 5000, + "log_steps": 1000, + "num_train_steps": 100000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 5.0, + "warmup_steps": 10000, + "seed": 0, + "fp16": false, + "n_workers": 1, + "pin_mem": true, + "init_pretrained": "lxmert", + + "train_datasets": { + "R4R": { + "name": "R4R", + "train_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_train_enc.jsonl"], + "val_seen_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_val_seen_enc.jsonl"], + "val_unseen_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_val_unseen_sampled_enc.jsonl"], + "connectivity_dir": "../datasets/R2R/connectivity", + "img_ft_file": "../datasets/R2R/features/pth_vit_base_patch16_224_imagenet.hdf5", + "scanvp_cands_file": "../datasets/R2R/annotations/scanvp_candview_relangles.json", + "tasks": [ + "mlm", + "sap" + ], + "mix_ratio": [ + 1, + 1 + ] + } + } +} diff --git a/pretrain_src/config/reverie_obj_pretrain.json b/pretrain_src/config/reverie_obj_pretrain.json index 98bd3b3..6d2a179 100644 --- a/pretrain_src/config/reverie_obj_pretrain.json +++ b/pretrain_src/config/reverie_obj_pretrain.json @@ -7,8 +7,8 @@ "nearby_vp_steps": null, "max_objects": 20, "max_txt_len": 200, - "train_batch_size": 16, - "val_batch_size": 16, + "train_batch_size": 32, + "val_batch_size": 32, "gradient_accumulation_steps": 1, "learning_rate": 5e-05, "valid_steps": 4000, diff --git a/pretrain_src/config/soon_obj_pretrain.json b/pretrain_src/config/soon_obj_pretrain.json index eca5186..380296a 100644 --- a/pretrain_src/config/soon_obj_pretrain.json +++ b/pretrain_src/config/soon_obj_pretrain.json @@ -7,13 +7,13 @@ "nearby_vp_steps": null, "max_objects": 100, "max_txt_len": 200, - "train_batch_size": 8, - "val_batch_size": 8, + "train_batch_size": 32, + "val_batch_size": 32, "gradient_accumulation_steps": 1, "learning_rate": 5e-05, - "valid_steps": 1500, + "valid_steps": 2000, "log_steps": 1000, - "num_train_steps": 30000, + "num_train_steps": 40000, "optim": "adamw", "betas": [ 0.9, @@ -41,10 +41,12 @@ "scanvp_cands_file": "../datasets/R2R/annotations/scanvp_candview_relangles.json", "tasks": [ "mlm", + "mrc", "sap", "og" ], "mix_ratio": [ + 1, 1, 1, 1 diff --git a/pretrain_src/data/dataset.py b/pretrain_src/data/dataset.py index 64d712b..822d5b7 100644 --- a/pretrain_src/data/dataset.py +++ b/pretrain_src/data/dataset.py @@ -387,7 +387,7 @@ class R2RTextPathData(ReverieTextPathData): # local: for k, cand_vp in enumerate(traj_cand_vpids[-1]): if cand_vp == gt_next_vp: - local_act_labels = k + 1 # [stop] is 0 + local_act_label = k + 1 # [stop] is 0 break return global_act_label, local_act_label diff --git a/pretrain_src/run_r2r.sh b/pretrain_src/run_r2r.sh index ec826b1..4fb682b 100644 --- a/pretrain_src/run_r2r.sh +++ b/pretrain_src/run_r2r.sh @@ -1,10 +1,10 @@ NODE_RANK=0 -NUM_GPUS=4 +NUM_GPUS=1 outdir=../datasets/R2R/exprs_map/pretrain/cmt-vitbase-mlm.mrc.sap-init.lxmert-aug.speaker # train -CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch \ +CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \ --nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \ train_r2r.py --world_size ${NUM_GPUS} \ --vlnbert cmt \ diff --git a/pretrain_src/run_r4r.sh b/pretrain_src/run_r4r.sh new file mode 100644 index 0000000..4328120 --- /dev/null +++ b/pretrain_src/run_r4r.sh @@ -0,0 +1,12 @@ +NODE_RANK=0 +NUM_GPUS=1 +outdir=../datasets/R4R/exprs_map/pretrain/cmt-vitbase-mlm.sap-init.lxmert + +# train +CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \ + --nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \ + train_r4r.py --world_size ${NUM_GPUS} \ + --vlnbert cmt \ + --model_config config/r2r_model_config.json \ + --config config/r4r_pretrain.json \ + --output_dir $outdir diff --git a/pretrain_src/run_reverie.sh b/pretrain_src/run_reverie.sh index 2ea0a8e..a00d118 100644 --- a/pretrain_src/run_reverie.sh +++ b/pretrain_src/run_reverie.sh @@ -1,9 +1,9 @@ NODE_RANK=0 -NUM_GPUS=2 +NUM_GPUS=1 outdir=../datasets/REVERIE/exprs_map/pretrain/cmt-vitbase-mlm.mrc.sap.og-init.lxmert-aug.speaker # train -CUDA_VISIBLE_DEVICES='0,1' python -m torch.distributed.launch \ +CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \ --nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \ train_reverie_obj.py --world_size ${NUM_GPUS} \ --vlnbert cmt \ diff --git a/pretrain_src/run_soon.sh b/pretrain_src/run_soon.sh index f668cf7..aedbb15 100644 --- a/pretrain_src/run_soon.sh +++ b/pretrain_src/run_soon.sh @@ -1,9 +1,9 @@ NODE_RANK=0 -NUM_GPUS=4 +NUM_GPUS=1 outdir=../datasets/SOON/exprs_map/pretrain/cmt-vitbase.butdobj-mlm.sap.og-init.lxmert # train -CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch \ +CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \ --nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \ train_soon_obj.py --world_size ${NUM_GPUS} \ --vlnbert cmt \ diff --git a/pretrain_src/train_r4r.py b/pretrain_src/train_r4r.py new file mode 100644 index 0000000..d7f130e --- /dev/null +++ b/pretrain_src/train_r4r.py @@ -0,0 +1,460 @@ +import os +import sys +import json +import argparse +import time +from collections import defaultdict +from easydict import EasyDict +from tqdm import tqdm + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +import torch.cuda.amp as amp # TODO + +from transformers import AutoTokenizer, PretrainedConfig +from transformers import AutoModel + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, set_dropout, set_random_seed, set_cuda, wrap_model +from utils.distributed import all_gather + +from optim import get_lr_sched +from optim.misc import build_optimizer + +from parser import load_parser, parse_with_config + +from data.loader import MetaLoader, PrefetchLoader, build_dataloader +from data.dataset import R2RTextPathData +from data.tasks import ( + MlmDataset, mlm_collate, + MrcDataset, mrc_collate, + SapDataset, sap_collate) + +from model.pretrain_cmt import GlocalTextPathCMTPreTraining + + +def create_dataloaders( + data_cfg, nav_db, tok, is_train: bool, device: torch.device, opts +): + dataloaders = {} + for k, task_name in enumerate(data_cfg.tasks): + if task_name == 'mlm': + task_dataset = MlmDataset(nav_db, tok) + task_collate_fn = mlm_collate + elif task_name == 'mrc': + task_dataset = MrcDataset(nav_db, tok, opts.mrc_mask_prob, end_vp_pos_ratio=0.2) + task_collate_fn = mrc_collate + elif task_name == 'sap': + task_dataset = SapDataset(nav_db, tok, end_vp_pos_ratio=0.2) + task_collate_fn = sap_collate + else: + raise ValueError(f'Undefined task {task}') + + LOGGER.info(f"{task_name}: {len(task_dataset)} samples loaded") + + task_loader, pre_epoch = build_dataloader( + task_name, task_dataset, task_collate_fn, is_train, opts + ) + + if is_train: + ratio = data_cfg.mix_ratio[k] + dataloaders[task_name] = (task_loader, ratio, pre_epoch) + else: + dataloaders[task_name] = PrefetchLoader(task_loader, device) + return dataloaders + + +def main(opts): + default_gpu, n_gpu, device = set_cuda(opts) + + if default_gpu: + LOGGER.info( + 'device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}'.format( + device, n_gpu, bool(opts.local_rank != -1), opts.fp16 + ) + ) + + seed = opts.seed + if opts.local_rank != -1: + seed += opts.rank + set_random_seed(seed) + + if default_gpu: + save_training_meta(opts) + TB_LOGGER.create(os.path.join(opts.output_dir, 'logs')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(os.path.join(opts.output_dir, 'ckpts')) + add_log_to_file(os.path.join(opts.output_dir, 'logs', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + # Model config + model_config = PretrainedConfig.from_json_file(opts.model_config) + model_config.pretrain_tasks = [] + for train_dataset_config in opts.train_datasets.values(): + model_config.pretrain_tasks.extend(train_dataset_config['tasks']) + model_config.pretrain_tasks = set(model_config.pretrain_tasks) + + tokenizer = AutoTokenizer.from_pretrained(model_config.lang_bert_name) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint, map_location=lambda storage, loc: storage) + else: + checkpoint = {} + if opts.init_pretrained == 'bert': + tmp = AutoModel.from_pretrained(model_config.lang_bert_name) + for param_name, param in tmp.named_parameters(): + checkpoint[param_name] = param + if model_config.lang_bert_name == 'xlm-roberta-base': + # embeddings.token_type_embeddings.weight (1 -> 2, the second is for image embedding) + checkpoint['embeddings.token_type_embeddings.weight'] = torch.cat( + [checkpoint['embeddings.token_type_embeddings.weight']] * 2, 0 + ) + del tmp + elif opts.init_pretrained == 'lxmert': + tmp = torch.load( + '../datasets/pretrained/LXMERT/model_LXRT.pth', + map_location=lambda storage, loc: storage + ) + for param_name, param in tmp.items(): + param_name = param_name.replace('module.', '') + if 'bert.encoder.layer' in param_name: + param_name = param_name.replace('bert.encoder.layer', 'bert.lang_encoder.layer') + checkpoint[param_name] = param + elif 'bert.encoder.x_layers' in param_name: + param_name1 = param_name.replace('bert.encoder.x_layers', 'bert.local_encoder.encoder.x_layers') + param_name2 = param_name.replace('bert.encoder.x_layers', 'bert.global_encoder.encoder.x_layers') + checkpoint[param_name1] = checkpoint[param_name2] = param + elif 'cls.predictions' in param_name: + param_name = param_name.replace('cls.predictions', 'mlm_head.predictions') + checkpoint[param_name] = param + else: + checkpoint[param_name] = param + del tmp + + model_class = GlocalTextPathCMTPreTraining + + # update some training configs + model = model_class.from_pretrained( + pretrained_model_name_or_path=None, config=model_config, state_dict=checkpoint + ) + model.train() + set_dropout(model, opts.dropout) + model = wrap_model(model, device, opts.local_rank) + del checkpoint + + # load data training set + data_cfg = EasyDict(opts.train_datasets['R4R']) + train_nav_db = R2RTextPathData( + data_cfg.train_traj_files, data_cfg.img_ft_file, + data_cfg.scanvp_cands_file, data_cfg.connectivity_dir, + image_prob_size=model_config.image_prob_size, + image_feat_size=model_config.image_feat_size, + angle_feat_size=model_config.angle_feat_size, + max_txt_len=opts.max_txt_len, in_memory=True, + act_visited_node=True + ) + val_nav_db = R2RTextPathData( + data_cfg.val_seen_traj_files, data_cfg.img_ft_file, + data_cfg.scanvp_cands_file, data_cfg.connectivity_dir, + image_prob_size=model_config.image_prob_size, + image_feat_size=model_config.image_feat_size, + angle_feat_size=model_config.angle_feat_size, + max_txt_len=opts.max_txt_len, in_memory=True, + act_visited_node=True + ) + val2_nav_db = R2RTextPathData( + data_cfg.val_unseen_traj_files, data_cfg.img_ft_file, + data_cfg.scanvp_cands_file, data_cfg.connectivity_dir, + image_prob_size=model_config.image_prob_size, + image_feat_size=model_config.image_feat_size, + angle_feat_size=model_config.angle_feat_size, + max_txt_len=opts.max_txt_len, in_memory=True, + act_visited_node=True + ) + + # Build data loaders + train_dataloaders = create_dataloaders( + data_cfg, train_nav_db, tokenizer, True, device, opts + ) + val_dataloaders = create_dataloaders( + data_cfg, val_nav_db, tokenizer, False, device, opts + ) + val2_dataloaders = create_dataloaders( + data_cfg, val2_nav_db, tokenizer, False, device, opts + ) + meta_loader = MetaLoader( + train_dataloaders, + accum_steps=opts.gradient_accumulation_steps, + distributed=opts.local_rank != -1, + device=device + ) + meta_loader = PrefetchLoader(meta_loader, device) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())} + + if opts.fp16: + grad_scaler = amp.GradScaler() + + global_step = 0 + LOGGER.info(f"***** Running training with {opts.world_size} GPUs *****") + LOGGER.info(" Batch size = %d", opts.train_batch_size if opts.local_rank == -1 else opts.train_batch_size * opts.world_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + # to compute training statistics + task2loss = {task: RunningMeter(f'loss/{task}') + for task in train_dataloaders.keys()} + + n_examples = defaultdict(int) + n_in_units = defaultdict(int) + n_loss_units = defaultdict(int) + grad_norm = 0 + + start_time = time.time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + for step, (name, batch) in enumerate(meta_loader): + # forward pass + n_examples[name] += batch['txt_ids'].size(0) + n_in_units[name] += batch['txt_lens'].sum().item() + task = name.split('_')[0] + # print(name, task) + # for k, v in batch.items(): + # print(k, v.size()) + # continue + if opts.fp16: + with amp.autocast(): + loss = model(batch, task=task, compute_loss=True) + else: + loss = model(batch, task=task, compute_loss=True) + + n_loss_units[name] += loss.size(0) + loss = loss.mean() # loss is not normalized in model + + # backward pass + if args.gradient_accumulation_steps > 1: # average loss + loss = loss / args.gradient_accumulation_steps + + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + if opts.fp16: + grad_scaler.scale(loss).backward() + else: + loss.backward() + + task2loss[name](loss.item()) + + # optimizer update and logging + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + # NOTE: not gathered across GPUs for efficiency + TB_LOGGER.log_scalar_dict({ll.name: ll.val + for ll in task2loss.values() + if ll.val is not None}) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + if opts.fp16: + grad_scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), opts.grad_norm + ) + # print(step, name, grad_norm) + # for k, v in model.named_parameters(): + # if v.grad is not None: + # v = torch.norm(v).data.item() + # print(k, v) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + if opts.fp16: + grad_scaler.step(optimizer) + grad_scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % opts.log_steps == 0: + # monitor training throughput + LOGGER.info(f'==============Step {global_step}===============') + for t in train_dataloaders.keys(): + tot_ex = n_examples[t] + ex_per_sec = int(tot_ex / (time.time() - start_time)) + tot_in = n_in_units[t] + in_per_sec = int(tot_in / (time.time() - start_time)) + tot_l = n_loss_units[t] + l_per_sec = int(tot_l / (time.time() - start_time)) + LOGGER.info(f'{t}: {tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec, + global_step) + TB_LOGGER.add_scalar(f'perf/{t}_in_per_s', in_per_sec, + global_step) + TB_LOGGER.add_scalar(f'perf/{t}_loss_per_s', l_per_sec, + global_step) + LOGGER.info('===============================================') + + if global_step % opts.valid_steps == 0: + LOGGER.info(f'------Step {global_step}: start validation seen------') + validate(model, val_dataloaders, setname='_seen') + LOGGER.info(f'------Step {global_step}: start validation unseen------') + validate(model, val2_dataloaders, setname='_unseen') + model_saver.save(model, global_step) + if global_step >= opts.num_train_steps: + break + if global_step % opts.valid_steps != 0: + LOGGER.info(f'------Step {global_step}: start validation seen------') + validate(model, val_dataloaders, setname='_seen') + LOGGER.info(f'------Step {global_step}: start validation unseen------') + validate(model, val2_dataloaders, setname='_unseen') + model_saver.save(model, global_step) + + +def validate(model, val_dataloaders, setname=''): + model.eval() + for task, loader in val_dataloaders.items(): + LOGGER.info(f"validate val{setname} on {task} task") + if task.startswith('mlm'): + val_log = validate_mlm(model, loader) + elif task.startswith('mrc'): + val_log = validate_mrc(model, loader) + elif task.startswith('sap'): + val_log = validate_sap(model, loader) + else: + raise ValueError(f'Undefined task {task}') + val_log = {f'val{setname}_{task}_{k}': v for k, v in val_log.items()} + TB_LOGGER.log_scalar_dict( + {f'valid{setname}_{task}/{k}': v for k, v in val_log.items()} + ) + model.train() + + +@torch.no_grad() +def validate_mlm(model, val_loader): + LOGGER.info("start running MLM validation...") + val_loss = 0 + n_correct = 0 + n_word = 0 + st = time.time() + for i, batch in enumerate(val_loader): + scores = model(batch, task='mlm', compute_loss=False) + labels = batch['txt_labels'] + labels = labels[labels != -1] + loss = F.cross_entropy(scores, labels, reduction='sum') + val_loss += loss.item() + n_correct += (scores.max(dim=-1)[1] == labels).sum().item() + n_word += labels.numel() + val_loss = sum(all_gather(val_loss)) + n_correct = sum(all_gather(n_correct)) + n_word = sum(all_gather(n_word)) + tot_time = time.time()-st + val_loss /= n_word + acc = n_correct / n_word + val_log = {'loss': val_loss, + 'acc': acc, + 'tok_per_s': n_word/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"acc: {acc*100:.2f}") + return val_log + +def compute_accuracy_for_soft_targets(out, labels): + outputs = out.max(dim=-1)[1] + labels = labels.max(dim=-1)[1] # argmax + n_correct = (outputs == labels).sum().item() + return n_correct + +@torch.no_grad() +def validate_mrc(model, val_loader): + LOGGER.info("start running MRC validation...") + val_loss = 0 + n_feat = 0 + st = time.time() + tot_score = 0 + for i, batch in enumerate(val_loader): + view_logits, view_targets, _, _ = model(batch, task='mrc', compute_loss=False) + view_logprobs = F.log_softmax(view_logits, dim=-1) + loss = F.kl_div(view_logprobs, view_targets, reduction='sum') + tot_score += compute_accuracy_for_soft_targets(view_logits, view_targets) + val_loss += loss.item() + n_feat += batch['vp_view_mrc_masks'].sum().item() + val_loss = sum(all_gather(val_loss)) + tot_score = sum(all_gather(tot_score)) + n_feat = sum(all_gather(n_feat)) + tot_time = time.time()-st + val_loss /= n_feat + val_acc = tot_score / n_feat + val_log = {'loss': val_loss, + 'acc': val_acc, + 'feat_per_s': n_feat/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score: {val_acc*100:.2f}") + return val_log + +@torch.no_grad() +def validate_sap(model, val_loader): + LOGGER.info("start running SAP validation...") + val_gloss, val_lloss, val_floss = 0, 0, 0 + n_gcorrect, n_lcorrect, n_fcorrect = 0, 0, 0 + n_data = 0 + st = time.time() + for i, batch in enumerate(val_loader): + global_logits, local_logits, fused_logits, global_act_labels, local_act_labels = \ + model(batch, task='sap', compute_loss=False) + val_gloss += F.cross_entropy(global_logits, global_act_labels, reduction='sum').data.item() + val_lloss += F.cross_entropy(local_logits, local_act_labels, reduction='sum').data.item() + val_floss += F.cross_entropy(fused_logits, global_act_labels, reduction='sum').data.item() + n_gcorrect += torch.sum(torch.argmax(global_logits, 1) == global_act_labels).item() + n_lcorrect += torch.sum(torch.argmax(local_logits, 1) == local_act_labels).item() + n_fcorrect += torch.sum(torch.argmax(fused_logits, 1) == global_act_labels).item() + n_data += len(global_act_labels) + + n_data = sum(all_gather(n_data)) + val_gloss = sum(all_gather(val_gloss)) / n_data + val_lloss = sum(all_gather(val_lloss)) / n_data + val_floss = sum(all_gather(val_floss)) / n_data + gacc = sum(all_gather(n_gcorrect)) / n_data + lacc = sum(all_gather(n_lcorrect)) / n_data + facc = sum(all_gather(n_fcorrect)) / n_data + + tot_time = time.time()-st + val_log = {'gloss': val_gloss, 'lloss': val_lloss, 'floss': val_floss, + 'gacc': gacc, 'lacc': lacc, 'facc': facc, + 'tok_per_s': n_data/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"gacc: {gacc*100:.2f}, lacc: {lacc*100:.2f}, facc: {facc*100:.2f}") + return val_log + +def build_args(): + parser = load_parser() + + opts = parse_with_config(parser) + + if os.path.exists(opts.output_dir) and os.listdir(opts.output_dir): + LOGGER.warning( + "Output directory ({}) already exists and is not empty.".format( + opts.output_dir + ) + ) + + return opts + +if __name__ == '__main__': + args = build_args() + main(args)