This commit is contained in:
Shizhe Chen 2022-03-26 20:56:29 +01:00
parent 7a8a5dee38
commit 89214a7c44
20 changed files with 696 additions and 31 deletions

View File

@ -1,6 +1,15 @@
# Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigation
This repository is the official implementation of [Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigatio]().
This repository is the official implementation of [Think Global, Act Local: Dual-scale GraphTransformer for Vision-and-Language Navigation](https://arxiv.org/abs/2202.11742).
Winner of the [ICCV 2021 Workshop Human Interaction for Robotic Navigation REVERIE & SOON Challenges](https://human-interaction4robotic-navigation.github.io/challenge.html).
Project webpage: [https://cshizhe.github.io/projects/vln_duet.html](https://cshizhe.github.io/projects/vln_duet.html).
Following language instructions to navigate in unseen environments is a challenging problem for autonomous embodied agents. The agent not only needs to ground languages in visual scenes, but also should explore the environment to reach its target. In this work, we propose a dual-scale graph transformer (DUET) for joint long-term action planning and fine-grained cross-modal understanding. We build a topological map on-the-fly to enable efficient exploration in global action space. To balance the complexity of large action space reasoning and fine-grained language grounding, we dynamically combine a fine-scale encoding over local observations and a coarse-scale encoding on a global map via graph transformers. The proposed approach, DUET, significantly outperforms state-of-the-art methods on goal-oriented vision-and-language navigation (VLN) benchmarks REVERIE and SOON. It also improves the success rate on the fine-grained VLN benchmark R2R.
![framework](files/teaser.png)
## Requirements
@ -16,7 +25,7 @@ conda activate vlnduet
pip install -r requirements.txt
```
3. Download data from [Dropbox](https://www.dropbox.com/s/7bijvxdw3rf451c/datasets.tar.gz?dl=0), including processed annotations, features and pretrained models. Put the data in `datasets' directory.
3. Download data from [Dropbox](https://www.dropbox.com/s/7bijvxdw3rf451c/datasets.tar.gz?dl=0), including processed annotations, features and pretrained models of REVERIE, SOON, R2R and R4R datasets. Put the data in `datasets' directory.
4. Download pretrained lxmert
```
@ -25,19 +34,17 @@ wget https://nlp.cs.unc.edu/data/model_LXRT.pth -P datasets/pretrained
```
## Pretraining
Combine behavior cloning and auxiliary proxy tasks in pretraining:
```pretrain
cd pretrain_src
bash run_reverie.sh # (run_soon.sh, run_r2r.sh)
bash run_reverie.sh # (run_soon.sh, run_r2r.sh, run_r4r.sh)
```
## Fine-tuning & Evaluation
Combine behavior cloning and auxiliary proxy tasks in pretraining:
Use pseudo interative demonstrator to fine-tune the model:
```finetune
cd map_nav_src
bash scripts/run_reverie.sh # (run_soon.sh, run_r2r.sh)
```
## Examples
Video examples can be found [here](https://www.dropbox.com/sh/g8vqygz7fgerg9s/AAAZ3gd9WdReUgRezxLnb1f_a?dl=0).

BIN
files/teaser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 679 KiB

View File

@ -236,6 +236,54 @@ class GMapNavAgent(Seq2SeqAgent):
return torch.from_numpy(a).cuda()
def _teacher_action_r4r(
self, obs, vpids, ended, visited_masks=None, imitation_learning=False, t=None, traj=None
):
"""R4R is not the shortest path. The goal location can be visited nodes.
"""
a = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if ended[i]: # Just ignore this index
a[i] = self.args.ignoreid
else:
if imitation_learning:
assert ob['viewpoint'] == ob['gt_path'][t]
if t == len(ob['gt_path']) - 1:
a[i] = 0 # stop
else:
goal_vp = ob['gt_path'][t + 1]
for j, vpid in enumerate(vpids[i]):
if goal_vp == vpid:
a[i] = j
break
else:
if ob['viewpoint'] == ob['gt_path'][-1]:
a[i] = 0 # Stop if arrived
else:
scan = ob['scan']
cur_vp = ob['viewpoint']
min_idx, min_dist = self.args.ignoreid, float('inf')
for j, vpid in enumerate(vpids[i]):
if j > 0 and ((visited_masks is None) or (not visited_masks[i][j])):
if self.args.expert_policy == 'ndtw':
dist = - cal_dtw(
self.env.shortest_distances[scan],
sum(traj[i]['path'], []) + self.env.shortest_paths[scan][ob['viewpoint']][vpid][1:],
ob['gt_path'],
threshold=3.0
)['nDTW']
elif self.args.expert_policy == 'spl':
# dist = min([self.env.shortest_distances[scan][vpid][end_vp] for end_vp in ob['gt_end_vps']])
dist = self.env.shortest_distances[scan][vpid][ob['gt_path'][-1]] \
+ self.env.shortest_distances[scan][cur_vp][vpid]
if dist < min_dist:
min_dist = dist
min_idx = j
a[i] = min_idx
if min_idx == self.args.ignoreid:
print('scan %s: all vps are searched' % (scan))
return torch.from_numpy(a).cuda()
def make_equiv_action(self, a_t, gmaps, obs, traj=None):
"""
Interface between Panoramic view and Egocentric view
@ -355,10 +403,22 @@ class GMapNavAgent(Seq2SeqAgent):
if train_ml is not None:
# Supervised training
nav_targets = self._teacher_action(
obs, nav_vpids, ended,
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
)
if self.args.dataset == 'r2r':
# nav_targets = self._teacher_action(
# obs, nav_vpids, ended,
# visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
# )
nav_targets = self._teacher_action_r4r(
obs, nav_vpids, ended,
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None,
imitation_learning=(self.feedback=='teacher'), t=t, traj=traj
)
elif self.args.dataset == 'r4r':
nav_targets = self._teacher_action_r4r(
obs, nav_vpids, ended,
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None,
imitation_learning=(self.feedback=='teacher'), t=t, traj=traj
)
# print(t, nav_logits, nav_targets)
ml_loss += self.criterion(nav_logits, nav_targets)
# print(t, 'ml_loss', ml_loss.item(), self.criterion(nav_logits, nav_targets).item())

View File

@ -19,6 +19,10 @@ def load_instr_datasets(anno_dir, dataset, splits, tokenizer, is_test=True):
if split == 'val_train_seen':
new_data = new_data[:50]
if not is_test:
if dataset == 'r4r' and split == 'val_unseen':
ridxs = np.random.permutation(len(new_data))[:200]
new_data = [new_data[ridx] for ridx in ridxs]
else: # augmented data
print('\nLoading augmented data %s for pretraining...' % os.path.basename(split))
with open(split) as f:

View File

@ -59,8 +59,10 @@ def build_dataset(args, rank=0, is_test=False):
# val_env_names = ['val_train_seen']
val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
if args.dataset == 'r4r' and (not args.test):
val_env_names[-1] == 'val_unseen_sampled'
if args.submit:
if args.submit and args.dataset != 'r4r':
val_env_names.append('test')
val_envs = {}
@ -129,6 +131,8 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
)
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
if args.dataset == 'r4r':
best_val = {'val_unseen_sampled': {"spl": 0., "sr": 0., "state":""}}
for idx in range(start_iter, start_iter+args.iters, args.log_every):
listner.logs = defaultdict(list)

View File

@ -6,7 +6,7 @@ def parse_args():
parser = argparse.ArgumentParser(description="")
parser.add_argument('--root_dir', type=str, default='../datasets')
parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r'])
parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r', 'r4r'])
parser.add_argument('--output_dir', type=str, default='default', help='experiment id')
parser.add_argument('--seed', type=int, default=0)

View File

@ -1,3 +1,4 @@
DATA_ROOT=../datasets
train_alg=dagger
@ -56,7 +57,7 @@ flag="--root_dir ${DATA_ROOT}
# train
CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \
--tokenizer bert \
--bert_ckpt_file '' \
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
--eval_first
# test

View File

@ -0,0 +1,66 @@
train_alg=dagger
features=vitbase
ft_dim=768
obj_features=vitbase
obj_ft_dim=768
ngpus=1
seed=0
name=${train_alg}-${features}
name=${name}-seed.${seed}
name=${name}-init.aug.45k
outdir=${DATA_ROOT}/R2R/exprs_map/finetune/${name}
flag="--root_dir ${DATA_ROOT}
--dataset r4r
--output_dir ${outdir}
--world_size ${ngpus}
--seed ${seed}
--tokenizer bert
--enc_full_graph
--graph_sprels
--fusion dynamic
--expert_policy spl
--train_alg ${train_alg}
--num_l_layers 9
--num_x_layers 4
--num_pano_layers 2
--max_action_len 15
--max_instr_len 200
--batch_size 8
--lr 1e-5
--iters 200000
--log_every 1000
--optim adamW
--features ${features}
--image_feat_size ${ft_dim}
--angle_feat_size 4
--ml_weight 0.2
--feat_dropout 0.4
--dropout 0.5
--gamma 0."
# train
CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \
--tokenizer bert \
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
--eval_first
# test
CUDA_VISIBLE_DEVICES='0' python r2r/main_nav.py $flag \
--tokenizer bert \
--resume_file ../datasets/R2R/trained_models/best_val_unseen \
--test --submit

View File

@ -61,7 +61,7 @@ flag="--root_dir ${DATA_ROOT}
# train
CUDA_VISIBLE_DEVICES='0' python reverie/main_nav_obj.py $flag \
--tokenizer bert \
--bert_ckpt_file '' \
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
--eval_first
# test

View File

@ -1,3 +1,4 @@
DATA_ROOT=../datasets
train_alg=dagger
@ -60,7 +61,7 @@ flag="--root_dir ${DATA_ROOT}
CUDA_VISIBLE_DEVICES='0' python soon/main.py $flag \
--tokenizer bert \
--bert_ckpt_file '' \
--bert_ckpt_file 'put the pretrained model (see pretrain_src) here' \
--eval_first
# test

View File

@ -4,8 +4,8 @@
"output_dir": "",
"mrc_mask_prob": 0.15,
"max_txt_len": 200,
"train_batch_size": 16,
"val_batch_size": 16,
"train_batch_size": 64,
"val_batch_size": 64,
"gradient_accumulation_steps": 1,
"learning_rate": 5e-05,
"valid_steps": 2500,

View File

@ -0,0 +1,48 @@
{
"model_config": "",
"checkpoint": null,
"output_dir": "",
"mrc_mask_prob": 0.15,
"max_txt_len": 200,
"train_batch_size": 32,
"val_batch_size": 32,
"gradient_accumulation_steps": 1,
"learning_rate": 5e-05,
"valid_steps": 5000,
"log_steps": 1000,
"num_train_steps": 100000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 5.0,
"warmup_steps": 10000,
"seed": 0,
"fp16": false,
"n_workers": 1,
"pin_mem": true,
"init_pretrained": "lxmert",
"train_datasets": {
"R4R": {
"name": "R4R",
"train_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_train_enc.jsonl"],
"val_seen_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_val_seen_enc.jsonl"],
"val_unseen_traj_files": ["../datasets/R4R/annotations/pretrain_map/R4R_val_unseen_sampled_enc.jsonl"],
"connectivity_dir": "../datasets/R2R/connectivity",
"img_ft_file": "../datasets/R2R/features/pth_vit_base_patch16_224_imagenet.hdf5",
"scanvp_cands_file": "../datasets/R2R/annotations/scanvp_candview_relangles.json",
"tasks": [
"mlm",
"sap"
],
"mix_ratio": [
1,
1
]
}
}
}

View File

@ -7,8 +7,8 @@
"nearby_vp_steps": null,
"max_objects": 20,
"max_txt_len": 200,
"train_batch_size": 16,
"val_batch_size": 16,
"train_batch_size": 32,
"val_batch_size": 32,
"gradient_accumulation_steps": 1,
"learning_rate": 5e-05,
"valid_steps": 4000,

View File

@ -7,13 +7,13 @@
"nearby_vp_steps": null,
"max_objects": 100,
"max_txt_len": 200,
"train_batch_size": 8,
"val_batch_size": 8,
"train_batch_size": 32,
"val_batch_size": 32,
"gradient_accumulation_steps": 1,
"learning_rate": 5e-05,
"valid_steps": 1500,
"valid_steps": 2000,
"log_steps": 1000,
"num_train_steps": 30000,
"num_train_steps": 40000,
"optim": "adamw",
"betas": [
0.9,
@ -41,10 +41,12 @@
"scanvp_cands_file": "../datasets/R2R/annotations/scanvp_candview_relangles.json",
"tasks": [
"mlm",
"mrc",
"sap",
"og"
],
"mix_ratio": [
1,
1,
1,
1

View File

@ -387,7 +387,7 @@ class R2RTextPathData(ReverieTextPathData):
# local:
for k, cand_vp in enumerate(traj_cand_vpids[-1]):
if cand_vp == gt_next_vp:
local_act_labels = k + 1 # [stop] is 0
local_act_label = k + 1 # [stop] is 0
break
return global_act_label, local_act_label

View File

@ -1,10 +1,10 @@
NODE_RANK=0
NUM_GPUS=4
NUM_GPUS=1
outdir=../datasets/R2R/exprs_map/pretrain/cmt-vitbase-mlm.mrc.sap-init.lxmert-aug.speaker
# train
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch \
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
train_r2r.py --world_size ${NUM_GPUS} \
--vlnbert cmt \

12
pretrain_src/run_r4r.sh Normal file
View File

@ -0,0 +1,12 @@
NODE_RANK=0
NUM_GPUS=1
outdir=../datasets/R4R/exprs_map/pretrain/cmt-vitbase-mlm.sap-init.lxmert
# train
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
train_r4r.py --world_size ${NUM_GPUS} \
--vlnbert cmt \
--model_config config/r2r_model_config.json \
--config config/r4r_pretrain.json \
--output_dir $outdir

View File

@ -1,9 +1,9 @@
NODE_RANK=0
NUM_GPUS=2
NUM_GPUS=1
outdir=../datasets/REVERIE/exprs_map/pretrain/cmt-vitbase-mlm.mrc.sap.og-init.lxmert-aug.speaker
# train
CUDA_VISIBLE_DEVICES='0,1' python -m torch.distributed.launch \
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
train_reverie_obj.py --world_size ${NUM_GPUS} \
--vlnbert cmt \

View File

@ -1,9 +1,9 @@
NODE_RANK=0
NUM_GPUS=4
NUM_GPUS=1
outdir=../datasets/SOON/exprs_map/pretrain/cmt-vitbase.butdobj-mlm.sap.og-init.lxmert
# train
CUDA_VISIBLE_DEVICES='0,1,2,3' python -m torch.distributed.launch \
CUDA_VISIBLE_DEVICES='0' python -m torch.distributed.launch \
--nproc_per_node=${NUM_GPUS} --node_rank $NODE_RANK \
train_soon_obj.py --world_size ${NUM_GPUS} \
--vlnbert cmt \

460
pretrain_src/train_r4r.py Normal file
View File

@ -0,0 +1,460 @@
import os
import sys
import json
import argparse
import time
from collections import defaultdict
from easydict import EasyDict
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.cuda.amp as amp # TODO
from transformers import AutoTokenizer, PretrainedConfig
from transformers import AutoModel
from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from utils.save import ModelSaver, save_training_meta
from utils.misc import NoOp, set_dropout, set_random_seed, set_cuda, wrap_model
from utils.distributed import all_gather
from optim import get_lr_sched
from optim.misc import build_optimizer
from parser import load_parser, parse_with_config
from data.loader import MetaLoader, PrefetchLoader, build_dataloader
from data.dataset import R2RTextPathData
from data.tasks import (
MlmDataset, mlm_collate,
MrcDataset, mrc_collate,
SapDataset, sap_collate)
from model.pretrain_cmt import GlocalTextPathCMTPreTraining
def create_dataloaders(
data_cfg, nav_db, tok, is_train: bool, device: torch.device, opts
):
dataloaders = {}
for k, task_name in enumerate(data_cfg.tasks):
if task_name == 'mlm':
task_dataset = MlmDataset(nav_db, tok)
task_collate_fn = mlm_collate
elif task_name == 'mrc':
task_dataset = MrcDataset(nav_db, tok, opts.mrc_mask_prob, end_vp_pos_ratio=0.2)
task_collate_fn = mrc_collate
elif task_name == 'sap':
task_dataset = SapDataset(nav_db, tok, end_vp_pos_ratio=0.2)
task_collate_fn = sap_collate
else:
raise ValueError(f'Undefined task {task}')
LOGGER.info(f"{task_name}: {len(task_dataset)} samples loaded")
task_loader, pre_epoch = build_dataloader(
task_name, task_dataset, task_collate_fn, is_train, opts
)
if is_train:
ratio = data_cfg.mix_ratio[k]
dataloaders[task_name] = (task_loader, ratio, pre_epoch)
else:
dataloaders[task_name] = PrefetchLoader(task_loader, device)
return dataloaders
def main(opts):
default_gpu, n_gpu, device = set_cuda(opts)
if default_gpu:
LOGGER.info(
'device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}'.format(
device, n_gpu, bool(opts.local_rank != -1), opts.fp16
)
)
seed = opts.seed
if opts.local_rank != -1:
seed += opts.rank
set_random_seed(seed)
if default_gpu:
save_training_meta(opts)
TB_LOGGER.create(os.path.join(opts.output_dir, 'logs'))
pbar = tqdm(total=opts.num_train_steps)
model_saver = ModelSaver(os.path.join(opts.output_dir, 'ckpts'))
add_log_to_file(os.path.join(opts.output_dir, 'logs', 'log.txt'))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
# Model config
model_config = PretrainedConfig.from_json_file(opts.model_config)
model_config.pretrain_tasks = []
for train_dataset_config in opts.train_datasets.values():
model_config.pretrain_tasks.extend(train_dataset_config['tasks'])
model_config.pretrain_tasks = set(model_config.pretrain_tasks)
tokenizer = AutoTokenizer.from_pretrained(model_config.lang_bert_name)
# Prepare model
if opts.checkpoint:
checkpoint = torch.load(opts.checkpoint, map_location=lambda storage, loc: storage)
else:
checkpoint = {}
if opts.init_pretrained == 'bert':
tmp = AutoModel.from_pretrained(model_config.lang_bert_name)
for param_name, param in tmp.named_parameters():
checkpoint[param_name] = param
if model_config.lang_bert_name == 'xlm-roberta-base':
# embeddings.token_type_embeddings.weight (1 -> 2, the second is for image embedding)
checkpoint['embeddings.token_type_embeddings.weight'] = torch.cat(
[checkpoint['embeddings.token_type_embeddings.weight']] * 2, 0
)
del tmp
elif opts.init_pretrained == 'lxmert':
tmp = torch.load(
'../datasets/pretrained/LXMERT/model_LXRT.pth',
map_location=lambda storage, loc: storage
)
for param_name, param in tmp.items():
param_name = param_name.replace('module.', '')
if 'bert.encoder.layer' in param_name:
param_name = param_name.replace('bert.encoder.layer', 'bert.lang_encoder.layer')
checkpoint[param_name] = param
elif 'bert.encoder.x_layers' in param_name:
param_name1 = param_name.replace('bert.encoder.x_layers', 'bert.local_encoder.encoder.x_layers')
param_name2 = param_name.replace('bert.encoder.x_layers', 'bert.global_encoder.encoder.x_layers')
checkpoint[param_name1] = checkpoint[param_name2] = param
elif 'cls.predictions' in param_name:
param_name = param_name.replace('cls.predictions', 'mlm_head.predictions')
checkpoint[param_name] = param
else:
checkpoint[param_name] = param
del tmp
model_class = GlocalTextPathCMTPreTraining
# update some training configs
model = model_class.from_pretrained(
pretrained_model_name_or_path=None, config=model_config, state_dict=checkpoint
)
model.train()
set_dropout(model, opts.dropout)
model = wrap_model(model, device, opts.local_rank)
del checkpoint
# load data training set
data_cfg = EasyDict(opts.train_datasets['R4R'])
train_nav_db = R2RTextPathData(
data_cfg.train_traj_files, data_cfg.img_ft_file,
data_cfg.scanvp_cands_file, data_cfg.connectivity_dir,
image_prob_size=model_config.image_prob_size,
image_feat_size=model_config.image_feat_size,
angle_feat_size=model_config.angle_feat_size,
max_txt_len=opts.max_txt_len, in_memory=True,
act_visited_node=True
)
val_nav_db = R2RTextPathData(
data_cfg.val_seen_traj_files, data_cfg.img_ft_file,
data_cfg.scanvp_cands_file, data_cfg.connectivity_dir,
image_prob_size=model_config.image_prob_size,
image_feat_size=model_config.image_feat_size,
angle_feat_size=model_config.angle_feat_size,
max_txt_len=opts.max_txt_len, in_memory=True,
act_visited_node=True
)
val2_nav_db = R2RTextPathData(
data_cfg.val_unseen_traj_files, data_cfg.img_ft_file,
data_cfg.scanvp_cands_file, data_cfg.connectivity_dir,
image_prob_size=model_config.image_prob_size,
image_feat_size=model_config.image_feat_size,
angle_feat_size=model_config.angle_feat_size,
max_txt_len=opts.max_txt_len, in_memory=True,
act_visited_node=True
)
# Build data loaders
train_dataloaders = create_dataloaders(
data_cfg, train_nav_db, tokenizer, True, device, opts
)
val_dataloaders = create_dataloaders(
data_cfg, val_nav_db, tokenizer, False, device, opts
)
val2_dataloaders = create_dataloaders(
data_cfg, val2_nav_db, tokenizer, False, device, opts
)
meta_loader = MetaLoader(
train_dataloaders,
accum_steps=opts.gradient_accumulation_steps,
distributed=opts.local_rank != -1,
device=device
)
meta_loader = PrefetchLoader(meta_loader, device)
# Prepare optimizer
optimizer = build_optimizer(model, opts)
task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
if opts.fp16:
grad_scaler = amp.GradScaler()
global_step = 0
LOGGER.info(f"***** Running training with {opts.world_size} GPUs *****")
LOGGER.info(" Batch size = %d", opts.train_batch_size if opts.local_rank == -1 else opts.train_batch_size * opts.world_size)
LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps)
LOGGER.info(" Num steps = %d", opts.num_train_steps)
# to compute training statistics
task2loss = {task: RunningMeter(f'loss/{task}')
for task in train_dataloaders.keys()}
n_examples = defaultdict(int)
n_in_units = defaultdict(int)
n_loss_units = defaultdict(int)
grad_norm = 0
start_time = time.time()
# quick hack for amp delay_unscale bug
optimizer.zero_grad()
optimizer.step()
for step, (name, batch) in enumerate(meta_loader):
# forward pass
n_examples[name] += batch['txt_ids'].size(0)
n_in_units[name] += batch['txt_lens'].sum().item()
task = name.split('_')[0]
# print(name, task)
# for k, v in batch.items():
# print(k, v.size())
# continue
if opts.fp16:
with amp.autocast():
loss = model(batch, task=task, compute_loss=True)
else:
loss = model(batch, task=task, compute_loss=True)
n_loss_units[name] += loss.size(0)
loss = loss.mean() # loss is not normalized in model
# backward pass
if args.gradient_accumulation_steps > 1: # average loss
loss = loss / args.gradient_accumulation_steps
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
if opts.fp16:
grad_scaler.scale(loss).backward()
else:
loss.backward()
task2loss[name](loss.item())
# optimizer update and logging
if (step + 1) % opts.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
lr_this_step = get_lr_sched(global_step, opts)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
# log loss
# NOTE: not gathered across GPUs for efficiency
TB_LOGGER.log_scalar_dict({ll.name: ll.val
for ll in task2loss.values()
if ll.val is not None})
TB_LOGGER.step()
# update model params
if opts.grad_norm != -1:
if opts.fp16:
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), opts.grad_norm
)
# print(step, name, grad_norm)
# for k, v in model.named_parameters():
# if v.grad is not None:
# v = torch.norm(v).data.item()
# print(k, v)
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
if opts.fp16:
grad_scaler.step(optimizer)
grad_scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
if global_step % opts.log_steps == 0:
# monitor training throughput
LOGGER.info(f'==============Step {global_step}===============')
for t in train_dataloaders.keys():
tot_ex = n_examples[t]
ex_per_sec = int(tot_ex / (time.time() - start_time))
tot_in = n_in_units[t]
in_per_sec = int(tot_in / (time.time() - start_time))
tot_l = n_loss_units[t]
l_per_sec = int(tot_l / (time.time() - start_time))
LOGGER.info(f'{t}: {tot_ex} examples trained at '
f'{ex_per_sec} ex/s')
TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
global_step)
TB_LOGGER.add_scalar(f'perf/{t}_in_per_s', in_per_sec,
global_step)
TB_LOGGER.add_scalar(f'perf/{t}_loss_per_s', l_per_sec,
global_step)
LOGGER.info('===============================================')
if global_step % opts.valid_steps == 0:
LOGGER.info(f'------Step {global_step}: start validation seen------')
validate(model, val_dataloaders, setname='_seen')
LOGGER.info(f'------Step {global_step}: start validation unseen------')
validate(model, val2_dataloaders, setname='_unseen')
model_saver.save(model, global_step)
if global_step >= opts.num_train_steps:
break
if global_step % opts.valid_steps != 0:
LOGGER.info(f'------Step {global_step}: start validation seen------')
validate(model, val_dataloaders, setname='_seen')
LOGGER.info(f'------Step {global_step}: start validation unseen------')
validate(model, val2_dataloaders, setname='_unseen')
model_saver.save(model, global_step)
def validate(model, val_dataloaders, setname=''):
model.eval()
for task, loader in val_dataloaders.items():
LOGGER.info(f"validate val{setname} on {task} task")
if task.startswith('mlm'):
val_log = validate_mlm(model, loader)
elif task.startswith('mrc'):
val_log = validate_mrc(model, loader)
elif task.startswith('sap'):
val_log = validate_sap(model, loader)
else:
raise ValueError(f'Undefined task {task}')
val_log = {f'val{setname}_{task}_{k}': v for k, v in val_log.items()}
TB_LOGGER.log_scalar_dict(
{f'valid{setname}_{task}/{k}': v for k, v in val_log.items()}
)
model.train()
@torch.no_grad()
def validate_mlm(model, val_loader):
LOGGER.info("start running MLM validation...")
val_loss = 0
n_correct = 0
n_word = 0
st = time.time()
for i, batch in enumerate(val_loader):
scores = model(batch, task='mlm', compute_loss=False)
labels = batch['txt_labels']
labels = labels[labels != -1]
loss = F.cross_entropy(scores, labels, reduction='sum')
val_loss += loss.item()
n_correct += (scores.max(dim=-1)[1] == labels).sum().item()
n_word += labels.numel()
val_loss = sum(all_gather(val_loss))
n_correct = sum(all_gather(n_correct))
n_word = sum(all_gather(n_word))
tot_time = time.time()-st
val_loss /= n_word
acc = n_correct / n_word
val_log = {'loss': val_loss,
'acc': acc,
'tok_per_s': n_word/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"acc: {acc*100:.2f}")
return val_log
def compute_accuracy_for_soft_targets(out, labels):
outputs = out.max(dim=-1)[1]
labels = labels.max(dim=-1)[1] # argmax
n_correct = (outputs == labels).sum().item()
return n_correct
@torch.no_grad()
def validate_mrc(model, val_loader):
LOGGER.info("start running MRC validation...")
val_loss = 0
n_feat = 0
st = time.time()
tot_score = 0
for i, batch in enumerate(val_loader):
view_logits, view_targets, _, _ = model(batch, task='mrc', compute_loss=False)
view_logprobs = F.log_softmax(view_logits, dim=-1)
loss = F.kl_div(view_logprobs, view_targets, reduction='sum')
tot_score += compute_accuracy_for_soft_targets(view_logits, view_targets)
val_loss += loss.item()
n_feat += batch['vp_view_mrc_masks'].sum().item()
val_loss = sum(all_gather(val_loss))
tot_score = sum(all_gather(tot_score))
n_feat = sum(all_gather(n_feat))
tot_time = time.time()-st
val_loss /= n_feat
val_acc = tot_score / n_feat
val_log = {'loss': val_loss,
'acc': val_acc,
'feat_per_s': n_feat/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"score: {val_acc*100:.2f}")
return val_log
@torch.no_grad()
def validate_sap(model, val_loader):
LOGGER.info("start running SAP validation...")
val_gloss, val_lloss, val_floss = 0, 0, 0
n_gcorrect, n_lcorrect, n_fcorrect = 0, 0, 0
n_data = 0
st = time.time()
for i, batch in enumerate(val_loader):
global_logits, local_logits, fused_logits, global_act_labels, local_act_labels = \
model(batch, task='sap', compute_loss=False)
val_gloss += F.cross_entropy(global_logits, global_act_labels, reduction='sum').data.item()
val_lloss += F.cross_entropy(local_logits, local_act_labels, reduction='sum').data.item()
val_floss += F.cross_entropy(fused_logits, global_act_labels, reduction='sum').data.item()
n_gcorrect += torch.sum(torch.argmax(global_logits, 1) == global_act_labels).item()
n_lcorrect += torch.sum(torch.argmax(local_logits, 1) == local_act_labels).item()
n_fcorrect += torch.sum(torch.argmax(fused_logits, 1) == global_act_labels).item()
n_data += len(global_act_labels)
n_data = sum(all_gather(n_data))
val_gloss = sum(all_gather(val_gloss)) / n_data
val_lloss = sum(all_gather(val_lloss)) / n_data
val_floss = sum(all_gather(val_floss)) / n_data
gacc = sum(all_gather(n_gcorrect)) / n_data
lacc = sum(all_gather(n_lcorrect)) / n_data
facc = sum(all_gather(n_fcorrect)) / n_data
tot_time = time.time()-st
val_log = {'gloss': val_gloss, 'lloss': val_lloss, 'floss': val_floss,
'gacc': gacc, 'lacc': lacc, 'facc': facc,
'tok_per_s': n_data/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"gacc: {gacc*100:.2f}, lacc: {lacc*100:.2f}, facc: {facc*100:.2f}")
return val_log
def build_args():
parser = load_parser()
opts = parse_with_config(parser)
if os.path.exists(opts.output_dir) and os.listdir(opts.output_dir):
LOGGER.warning(
"Output directory ({}) already exists and is not empty.".format(
opts.output_dir
)
)
return opts
if __name__ == '__main__':
args = build_args()
main(args)