init
This commit is contained in:
parent
7a8a5dee38
commit
89214a7c44
21
README.md
21
README.md
@ -1,6 +1,15 @@
|
||||
# Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigation
|
||||
|
||||
This repository is the official implementation of [Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigatio]().
|
||||
This repository is the official implementation of [Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigation](https://arxiv.org/abs/2202.11742).
|
||||
|
||||
Winner of the [ICCV 2021 Workshop Human Interaction for Robotic Navigation REVERIE & SOON Challenges](https://human-interaction4robotic-navigation.github.io/challenge.html).
|
||||
|
||||
Project webpage: [https://cshizhe.github.io/projects/vln_duet.html](https://cshizhe.github.io/projects/vln_duet.html).
|
||||
|
||||
Following language instructions to navigate in unseen environments is a challenging problem for autonomous embodied agents. The agent not only needs to ground languages in visual scenes, but also should explore the environment to reach its target. In this work, we propose a dual-scale graph transformer (DUET) for joint long-term action planning and fine-grained cross-modal understanding. We build a topological map on-the-fly to enable efficient exploration in global action space. To balance the complexity of large action space reasoning and fine-grained language grounding, we dynamically combine a fine-scale encoding over local observations and a coarse-scale encoding on a global map via graph transformers. The proposed approach, DUET, significantly outperforms state-of-the-art methods on goal-oriented vision-and-language navigation (VLN) benchmarks REVERIE and SOON. It also improves the success rate on the fine-grained VLN benchmark R2R.
|
||||
|
||||

|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
@ -16,7 +25,7 @@ conda activate vlnduet
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Download data from [Dropbox](https://www.dropbox.com/s/7bijvxdw3rf451c/datasets.tar.gz?dl=0), including processed annotations, features and pretrained models. Put the data in `datasets' directory.
|
||||
3. Download data from [Dropbox](https://www.dropbox.com/s/7bijvxdw3rf451c/datasets.tar.gz?dl=0), including processed annotations, features and pretrained models of REVERIE, SOON, R2R and R4R datasets. Put the data in `datasets' directory.
|
||||
|
||||
4. Download pretrained lxmert
|
||||
```
|
||||
@ -25,19 +34,17 @@ wget https://nlp.cs.unc.edu/data/model_LXRT.pth -P datasets/pretrained
|
||||
```
|
||||
|
||||
## Pretraining
|
||||
|
||||
Combine behavior cloning and auxiliary proxy tasks in pretraining:
|
||||
```pretrain
|
||||
cd pretrain_src
|
||||
bash run_reverie.sh # (run_soon.sh, run_r2r.sh)
|
||||
bash run_reverie.sh # (run_soon.sh, run_r2r.sh, run_r4r.sh)
|
||||
```
|
||||
|
||||
## Fine-tuning & Evaluation
|
||||
|
||||
Combine behavior cloning and auxiliary proxy tasks in pretraining:
|
||||
Use pseudo interative demonstrator to fine-tune the model:
|
||||
```finetune
|
||||
cd map_nav_src
|
||||
bash scripts/run_reverie.sh # (run_soon.sh, run_r2r.sh)
|
||||
```
|
||||
|
||||
## Examples
|
||||
Video examples can be found [here](https://www.dropbox.com/sh/g8vqygz7fgerg9s/AAAZ3gd9WdReUgRezxLnb1f_a?dl=0).
|
||||
|
||||
BIN
files/teaser.png
Normal file
BIN
files/teaser.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 679 KiB |
@ -236,6 +236,54 @@ class GMapNavAgent(Seq2SeqAgent):
|
||||
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def _teacher_action_r4r(
|
||||
self, obs, vpids, ended, visited_masks=None, imitation_learning=False, t=None, traj=None
|
||||
):
|
||||
"""R4R is not the shortest path. The goal location can be visited nodes.
|
||||
"""
|
||||
a = np.zeros(len(obs), dtype=np.int64)
|
||||
for i, ob in enumerate(obs):
|
||||
if ended[i]: # Just ignore this index
|
||||
a[i] = self.args.ignoreid
|
||||
else:
|
||||
if imitation_learning:
|
||||
assert ob['viewpoint'] == ob['gt_path'][t]
|
||||
if t == len(ob['gt_path']) - 1:
|
||||
a[i] = 0 # stop
|
||||
else:
|
||||
goal_vp = ob['gt_path'][t + 1]
|
||||
for j, vpid in enumerate(vpids[i]):
|
||||
if goal_vp == vpid:
|
||||
a[i] = j
|
||||
break
|
||||
else:
|
||||
if ob['viewpoint'] == ob['gt_path'][-1]:
|
||||
a[i] = 0 # Stop if arrived
|
||||
else:
|
||||
scan = ob['scan']
|
||||
cur_vp = ob['viewpoint']
|
||||
min_idx, min_dist = self.args.ignoreid, float('inf')
|
||||
for j, vpid in enumerate(vpids[i]):
|
||||
if j > 0 and ((visited_masks is None) or (not visited_masks[i][j])):
|
||||
if self.args.expert_policy == 'ndtw':
|
||||
dist = - cal_dtw(
|
||||
self.env.shortest_distances[scan],
|
||||
sum(traj[i]['path'], []) + self.env.shortest_paths[scan][ob['viewpoint']][vpid][1:],
|
||||
ob['gt_path'],
|
||||
threshold=3.0
|
||||
)['nDTW']
|
||||
elif self.args.expert_policy == 'spl':
|
||||
# dist = min([self.env.shortest_distances[scan][vpid][end_vp] for end_vp in ob['gt_end_vps']])
|
||||
dist = self.env.shortest_distances[scan][vpid][ob['gt_path'][-1]] \
|
||||
+ self.env.shortest_distances[scan][cur_vp][vpid]
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_idx = j
|
||||
a[i] = min_idx
|
||||
if min_idx == self.args.ignoreid:
|
||||
print('scan %s: all vps are searched' % (scan))
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def make_equiv_action(self, a_t, gmaps, obs, traj=None):
|
||||
"""
|
||||
Interface between Panoramic view and Egocentric view
|
||||
@ -355,10 +403,22 @@ class GMapNavAgent(Seq2SeqAgent):
|
||||
|
||||
if train_ml is not None:
|
||||
# Supervised training
|
||||
nav_targets = self._teacher_action(
|
||||
obs, nav_vpids, ended,
|
||||
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
|
||||
)
|
||||
if self.args.dataset == 'r2r':
|
||||
# nav_targets = self._teacher_action(
|
||||
# obs, nav_vpids, ended,
|
||||
# visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
|
||||
# )
|
||||
nav_targets = self._teacher_action_r4r(
|
||||
obs, nav_vpids, ended,
|
||||
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None,
|
||||
imitation_learning=(self.feedback=='teacher'), t=t, traj=traj
|
||||
)
|
||||
elif self.args.dataset == 'r4r':
|
||||
nav_targets = self._teacher_action_r4r(
|
||||
obs, nav_vpids, ended,
|
||||
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None,
|
||||
imitation_learning=(self.feedback=='teacher'), t=t, traj=traj
|
||||
)
|
||||
# print(t, nav_logits, nav_targets)
|
||||
ml_loss += self.criterion(nav_logits, nav_targets)
|
||||
# print(t, 'ml_loss', ml_loss.item(), self.criterion(nav_logits, nav_targets).item())
|
||||
|
||||
@ -19,6 +19,10 @@ def load_instr_datasets(anno_dir, dataset, splits, tokenizer, is_test=True):
|
||||
if split == 'val_train_seen':
|
||||
new_data = new_data[:50]
|
||||
|
||||
if not is_test:
|
||||
if dataset == 'r4r' and split == 'val_unseen':
|
||||
ridxs = np.random.permutation(len(new_data))[:200]
|
||||
new_data = [new_data[ridx] for ridx in ridxs]
|
||||
else: # augmented data
|
||||
print('\nLoading augmented data %s for pretraining...' % os.path.basename(split))
|
||||
with open(split) as f:
|
||||
|
||||
@ -59,8 +59,10 @@ def build_dataset(args, rank=0, is_test=False):
|
||||
|
||||
# val_env_names = ['val_train_seen']
|
||||
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
if args.dataset == 'r4r' and (not args.test):
|
||||
val_env_names[-1] == 'val_unseen_sampled'
|
||||
|
||||
if args.submit:
|
||||
if args.submit and args.dataset != 'r4r':
|
||||
val_env_names.append('test')
|
||||
|
||||
val_envs = {}
|
||||
@ -129,6 +131,8 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
||||
)
|
||||
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
|
||||
if args.dataset == 'r4r':
|
||||
best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}}
|
||||
|
||||
for idx in range(start_iter, start_iter+args.iters, args.log_every):
|
||||
listner.logs = defaultdict(list)
|
||||
|
||||
@ -6,7 +6,7 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
|
||||
parser.add_argument('--root_dir', type=str, default='../datasets')
|
||||
parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r'])
|
||||
parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r', 'r4r'])
|
||||
parser.add_argument('--output_dir', type=str, default='default', help='experiment id')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
DATA_ROOT=../datasets
|
||||
|
||||
train_alg=dagger
|
||||
|
||||
@ -56,7 +57,7 @@ flag="--root_dir ${DATA_ROOT}
|
||||
# train
|
||||
CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \
|
||||
--tokenizer bert \
|
||||
--bert_ckpt_file '' \
|
||||
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
|
||||
--eval_first
|
||||
|
||||
# test
|
||||
|
||||
66
map_nav_src/scripts/run_r4r.sh
Normal file
66
map_nav_src/scripts/run_r4r.sh
Normal file
@ -0,0 +1,66 @@
|
||||
|
||||
train_alg=dagger
|
||||
|
||||
features=vitbase
|
||||
ft_dim=768
|
||||
obj_features=vitbase
|
||||
obj_ft_dim=768
|
||||
|
||||
ngpus=1
|
||||
seed=0
|
||||
|
||||
name=${train_alg}-${features}
|
||||
name=${name}-seed.${seed}
|
||||
name=${name}-init.aug.45k
|
||||
|
||||
outdir=${DATA_ROOT}/R2R/exprs_map/finetune/${name}
|
||||
|
||||
flag="--root_dir ${DATA_ROOT}
|
||||
--dataset r4r
|
||||
--output_dir ${outdir}
|
||||
--world_size ${ngpus}
|
||||
--seed ${seed}
|
||||
--tokenizer bert
|
||||
|
||||
--enc_full_graph
|
||||
--graph_sprels
|
||||
--fusion dynamic
|
||||
|
||||
--expert_policy spl
|
||||
--train_alg ${train_alg}
|
||||
|
||||
--num_l_layers 9
|
||||
--num_x_layers 4
|
||||
--num_pano_layers 2
|
||||
|
||||
--max_action_len 15
|
||||
--max_instr_len 200
|
||||
|
||||
--batch_size 8
|
||||
--lr 1e-5
|
||||
--iters 200000
|
||||
--log_every 1000
|
||||
--optim adamW
|
||||
|
||||
--features ${features}
|
||||
--image_feat_size ${ft_dim}
|
||||
--angle_feat_size 4
|
||||
|
||||
--ml_weight 0.2
|
||||
|
||||
--feat_dropout 0.4
|
||||
--dropout 0.5
|
||||
|
||||
--gamma 0."
|
||||
|
||||
# train
|
||||
CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \
|
||||
--tokenizer bert \
|
||||
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
|
||||
--eval_first
|
||||
|
||||
# test
|
||||
CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \
|
||||
--tokenizer bert \
|
||||
--resume_file ../datasets/R2R/trained_models/best_val_unseen \
|
||||
--test --submit
|
||||
@ -61,7 +61,7 @@ flag="--root_dir ${DATA_ROOT}
|
||||
# train
|
||||
CUDA_VISIBLE_DEVICES='0' python reverie/main_nav_obj.py $flag \
|
||||
--tokenizer bert \
|
||||
--bert_ckpt_file '' \
|
||||
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
|
||||
--eval_first
|
||||
|
||||
# test
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
DATA_ROOT=../datasets
|
||||
|
||||
train_alg=dagger
|
||||
|
||||
@ -60,7 +61,7 @@ flag="--root_dir ${DATA_ROOT}
|
||||
|
||||
CUDA_VISIBLE_DEVICES='0' python soon/main.py $flag \
|
||||
--tokenizer bert \
|
||||
--bert_ckpt_file '' \
|
||||
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
|
||||
--eval_first
|
||||
|
||||
# test
|
||||
|
||||
@ -4,8 +4,8 @@
|
||||
"output_dir": "",
|
||||
"mrc_mask_prob": 0.15,
|
||||
"max_txt_len": 200,
|
||||
"train_batch_size": 16,
|
||||
"val_batch_size": 16,
|
||||
"train_batch_size": 64,
|
||||
"val_batch_size": 64,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 5e-05,
|
||||
"valid_steps": 2500,
|
||||
|
||||
48
pretrain_src/config/r4r_pretrain.json
Normal file
48
pretrain_src/config/r4r_pretrain.json
Normal file
@ -0,0 +1,48 @@
|
||||
{
|
||||
"model_config": "",
|
||||
"checkpoint": null,
|
||||
"output_dir": "",
|
||||
"mrc_mask_prob": 0.15,
|
||||
"max_txt_len": 200,
|
||||
"train_batch_size": 32,
|
||||
"val_batch_size": 32,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 5e-05,
|
||||
"valid_steps": 5000,
|
||||
"log_steps": 1000,
|
||||
"num_train_steps": 100000,
|
||||
"optim": "adamw",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.98
|
||||
],
|
||||
"dropout": 0.1,
|
||||
"weight_decay": 0.01,
|
||||
"grad_norm": 5.0,
|
||||
"warmup_steps": 10000,
|
||||
"seed": 0,
|
||||
"fp16": false,
|
||||
"n_workers": 1,
|
||||
"pin_mem": true,
|
||||
"init_pretrained": "lxmert",
|
||||
|
||||
"train_datasets": {
|
||||
"R4R": {
|
||||
"name": "R4R",
|
||||
"train_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_train_enc.jsonl"],
|
||||
"val_seen_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_val_seen_enc.jsonl"],
|
||||
"val_unseen_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_val_unseen_sampled_enc.jsonl"],
|
||||
"connectivity_dir": "../datasets/R2R/connectivity",
|
||||
"img_ft_file": "../datasets/R2R/features/pth_vit_base_patch16_224_imagenet.hdf5",
|
||||
"scanvp_cands_file": "../datasets/R2R/annotations/scanvp_candview_relangles.json",
|
||||
"tasks": [
|
||||
"mlm",
|
||||
"sap"
|
||||
],
|
||||
"mix_ratio": [
|
||||
1,
|
||||
1
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -7,8 +7,8 @@
|
||||
"nearby_vp_steps": null,
|
||||
"max_objects": 20,
|
||||
"max_txt_len": 200,
|
||||
"train_batch_size": 16,
|
||||
"val_batch_size": 16,
|
||||
"train_batch_size": 32,
|
||||
"val_batch_size": 32,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 5e-05,
|
||||
"valid_steps": 4000,
|
||||
|
||||
@ -7,13 +7,13 @@
|
||||
"nearby_vp_steps": null,
|
||||
"max_objects": 100,
|
||||
"max_txt_len": 200,
|
||||
"train_batch_size": 8,
|
||||
"val_batch_size": 8,
|
||||
"train_batch_size": 32,
|
||||
"val_batch_size": 32,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 5e-05,
|
||||
"valid_steps": 1500,
|
||||
"valid_steps": 2000,
|
||||
"log_steps": 1000,
|
||||
"num_train_steps": 30000,
|
||||
"num_train_steps": 40000,
|
||||
"optim": "adamw",
|
||||
"betas": [
|
||||
0.9,
|
||||
@ -41,10 +41,12 @@
|
||||
"scanvp_cands_file": "../datasets/R2R/annotations/scanvp_candview_relangles.json",
|
||||
"tasks": [
|
||||
"mlm",
|
||||
"mrc",
|
||||
"sap",
|
||||
"og"
|
||||
],
|
||||
"mix_ratio": [
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1
|
||||
|
||||
@ -387,7 +387,7 @@ class R2RTextPathData(ReverieTextPathData):
|
||||
# local:
|
||||
for k, cand_vp in enumerate(traj_cand_vpids[-1]):
|
||||
if cand_vp == gt_next_vp:
|
||||
local_act_labels = k + 1 # [stop] is 0
|
||||
local_act_label = k + 1 # [stop] is 0
|
||||
break
|
||||
return global_act_label, local_act_label
|
||||
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
|
||||
NODE_RANK=0
|
||||
NUM_GPUS=4
|
||||
NUM_GPUS=1
|
||||
outdir=../datasets/R2R/exprs_map/pretrain/cmt-vitbase-mlm.mrc.sap-init.lxmert-aug.speaker
|
||||
|
||||
# train
|
||||
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch \
|
||||
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
|
||||
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
|
||||
train_r2r.py --world_size ${NUM_GPUS} \
|
||||
--vlnbert cmt \
|
||||
|
||||
12
pretrain_src/run_r4r.sh
Normal file
12
pretrain_src/run_r4r.sh
Normal file
@ -0,0 +1,12 @@
|
||||
NODE_RANK=0
|
||||
NUM_GPUS=1
|
||||
outdir=../datasets/R4R/exprs_map/pretrain/cmt-vitbase-mlm.sap-init.lxmert
|
||||
|
||||
# train
|
||||
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
|
||||
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
|
||||
train_r4r.py --world_size ${NUM_GPUS} \
|
||||
--vlnbert cmt \
|
||||
--model_config config/r2r_model_config.json \
|
||||
--config config/r4r_pretrain.json \
|
||||
--output_dir $outdir
|
||||
@ -1,9 +1,9 @@
|
||||
NODE_RANK=0
|
||||
NUM_GPUS=2
|
||||
NUM_GPUS=1
|
||||
outdir=../datasets/REVERIE/exprs_map/pretrain/cmt-vitbase-mlm.mrc.sap.og-init.lxmert-aug.speaker
|
||||
|
||||
# train
|
||||
CUDA_VISIBLE_DEVICES='0,1' python -m torch.distributed.launch \
|
||||
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
|
||||
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
|
||||
train_reverie_obj.py --world_size ${NUM_GPUS} \
|
||||
--vlnbert cmt \
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
NODE_RANK=0
|
||||
NUM_GPUS=4
|
||||
NUM_GPUS=1
|
||||
outdir=../datasets/SOON/exprs_map/pretrain/cmt-vitbase.butdobj-mlm.sap.og-init.lxmert
|
||||
|
||||
# train
|
||||
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch \
|
||||
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
|
||||
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
|
||||
train_soon_obj.py --world_size ${NUM_GPUS} \
|
||||
--vlnbert cmt \
|
||||
|
||||
460
pretrain_src/train_r4r.py
Normal file
460
pretrain_src/train_r4r.py
Normal file
@ -0,0 +1,460 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from easydict import EasyDict
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch.cuda.amp as amp # TODO
|
||||
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
from transformers import AutoModel
|
||||
|
||||
from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
|
||||
from utils.save import ModelSaver, save_training_meta
|
||||
from utils.misc import NoOp, set_dropout, set_random_seed, set_cuda, wrap_model
|
||||
from utils.distributed import all_gather
|
||||
|
||||
from optim import get_lr_sched
|
||||
from optim.misc import build_optimizer
|
||||
|
||||
from parser import load_parser, parse_with_config
|
||||
|
||||
from data.loader import MetaLoader, PrefetchLoader, build_dataloader
|
||||
from data.dataset import R2RTextPathData
|
||||
from data.tasks import (
|
||||
MlmDataset, mlm_collate,
|
||||
MrcDataset, mrc_collate,
|
||||
SapDataset, sap_collate)
|
||||
|
||||
from model.pretrain_cmt import GlocalTextPathCMTPreTraining
|
||||
|
||||
|
||||
def create_dataloaders(
|
||||
data_cfg, nav_db, tok, is_train: bool, device: torch.device, opts
|
||||
):
|
||||
dataloaders = {}
|
||||
for k, task_name in enumerate(data_cfg.tasks):
|
||||
if task_name == 'mlm':
|
||||
task_dataset = MlmDataset(nav_db, tok)
|
||||
task_collate_fn = mlm_collate
|
||||
elif task_name == 'mrc':
|
||||
task_dataset = MrcDataset(nav_db, tok, opts.mrc_mask_prob, end_vp_pos_ratio=0.2)
|
||||
task_collate_fn = mrc_collate
|
||||
elif task_name == 'sap':
|
||||
task_dataset = SapDataset(nav_db, tok, end_vp_pos_ratio=0.2)
|
||||
task_collate_fn = sap_collate
|
||||
else:
|
||||
raise ValueError(f'Undefined task {task}')
|
||||
|
||||
LOGGER.info(f"{task_name}: {len(task_dataset)} samples loaded")
|
||||
|
||||
task_loader, pre_epoch = build_dataloader(
|
||||
task_name, task_dataset, task_collate_fn, is_train, opts
|
||||
)
|
||||
|
||||
if is_train:
|
||||
ratio = data_cfg.mix_ratio[k]
|
||||
dataloaders[task_name] = (task_loader, ratio, pre_epoch)
|
||||
else:
|
||||
dataloaders[task_name] = PrefetchLoader(task_loader, device)
|
||||
return dataloaders
|
||||
|
||||
|
||||
def main(opts):
|
||||
default_gpu, n_gpu, device = set_cuda(opts)
|
||||
|
||||
if default_gpu:
|
||||
LOGGER.info(
|
||||
'device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}'.format(
|
||||
device, n_gpu, bool(opts.local_rank != -1), opts.fp16
|
||||
)
|
||||
)
|
||||
|
||||
seed = opts.seed
|
||||
if opts.local_rank != -1:
|
||||
seed += opts.rank
|
||||
set_random_seed(seed)
|
||||
|
||||
if default_gpu:
|
||||
save_training_meta(opts)
|
||||
TB_LOGGER.create(os.path.join(opts.output_dir, 'logs'))
|
||||
pbar = tqdm(total=opts.num_train_steps)
|
||||
model_saver = ModelSaver(os.path.join(opts.output_dir, 'ckpts'))
|
||||
add_log_to_file(os.path.join(opts.output_dir, 'logs', 'log.txt'))
|
||||
else:
|
||||
LOGGER.disabled = True
|
||||
pbar = NoOp()
|
||||
model_saver = NoOp()
|
||||
|
||||
# Model config
|
||||
model_config = PretrainedConfig.from_json_file(opts.model_config)
|
||||
model_config.pretrain_tasks = []
|
||||
for train_dataset_config in opts.train_datasets.values():
|
||||
model_config.pretrain_tasks.extend(train_dataset_config['tasks'])
|
||||
model_config.pretrain_tasks = set(model_config.pretrain_tasks)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.lang_bert_name)
|
||||
|
||||
# Prepare model
|
||||
if opts.checkpoint:
|
||||
checkpoint = torch.load(opts.checkpoint, map_location=lambda storage, loc: storage)
|
||||
else:
|
||||
checkpoint = {}
|
||||
if opts.init_pretrained == 'bert':
|
||||
tmp = AutoModel.from_pretrained(model_config.lang_bert_name)
|
||||
for param_name, param in tmp.named_parameters():
|
||||
checkpoint[param_name] = param
|
||||
if model_config.lang_bert_name == 'xlm-roberta-base':
|
||||
# embeddings.token_type_embeddings.weight (1 -> 2, the second is for image embedding)
|
||||
checkpoint['embeddings.token_type_embeddings.weight'] = torch.cat(
|
||||
[checkpoint['embeddings.token_type_embeddings.weight']] * 2, 0
|
||||
)
|
||||
del tmp
|
||||
elif opts.init_pretrained == 'lxmert':
|
||||
tmp = torch.load(
|
||||
'../datasets/pretrained/LXMERT/model_LXRT.pth',
|
||||
map_location=lambda storage, loc: storage
|
||||
)
|
||||
for param_name, param in tmp.items():
|
||||
param_name = param_name.replace('module.', '')
|
||||
if 'bert.encoder.layer' in param_name:
|
||||
param_name = param_name.replace('bert.encoder.layer', 'bert.lang_encoder.layer')
|
||||
checkpoint[param_name] = param
|
||||
elif 'bert.encoder.x_layers' in param_name:
|
||||
param_name1 = param_name.replace('bert.encoder.x_layers', 'bert.local_encoder.encoder.x_layers')
|
||||
param_name2 = param_name.replace('bert.encoder.x_layers', 'bert.global_encoder.encoder.x_layers')
|
||||
checkpoint[param_name1] = checkpoint[param_name2] = param
|
||||
elif 'cls.predictions' in param_name:
|
||||
param_name = param_name.replace('cls.predictions', 'mlm_head.predictions')
|
||||
checkpoint[param_name] = param
|
||||
else:
|
||||
checkpoint[param_name] = param
|
||||
del tmp
|
||||
|
||||
model_class = GlocalTextPathCMTPreTraining
|
||||
|
||||
# update some training configs
|
||||
model = model_class.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=model_config, state_dict=checkpoint
|
||||
)
|
||||
model.train()
|
||||
set_dropout(model, opts.dropout)
|
||||
model = wrap_model(model, device, opts.local_rank)
|
||||
del checkpoint
|
||||
|
||||
# load data training set
|
||||
data_cfg = EasyDict(opts.train_datasets['R4R'])
|
||||
train_nav_db = R2RTextPathData(
|
||||
data_cfg.train_traj_files, data_cfg.img_ft_file,
|
||||
data_cfg.scanvp_cands_file, data_cfg.connectivity_dir,
|
||||
image_prob_size=model_config.image_prob_size,
|
||||
image_feat_size=model_config.image_feat_size,
|
||||
angle_feat_size=model_config.angle_feat_size,
|
||||
max_txt_len=opts.max_txt_len, in_memory=True,
|
||||
act_visited_node=True
|
||||
)
|
||||
val_nav_db = R2RTextPathData(
|
||||
data_cfg.val_seen_traj_files, data_cfg.img_ft_file,
|
||||
data_cfg.scanvp_cands_file, data_cfg.connectivity_dir,
|
||||
image_prob_size=model_config.image_prob_size,
|
||||
image_feat_size=model_config.image_feat_size,
|
||||
angle_feat_size=model_config.angle_feat_size,
|
||||
max_txt_len=opts.max_txt_len, in_memory=True,
|
||||
act_visited_node=True
|
||||
)
|
||||
val2_nav_db = R2RTextPathData(
|
||||
data_cfg.val_unseen_traj_files, data_cfg.img_ft_file,
|
||||
data_cfg.scanvp_cands_file, data_cfg.connectivity_dir,
|
||||
image_prob_size=model_config.image_prob_size,
|
||||
image_feat_size=model_config.image_feat_size,
|
||||
angle_feat_size=model_config.angle_feat_size,
|
||||
max_txt_len=opts.max_txt_len, in_memory=True,
|
||||
act_visited_node=True
|
||||
)
|
||||
|
||||
# Build data loaders
|
||||
train_dataloaders = create_dataloaders(
|
||||
data_cfg, train_nav_db, tokenizer, True, device, opts
|
||||
)
|
||||
val_dataloaders = create_dataloaders(
|
||||
data_cfg, val_nav_db, tokenizer, False, device, opts
|
||||
)
|
||||
val2_dataloaders = create_dataloaders(
|
||||
data_cfg, val2_nav_db, tokenizer, False, device, opts
|
||||
)
|
||||
meta_loader = MetaLoader(
|
||||
train_dataloaders,
|
||||
accum_steps=opts.gradient_accumulation_steps,
|
||||
distributed=opts.local_rank != -1,
|
||||
device=device
|
||||
)
|
||||
meta_loader = PrefetchLoader(meta_loader, device)
|
||||
|
||||
# Prepare optimizer
|
||||
optimizer = build_optimizer(model, opts)
|
||||
task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
|
||||
|
||||
if opts.fp16:
|
||||
grad_scaler = amp.GradScaler()
|
||||
|
||||
global_step = 0
|
||||
LOGGER.info(f"***** Running training with {opts.world_size} GPUs *****")
|
||||
LOGGER.info(" Batch size = %d", opts.train_batch_size if opts.local_rank == -1 else opts.train_batch_size * opts.world_size)
|
||||
LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps)
|
||||
LOGGER.info(" Num steps = %d", opts.num_train_steps)
|
||||
|
||||
# to compute training statistics
|
||||
task2loss = {task: RunningMeter(f'loss/{task}')
|
||||
for task in train_dataloaders.keys()}
|
||||
|
||||
n_examples = defaultdict(int)
|
||||
n_in_units = defaultdict(int)
|
||||
n_loss_units = defaultdict(int)
|
||||
grad_norm = 0
|
||||
|
||||
start_time = time.time()
|
||||
# quick hack for amp delay_unscale bug
|
||||
optimizer.zero_grad()
|
||||
optimizer.step()
|
||||
for step, (name, batch) in enumerate(meta_loader):
|
||||
# forward pass
|
||||
n_examples[name] += batch['txt_ids'].size(0)
|
||||
n_in_units[name] += batch['txt_lens'].sum().item()
|
||||
task = name.split('_')[0]
|
||||
# print(name, task)
|
||||
# for k, v in batch.items():
|
||||
# print(k, v.size())
|
||||
# continue
|
||||
if opts.fp16:
|
||||
with amp.autocast():
|
||||
loss = model(batch, task=task, compute_loss=True)
|
||||
else:
|
||||
loss = model(batch, task=task, compute_loss=True)
|
||||
|
||||
n_loss_units[name] += loss.size(0)
|
||||
loss = loss.mean() # loss is not normalized in model
|
||||
|
||||
# backward pass
|
||||
if args.gradient_accumulation_steps > 1: # average loss
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
|
||||
if opts.fp16:
|
||||
grad_scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
task2loss[name](loss.item())
|
||||
|
||||
# optimizer update and logging
|
||||
if (step + 1) % opts.gradient_accumulation_steps == 0:
|
||||
global_step += 1
|
||||
|
||||
# learning rate scheduling
|
||||
lr_this_step = get_lr_sched(global_step, opts)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
|
||||
|
||||
# log loss
|
||||
# NOTE: not gathered across GPUs for efficiency
|
||||
TB_LOGGER.log_scalar_dict({ll.name: ll.val
|
||||
for ll in task2loss.values()
|
||||
if ll.val is not None})
|
||||
TB_LOGGER.step()
|
||||
|
||||
# update model params
|
||||
if opts.grad_norm != -1:
|
||||
if opts.fp16:
|
||||
grad_scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), opts.grad_norm
|
||||
)
|
||||
# print(step, name, grad_norm)
|
||||
# for k, v in model.named_parameters():
|
||||
# if v.grad is not None:
|
||||
# v = torch.norm(v).data.item()
|
||||
# print(k, v)
|
||||
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
|
||||
if opts.fp16:
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
pbar.update(1)
|
||||
|
||||
if global_step % opts.log_steps == 0:
|
||||
# monitor training throughput
|
||||
LOGGER.info(f'==============Step {global_step}===============')
|
||||
for t in train_dataloaders.keys():
|
||||
tot_ex = n_examples[t]
|
||||
ex_per_sec = int(tot_ex / (time.time() - start_time))
|
||||
tot_in = n_in_units[t]
|
||||
in_per_sec = int(tot_in / (time.time() - start_time))
|
||||
tot_l = n_loss_units[t]
|
||||
l_per_sec = int(tot_l / (time.time() - start_time))
|
||||
LOGGER.info(f'{t}: {tot_ex} examples trained at '
|
||||
f'{ex_per_sec} ex/s')
|
||||
TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
|
||||
global_step)
|
||||
TB_LOGGER.add_scalar(f'perf/{t}_in_per_s', in_per_sec,
|
||||
global_step)
|
||||
TB_LOGGER.add_scalar(f'perf/{t}_loss_per_s', l_per_sec,
|
||||
global_step)
|
||||
LOGGER.info('===============================================')
|
||||
|
||||
if global_step % opts.valid_steps == 0:
|
||||
LOGGER.info(f'------Step {global_step}: start validation seen------')
|
||||
validate(model, val_dataloaders, setname='_seen')
|
||||
LOGGER.info(f'------Step {global_step}: start validation unseen------')
|
||||
validate(model, val2_dataloaders, setname='_unseen')
|
||||
model_saver.save(model, global_step)
|
||||
if global_step >= opts.num_train_steps:
|
||||
break
|
||||
if global_step % opts.valid_steps != 0:
|
||||
LOGGER.info(f'------Step {global_step}: start validation seen------')
|
||||
validate(model, val_dataloaders, setname='_seen')
|
||||
LOGGER.info(f'------Step {global_step}: start validation unseen------')
|
||||
validate(model, val2_dataloaders, setname='_unseen')
|
||||
model_saver.save(model, global_step)
|
||||
|
||||
|
||||
def validate(model, val_dataloaders, setname=''):
|
||||
model.eval()
|
||||
for task, loader in val_dataloaders.items():
|
||||
LOGGER.info(f"validate val{setname} on {task} task")
|
||||
if task.startswith('mlm'):
|
||||
val_log = validate_mlm(model, loader)
|
||||
elif task.startswith('mrc'):
|
||||
val_log = validate_mrc(model, loader)
|
||||
elif task.startswith('sap'):
|
||||
val_log = validate_sap(model, loader)
|
||||
else:
|
||||
raise ValueError(f'Undefined task {task}')
|
||||
val_log = {f'val{setname}_{task}_{k}': v for k, v in val_log.items()}
|
||||
TB_LOGGER.log_scalar_dict(
|
||||
{f'valid{setname}_{task}/{k}': v for k, v in val_log.items()}
|
||||
)
|
||||
model.train()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_mlm(model, val_loader):
|
||||
LOGGER.info("start running MLM validation...")
|
||||
val_loss = 0
|
||||
n_correct = 0
|
||||
n_word = 0
|
||||
st = time.time()
|
||||
for i, batch in enumerate(val_loader):
|
||||
scores = model(batch, task='mlm', compute_loss=False)
|
||||
labels = batch['txt_labels']
|
||||
labels = labels[labels != -1]
|
||||
loss = F.cross_entropy(scores, labels, reduction='sum')
|
||||
val_loss += loss.item()
|
||||
n_correct += (scores.max(dim=-1)[1] == labels).sum().item()
|
||||
n_word += labels.numel()
|
||||
val_loss = sum(all_gather(val_loss))
|
||||
n_correct = sum(all_gather(n_correct))
|
||||
n_word = sum(all_gather(n_word))
|
||||
tot_time = time.time()-st
|
||||
val_loss /= n_word
|
||||
acc = n_correct / n_word
|
||||
val_log = {'loss': val_loss,
|
||||
'acc': acc,
|
||||
'tok_per_s': n_word/tot_time}
|
||||
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
|
||||
f"acc: {acc*100:.2f}")
|
||||
return val_log
|
||||
|
||||
def compute_accuracy_for_soft_targets(out, labels):
|
||||
outputs = out.max(dim=-1)[1]
|
||||
labels = labels.max(dim=-1)[1] # argmax
|
||||
n_correct = (outputs == labels).sum().item()
|
||||
return n_correct
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_mrc(model, val_loader):
|
||||
LOGGER.info("start running MRC validation...")
|
||||
val_loss = 0
|
||||
n_feat = 0
|
||||
st = time.time()
|
||||
tot_score = 0
|
||||
for i, batch in enumerate(val_loader):
|
||||
view_logits, view_targets, _, _ = model(batch, task='mrc', compute_loss=False)
|
||||
view_logprobs = F.log_softmax(view_logits, dim=-1)
|
||||
loss = F.kl_div(view_logprobs, view_targets, reduction='sum')
|
||||
tot_score += compute_accuracy_for_soft_targets(view_logits, view_targets)
|
||||
val_loss += loss.item()
|
||||
n_feat += batch['vp_view_mrc_masks'].sum().item()
|
||||
val_loss = sum(all_gather(val_loss))
|
||||
tot_score = sum(all_gather(tot_score))
|
||||
n_feat = sum(all_gather(n_feat))
|
||||
tot_time = time.time()-st
|
||||
val_loss /= n_feat
|
||||
val_acc = tot_score / n_feat
|
||||
val_log = {'loss': val_loss,
|
||||
'acc': val_acc,
|
||||
'feat_per_s': n_feat/tot_time}
|
||||
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
|
||||
f"score: {val_acc*100:.2f}")
|
||||
return val_log
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_sap(model, val_loader):
|
||||
LOGGER.info("start running SAP validation...")
|
||||
val_gloss, val_lloss, val_floss = 0, 0, 0
|
||||
n_gcorrect, n_lcorrect, n_fcorrect = 0, 0, 0
|
||||
n_data = 0
|
||||
st = time.time()
|
||||
for i, batch in enumerate(val_loader):
|
||||
global_logits, local_logits, fused_logits, global_act_labels, local_act_labels = \
|
||||
model(batch, task='sap', compute_loss=False)
|
||||
val_gloss += F.cross_entropy(global_logits, global_act_labels, reduction='sum').data.item()
|
||||
val_lloss += F.cross_entropy(local_logits, local_act_labels, reduction='sum').data.item()
|
||||
val_floss += F.cross_entropy(fused_logits, global_act_labels, reduction='sum').data.item()
|
||||
n_gcorrect += torch.sum(torch.argmax(global_logits, 1) == global_act_labels).item()
|
||||
n_lcorrect += torch.sum(torch.argmax(local_logits, 1) == local_act_labels).item()
|
||||
n_fcorrect += torch.sum(torch.argmax(fused_logits, 1) == global_act_labels).item()
|
||||
n_data += len(global_act_labels)
|
||||
|
||||
n_data = sum(all_gather(n_data))
|
||||
val_gloss = sum(all_gather(val_gloss)) / n_data
|
||||
val_lloss = sum(all_gather(val_lloss)) / n_data
|
||||
val_floss = sum(all_gather(val_floss)) / n_data
|
||||
gacc = sum(all_gather(n_gcorrect)) / n_data
|
||||
lacc = sum(all_gather(n_lcorrect)) / n_data
|
||||
facc = sum(all_gather(n_fcorrect)) / n_data
|
||||
|
||||
tot_time = time.time()-st
|
||||
val_log = {'gloss': val_gloss, 'lloss': val_lloss, 'floss': val_floss,
|
||||
'gacc': gacc, 'lacc': lacc, 'facc': facc,
|
||||
'tok_per_s': n_data/tot_time}
|
||||
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
|
||||
f"gacc: {gacc*100:.2f}, lacc: {lacc*100:.2f}, facc: {facc*100:.2f}")
|
||||
return val_log
|
||||
|
||||
def build_args():
|
||||
parser = load_parser()
|
||||
|
||||
opts = parse_with_config(parser)
|
||||
|
||||
if os.path.exists(opts.output_dir) and os.listdir(opts.output_dir):
|
||||
LOGGER.warning(
|
||||
"Output directory ({}) already exists and is not empty.".format(
|
||||
opts.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
return opts
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = build_args()
|
||||
main(args)
|
||||
Loading…
Reference in New Issue
Block a user